import numpy as np
from re import sub
# from glob import glob
from astropy.coordinates import EarthLocation
from traceback import extract_tb
from datetime import datetime, timedelta
from pathlib import Path
from tools import dt2float
from netCDF4 import Dataset

working_dir = Path(__file__).absolute().parent
def log_message(msg, working_leo=''):
    with open(working_dir / f"log/progress_dcb_{working_leo}_{datetime.now():%Y%m}.txt", 'a') as fid:
        fid.write(f"{datetime.now():%H:%M:%S} {msg}\n")

alti_km = 4000
dt0_gps = datetime(1980,1,6)
leos = ['GN04', 'GN05']
# TC: updated 2026/05/18 (1 line, fix sign typo)
dts = [datetime.today() + timedelta(days=dd) for dd in range(-2,0)]
ants = [0, 1]

for dt in dts:
    sec_rr = (dt - dt0_gps) / timedelta(seconds=1)
    for ll, leo in enumerate(leos):
        for aa, ant in enumerate(ants):
            log_message(f"- Processing ({dt:%Y-%m-%d}.{leo}.A{ant:02d})")
            try:
                outfile = working_dir / f'p04_teczdcb/teczdcb_{leo}.{dt:%Y-%m-%d}.A{ant:02d}.npz'
                for tdif in range(366):
                    oldfile = working_dir / f'p04_teczdcb/teczdcb_{leo}.{dt-timedelta(days=tdif):%Y-%m-%d}.A{ant:02d}.npz'
                    if oldfile.exists():
                        with np.load(oldfile, allow_pickle=True) as fid:
                            gnsses = fid['gnsses']
                            dcb_sol = fid['dcb_sol']
                        break
                alltec = np.full((187200, len(gnsses)), np.nan)
                allcoszn = np.full((187200, len(gnsses)), np.nan)
                allmeps = np.full((187200, len(gnsses)), np.nan)
                orbfile = working_dir / f'p01_leoOrb/leoOrb_{leo}.{dt:%Y.%j}.npz'
                with np.load(orbfile, allow_pickle=True) as fid:
                    Ldata0 = fid['data']
                    Ldts0 = fid['dts']
                orbfile = working_dir / f'p01_leoOrb/leoOrb_{leo}.{dt+timedelta(days=1):%Y.%j}.npz'
                # TC: updated 2026/05/18 (10 lines, check existence of next day orbit file)
                if orbfile.exists():
                    with np.load(orbfile, allow_pickle=True) as fid:
                        Ldata1 = fid['data']
                        Ldts1 = fid['dts']
                    idx = ~np.isin(dt2float(Ldts0,dt),dt2float(Ldts1,dt),assume_unique=True)
                    Ldts = np.hstack((Ldts0[idx], Ldts1))
                    Ldata = np.vstack((Ldata0[idx], Ldata1))
                else:
                    Ldts = Ldts0
                    Ldata = Ldata0
                idx = np.argsort(Ldts)
                Ldts = Ldts[idx]
                Ldata = Ldata[idx]
                sec_ref = np.round(((Ldts - dt) / timedelta(seconds=1)).astype('float')).astype('int')
                idx = sec_ref>=-7200
                leo_pos = EarthLocation(x=Ldata[idx,0],y=Ldata[idx,1],z=Ldata[idx,2],unit='km')
                alllon, alllat, _ = leo_pos.to_geodetic()
                allhr = (sec_ref[idx]/3600+alllon.value/15+12)%24-12

                tecfiles = list(
                    (working_dir / f'p05_podTec/{dt:%Y-%m-%d}').glob(f'ROSW_TEC_{leo}_*_A{ant:02d}_*.nc')) + list(
                    (working_dir / f'p05_podTec/{dt + timedelta(days=1):%Y-%m-%d}').glob(f'ROSW_TEC_{leo}_*_A{ant:02d}_*.nc'))
                for tecfile in tecfiles:
                    try:
                        gnss = sub(r'.*_([A-Z]\d\d)_A\d\d_.*',r'\1',tecfile.name)
                        if gnss not in gnsses:
                            gnsses = np.append(gnsses, gnss)
                            dcb_sol = np.append(dcb_sol, np.nan)
                            alltec = np.hstack((alltec, np.full_like(alltec[:,:1], np.nan)))
                            allcoszn = np.hstack((allcoszn, np.full_like(allcoszn[:, :1], np.nan)))
                            allmeps = np.hstack((allmeps, np.full_like(allmeps[:, :1], np.nan)))
                        gidx = list(gnsses).index(gnss)
                        with Dataset(tecfile, 'r') as fid:
                            tidx = (fid['time'][...]-sec_rr).astype('int')+7200
                            dcb = fid.getncattr('leodcb')
                            tec = fid['TEC'][...]-dcb
                            coszn = np.sin(fid['elevation'][...]*np.pi/180)
                            OL = np.column_stack([fid['x_LEO'][...],fid['y_LEO'][...],fid['z_LEO'][...]])
                        altl_km = EarthLocation.from_geocentric(*OL.T, unit='km').geodetic.height.value
                        rad_ratio = (alti_km + 6371) / (altl_km + 6371)
                        meps = (coszn + (rad_ratio ** 2 + coszn ** 2 - 1) ** 0.5) / (1 + rad_ratio)
                        alltec[tidx, gidx] = tec
                        allcoszn[tidx, gidx] = coszn
                        allmeps[tidx, gidx] = meps
                    except:
                        continue

                # reffiles = glob(f'./p01_podPair/{dt:%Y-%m-%d}/*{leo}.*.A{ant:02d}.*.npz')
                # for reffile in reffiles:
                #     gidx = gnsses.index(sub(r'.*\.(...)\.npz',r'\1',reffile))
                #     try:
                #         sec_ref_0, tec, coszn, meps = calculate_tec(reffile,dumpnc=False)
                #     except:
                #         continue
                #     alltec[np.round(sec_ref_0).astype('int')+7200, gidx] = tec
                #     allcoszn[np.round(sec_ref_0).astype('int')+7200, gidx] = coszn
                #     allmeps[np.round(sec_ref_0).astype('int')+7200, gidx] = meps
                # reffiles = glob(f'./p01_podPair/{dt+timedelta(days=1):%Y-%m-%d}/*{leo}.*.A{ant:02d}.*.npz')
                # for reffile in reffiles:
                #     gidx = gnsses.index(sub(r'.*\.(...)\.npz', r'\1', reffile))
                #     try:
                #         sec_ref_0, tec, coszn, meps = calculate_tec(reffile,dumpnc=False)
                #     except:
                #         continue
                #     alltec[np.round(sec_ref_0).astype('int') + 93600, gidx] = tec
                #     allcoszn[np.round(sec_ref_0).astype('int') + 93600, gidx] = coszn
                #     allmeps[np.round(sec_ref_0).astype('int') + 93600, gidx] = meps
                mask = allcoszn > .1
                tmask = np.full((187200,), False)
                tmask[sec_ref[idx]+7200] = np.logical_and(np.abs(allhr) < 4, np.abs(alllat.value) > 30)
                nprof = np.sum(mask, axis=1)
                T = alltec.copy()
                T[~mask] = np.nan
                M = allmeps.copy()
                M[~mask] = np.nan
                gidx = np.where(np.any(np.isfinite(T), axis=0))[0]
                # gnsses = np.array(gnsses)[gidx]
                # dcb_sol = dcb_sol[gidx]
                allcoszn = allcoszn[:, gidx]
                alltec = alltec[:, gidx]
                allmeps = allmeps[:, gidx]
                M = M[:, gidx]
                T = T[:, gidx]
                mask = allcoszn > .1

                modtec = (T + dcb_sol[gidx]) * M
                idx = np.abs(modtec - np.nanmedian(modtec, axis=1)[:, None]) > 5
                T[idx] = np.nan
                M[idx] = np.nan
                nprof = np.sum(np.isfinite(T), axis=1)

                A = np.zeros((T.shape[1], T.shape[1]))
                B = np.zeros((T.shape[1], ))
                for t in range(T.shape[0]):
                    if nprof[t]<2:
                        continue
                    idx = np.logical_and(np.isfinite(T[t]),np.isfinite(M[t]))
                    nprof[t] = np.sum(idx)
                    if nprof[t] < 2:
                        continue
                    A[idx,idx] += M[t,idx]**2
                    A[np.ix_(idx, idx)] -= M[t,idx,None]*M[None,t,idx]/nprof[t]
                    B[idx] += (M[t,idx]@T[t,idx])*M[t,idx]/nprof[t] - T[t,idx]*M[t,idx]**2
                idx = np.logical_or(np.all(A==0,axis=0),np.all(A==0,axis=1))
                if np.any(idx):
                    A = A[~idx][:,~idx]
                    B = B[~idx]
                    gidx = gidx[~idx]
                    allcoszn = allcoszn[:, ~idx]
                    alltec = alltec[:, ~idx]
                    allmeps = allmeps[:, ~idx]
                    M = M[:, ~idx]
                    T = T[:, ~idx]
                    mask = allcoszn > .1
                if np.linalg.matrix_rank(A)<A.shape[0]:
                    raise ValueError(f"Singular linear system (1) ({dt:%Y-%m-%d}.{leo}.A{ant:02d})")
                dcb_sol[gidx] = np.linalg.solve(A, B)
                modtec = (T+dcb_sol[gidx])*M
                modtec[~mask] = np.nan
                idx = np.abs(modtec-np.nanmedian(modtec,axis=1)[:,None]) > 5
                while np.any(idx):
                    T[idx] = np.nan
                    M[idx] = np.nan
                    nprof = np.sum(np.isfinite(T), axis=1)
                    A = np.zeros((T.shape[1], T.shape[1]))
                    B = np.zeros((T.shape[1],))
                    for t in range(T.shape[0]):
                        if nprof[t] < 2:
                            continue
                        idx = np.logical_and(np.isfinite(T[t]),np.isfinite(M[t]))
                        nprof[t] = np.sum(idx)
                        if nprof[t] < 2:
                            continue
                        A[idx, idx] += M[t, idx] ** 2
                        A[np.ix_(idx, idx)] -= M[t, idx, None] * M[None, t, idx] / nprof[t]
                        B[idx] += (M[t, idx] @ T[t, idx]) * M[t, idx] / nprof[t] - T[t, idx] * M[t, idx] ** 2
                    idx = np.logical_or(np.all(A==0,axis=0),np.all(A==0,axis=1))
                    if np.any(idx):
                        A = A[~idx][:, ~idx]
                        B = B[~idx]
                        gidx = gidx[~idx]
                        allcoszn = allcoszn[:, ~idx]
                        alltec = alltec[:, ~idx]
                        allmeps = allmeps[:, ~idx]
                        M = M[:, ~idx]
                        T = T[:, ~idx]
                        mask = allcoszn > .1
                    if np.linalg.matrix_rank(A) < A.shape[0]:
                        raise ValueError(f"Singular linear system (2) ({dt:%Y-%m-%d}.{leo}.A{ant:02d})")
                    dcb_sol[gidx] = np.linalg.solve(A, B)
                    modtec = (T + dcb_sol[gidx]) * M
                    modtec[~mask] = np.nan
                    idx = np.abs(modtec - np.nanmedian(modtec, axis=1)[:, None]) > 5
                tecz = np.nanmean(modtec, axis=1)
                gidx = np.isfinite(dcb_sol)
                dcb_sol = dcb_sol[gidx]
                gnsses = gnsses[gidx]
                gidx = np.argsort(gnsses)
                dcb_sol = dcb_sol[gidx]
                gnsses = gnsses[gidx]
                np.savez(outfile, sec_tec=np.arange(-7200,180000),gnsses=gnsses,tecz=tecz,dcb_sol=dcb_sol)

                # sTM2 = np.nansum(T*M**2, axis=1)
                # sM = np.nansum(M, axis=1)
                # sM2 = np.nansum(M**2, axis=1)
                # sTM = np.nansum(T*M, axis=1)
                # A = sM2-sM**2/nprof
                # B = -sTM2+sM*sTM/nprof
                # A[nprof<2] = np.nan
                # B[nprof<2] = np.nan
                # leodcb = (B/A)[:,None]
                # modtec = (alltec+leodcb)*allmeps
                # modtec[~mask] = np.nan
                # idx = np.logical_and(np.nanstd(modtec,axis=1)<0.1,np.nanmean(modtec,axis=1)>0)
                log_message(f"- Done ({dt:%Y-%m-%d}.{leo}.A{ant:02d})")
            except Exception as e:
                log_message(f'-- Error detected for ({dt:%Y-%m-%d}.{leo}.A{ant:02d})')
                log_message(f'---  {type(e).__name__} >> {str(e)}')
                for tb in extract_tb(e.__traceback__):
                    log_message(f"---  line {tb.lineno} of {tb.filename}")