import os, sys, math
"""
Conversion T3PA to single TXT frame or PMF multiframe files
(c) 2025 Pavel Hudecek, Advacam

1. (optional if acqTime > 0) Measure to T3PA
2. Import T3PA data to frames (incl. detect if max ToA fits to single/more frames)
3. Filter noisy pixels
4. Save output files (TXT/PMF by single/more frames)

 Output frames:
    - ToA:  time of arrival [ns] (first hit)
    - FToA: fine ToA (first hit)
    - ToT:  time over threshold (first hit)
    - Cnt:  number of hits to the pixel
    - Pixels with no hit has the noHitValue (except Cnt frame - no-hit is allways 0)
    - Noisy pixels are set to noisyPxValue (except Cnt frame - noisy-pix is allways 0)
"""
PIXETDIR="C:\\Advacam\\pixet-nightly"   # Path to the Pixet core directory, if the measurement needed
data_dir = "out-files\\t3pa-test"       # Input/output data directory

fNamTxt = os.path.join(data_dir, "t3pa-test.txt")   # default txt filename
fNamT3pa = os.path.join(data_dir, "t3pa-test.t3pa") # t3pa filename

frameXsize = 256
frameYsize = 256
frameCntLimit = 10000   # maximum allowed frames to no error
noisyCntLimit = 5000    # maximum allowed count, reach to detect noisy pixels

frameTime = 0.2         # time [s] for division data to frames in PMF (> detected acqTime - single frame output TXT) 
acqTime = 0.0           # acquisition time [s] / 0.0 - disable meas & don't try using pixet

noHitValue = 0          # value for no hit pixels in all frames except the Counts
noisyPxValue = 0        # value for noisy pixels in all frames except the Counts

# ----------------- end of user settings -------------------------------------------------------------------

pixetPresent = 0 # 1: present at start = run from Pixet / 0: not present = system console / -1: disabled by acqTime
def exitCore(ee):
    print("-----------------------------------------------------------------")
    if pixetPresent==1:
        print("Pixet core: remain")
    elif pixetPresent==0:
        print("Pixet core: exit...")
        pixet.exitPixet()
        print("pypixet.exit()...")
        pypixet.exit()
        print("Done")
    if ee: exit()

if acqTime > 1e-10:
    try: # Run from the Py console in the Pixet
        devices = pixet.devices()
        pixetPresent = 1
        print("Pixet core present")
    except: # Run from the system terminal
        lastDir = os.path.abspath(os.getcwd())
        print("lastDir", lastDir)
        print("PIXETDIR", PIXETDIR)
        if PIXETDIR!=None:
            os.chdir(PIXETDIR)
            sys.path.append(PIXETDIR)
        print("cwd:", os.getcwd())
        print("Pixet core starting...")
        import pypixet
        pypixet.start()
        pixet=pypixet.pixet
        devices = pixet.devices()
        os.chdir(lastDir)
        print("Done")
        print("cwd:", os.getcwd())

    print("-----------------------------------------------------------------")
    print("devs:")
    for n in range(len(devices)):
        dev = devices[n]
        print(" ", n, dev.fullName(), dev.chipType(), dev.sensorType(0), dev.chipIDSummary(), dev.chipIDs())
        if dev.chipCount()>1:
            st = ""
            for ch in range(dev.chipCount()):
                st += f" {dev.sensorType(ch)}"
            print("     ", st)
    print("-----------------------------------------------------------------")

    dev = devices[0]
    print("\nselected:", dev.fullName(), dev.chipType(), dev.sensorType(0), "\n")
    print("-----------------------------------------------------------------")

    print("pixetType():      ", pixet.pixetType())
    print("pixetVersion():   ", pixet.pixetVersion())
    print("pixetAPIVersion():", pixet.pixetAPIVersion())
    pars = dev.parameters()
    rc = pars.get("HwLibVer").getString()
    print("HwLibVer:          ", rc)

    rc = dev.loadFactoryConfig()
    assert rc == 0, dev.lastError()
    rc = dev.setOperationMode(pixet.PX_TPX3_OPM_TOATOT)
    assert rc == 0, dev.lastError()

    acqType = pixet.PX_ACQTYPE_DATADRIVEN
    acqMode = pixet.PX_ACQMODE_NORMAL

    print(f"doAdvancedAcquisition '{fNamT3pa}'...")
    # doAdvancedAcquisition(count, time, acqType, acqMode, fileType, fileFlags, fileName) 
    rc = dev.doAdvancedAcquisition(1, acqTime, acqType, acqMode, pixet.PX_FTYPE_AUTODETECT, 0, fNamT3pa)
    if rc == 0:
        print(f"done")
    else:
        print(f"rc:{rc} err:'{dev.lastError()}'")
else:
    print("acqTime == 0.0 - no Pixet and acquisition, only file processing")
    pixetPresent = -1

frameToAR = [[noHitValue] * frameXsize * frameYsize]
frameToA = [[float(noHitValue)] * frameXsize * frameYsize]
frameFToA = [[noHitValue] * frameXsize * frameYsize]
frameToT = [[noHitValue] * frameXsize * frameYsize]
frameCnt = [[0] * frameXsize * frameYsize]
frameIdx = 0
pxCnt = 0
pxCntOvr = 0
toaMax = 0.0
itotMax = 0
cntMax = 0
f = open(fNamT3pa, "r")
header = f.readline() # read file header
assert header == "Index	Matrix Index	ToA	ToT	FToA	Overflow\n"
# read data
for line in f:
    lineSp = line.strip().split("\t")
    assert len(lineSp) == 6, f"Invalid data line:'{line}' spLen:{len(lineSp)}"
    overflow = int(lineSp[5])
    matrixIndex = int(lineSp[1])
    index = int(lineSp[0])
    toar = int(lineSp[2])
    tot = int(lineSp[3])
    ftoa = int(lineSp[4])
    toa = 25.0*toar - (25.0/16.0)*ftoa
    pxCnt += 1
    if overflow == 0:
        toaSec = toa / 1e9
        frameIdx = int(toaSec / frameTime)
        if frameIdx >= frameCntLimit:
            print(f"Error: frame index:{frameIdx} reach/exceed frameCntLimit:{frameCntLimit}")
            exitCore(1)
        while frameIdx >= len(frameToAR):
            frameToAR.append([noHitValue] * frameXsize * frameYsize)
            frameToA.append([float(noHitValue)] * frameXsize * frameYsize)
            frameFToA.append([noHitValue] * frameXsize * frameYsize)
            frameToT.append([noHitValue] * frameXsize * frameYsize)
            frameCnt.append([0] * frameXsize * frameYsize)

        if frameToAR[frameIdx][matrixIndex] == noHitValue:
            frameToAR[frameIdx][matrixIndex] = toar
            frameFToA[frameIdx][matrixIndex] = ftoa
            frameToA[frameIdx][matrixIndex] = toa
        frameToT[frameIdx][matrixIndex] += tot
        frameCnt[frameIdx][matrixIndex] += 1

        if toa > toaMax: toaMax = toa
        if frameCnt[frameIdx][matrixIndex] > cntMax: cntMax = frameCnt[frameIdx][matrixIndex]
        if frameToT[frameIdx][matrixIndex] > itotMax: itotMax = frameToT[frameIdx][matrixIndex]
    else:
        pxCntOvr += 1
f.close()

hitPixels = []
for fr in frameCnt:
    hp = 0
    for cnt in fr:
        if cnt > 0: hp += 1
    hitPixels.append(hp)

s = f"pxCnt:{pxCnt} frCnt:{len(frameToAR)} pxCntOvr:{pxCntOvr} toaMax:{toaMax:.2e} itotMax:{itotMax} cntMax:{cntMax}"
s += f"\nhitPixels:{hitPixels}"
print(s)

print("filter noisy pixels...")
for i in range(len(frameToAR)):
    for j, cnt in enumerate(frameCnt[i]):
        if cnt > noisyCntLimit:
            frameCnt[i][j] = 0
            frameToAR[i][j] = noisyPxValue
            frameToA[i][j] = float(noisyPxValue)
            frameFToA[i][j] = noisyPxValue
            frameToT[i][j] = noisyPxValue

hitPixels = []
cntMax = 0
itotMax = 0
for i in range(len(frameToAR)):
    hp = 0
    for j, cnt in enumerate(frameCnt[i]):
        if cnt > 0: hp += 1
        if cnt > cntMax: cntMax = cnt
        if frameToT[i][j] > itotMax: itotMax = frameToT[i][j]
    hitPixels.append(hp)

s = f"cntMax:{cntMax} itotMax:{itotMax}\nhitPixels:{hitPixels}"
print(s)

acqTime = toaMax / 1e9
frameTxtExt = fNamTxt.split(".")[-1]
if acqTime > frameTime:
    fNamTxt = fNamTxt.replace(frameTxtExt, "pmf")
    frameTxtExt = "pmf"
    print(f"Multiframe detected:\n   Max ToA:{acqTime}, frameTime:{frameTime}, frames:{math.ceil(acqTime / frameTime)}")
    print(f"   file ext:{frameTxtExt} fNamTxt:{fNamTxt}")

frameTxtExt = "." + frameTxtExt

fNamToA = fNamTxt.replace(frameTxtExt, f"_ToA{frameTxtExt}")
fNamFToA = fNamTxt.replace(frameTxtExt, f"_FToA{frameTxtExt}")
fNamToT = fNamTxt.replace(frameTxtExt, f"_ToT{frameTxtExt}")
fNamCnt = fNamTxt.replace(frameTxtExt, f"_Cnt{frameTxtExt}")

print("saving frames...")
fToA = open(fNamToA, "w")
fFToA = open(fNamFToA, "w")
fToT = open(fNamToT, "w")
fCnt = open(fNamCnt, "w")
#enter = "\n" if sys.version_info[0] >= 3 else "\r\n"
for f in range(len(frameToAR)):
    for y in range(frameYsize):
        for x in range(frameXsize):
            i = y * frameXsize + x
            end = "\n" if x == frameXsize - 1 else " "
            fToA.write(f"{frameToA[f][i]}{end}")
            fFToA.write(f"{frameFToA[f][i]}{end}")
            fToT.write(f"{frameToT[f][i]}{end}")
            fCnt.write(f"{frameCnt[f][i]}{end}")
fToA.close()
fFToA.close()
fToT.close()
fCnt.close()
print(f"done. frames:{len(frameToAR)}")

exitCore(1) # ***************************************************************************************************