import os.path
from glob import glob
from re import sub
from datetime import datetime, timedelta
import numpy as np
from scipy.interpolate import CubicSpline
from spire.tools import dt2float, read_gnss_v1

def cal_zenith(dt, leosat):
    leofile = f's01_leoOrb/leoOrb_{leosat}.{dt:%Y.%j}.npz'
    with np.load(leofile, allow_pickle=True) as fid:
        Ldata = fid['data']
        Ldt = fid['dts']
    Gdata, Gdt, Gsats = read_gnss_v1(Ldt[[0,-1]])
    idx = np.all(np.isfinite(Gdata[:-1, :, :3]), axis=(0, 2))
    Gdata = Gdata[:, idx]
    Gsats = Gsats[idx]
    Tbdry = [max(Gdt[0],Ldt[0]), min(Gdt[-1],Ldt[-1])]
    Tbdry_i = [x.replace(second=10*(x.second//10), microsecond=0) for x in Tbdry]
    if Tbdry_i[0] != Tbdry[0]:
        Tbdry_i[0] = Tbdry_i[0] + timedelta(seconds=10)
    Tnew = np.arange(Tbdry_i[0], Tbdry_i[1] + timedelta(seconds=10), timedelta(seconds=10)).astype(datetime)
    Gnew = CubicSpline(dt2float(Gdt[:-1], Tnew[0]), Gdata[:-1, :, :3])(dt2float(Tnew, Tnew[0]))
    Lnew = CubicSpline(dt2float(Ldt, Tnew[0]), Ldata[:, :3])(dt2float(Tnew, Tnew[0]))[:, None, :]
    LG = Gnew - Lnew
    zenith = np.arccos(np.sum(LG * Lnew, axis=2) /
                       np.sum(LG * LG, axis=2) ** 0.5 / np.sum(Lnew * Lnew, axis=2) ** 0.5) / np.pi * 180
    return zenith, Tnew, Gsats

def pair_ro_ref(working_leo):
    rofiles = glob(f't01_RO_{working_leo}/*.npz')
    dts_ro = np.array([datetime.strptime(sub(r'.*\.(\d{14})\.\d{14}\..*', r'\1', x), '%Y%m%d%H%M%S') for x in rofiles])
    leosats = np.unique([sub(r'.*_(.....)\..*',r'\1',rofile) for rofile in rofiles])
    dt_rng = np.unique([datetime(*x.timetuple()[:3]) for x in dts_ro])
    for leosat in leosats:
        rofiles = glob(f't01_RO_{working_leo}/occTab_{leosat}*.npz')
        reffiles = np.array(glob(f't01_REF_{working_leo}/occTab_{leosat}*.npz'))
        if len(rofiles)==0 or len(reffiles)==0:
            continue
        gnss_ref = np.expand_dims([sub(r'.*\.(...)\.npz', r'\1', x) for x in reffiles], axis=0)
        dts_ro = np.expand_dims([datetime.strptime(sub(r'.*\.(\d{14})\.\d{14}\..*', r'\1', x),'%Y%m%d%H%M%S') for x in rofiles], axis=1)
        dte_ro = np.expand_dims([datetime.strptime(sub(r'.*\.\d{14}\.(\d{14})\..*', r'\1', x),'%Y%m%d%H%M%S') for x in rofiles], axis=1)
        dts_ref = np.expand_dims([datetime.strptime(sub(r'.*\.(\d{14})\.\d{14}\..*', r'\1', x),'%Y%m%d%H%M%S') for x in reffiles], axis=0)
        dte_ref = np.expand_dims([datetime.strptime(sub(r'.*\.\d{14}\.(\d{14})\..*', r'\1', x),'%Y%m%d%H%M%S') for x in reffiles], axis=0)
        # leosat = sub(r'.*_(.....)\..*',r'\1',rofiles[0])
        for dt in dt_rng:
            # TC: updated 2026/04/29 (add 2 lines, skip calculation if leoOrb file not exists)
            if not os.path.exists(f's01_leoOrb/leoOrb_{leosat}.{dt:%Y.%j}.npz'):
                continue
            zenith, Tnew, Gsats = cal_zenith(dt, leosat)
            pidx = (dts_ro>dts_ref)*(dte_ro<dte_ref)*(np.isin(gnss_ref,Gsats))*(Tnew[0]<dts_ro)*(Tnew[-1]>dte_ro)*(Tnew[0]<dts_ref)*(Tnew[-1]>dts_ref)
            for rofile,dts,dte,ppidx in zip(rofiles,dts_ro[:,0],dte_ro[:,0],pidx):
                if ~np.any(ppidx):
                    continue
                reffile = reffiles[ppidx]
                tidx = np.where((Tnew>dts)*(Tnew<dte))[0]
                lidx = [list(Gsats).index(sub(r'.*\.(...)\.npz', r'\1', x)) for x in reffile]
                refzenith = np.max(zenith[tidx][:, lidx], axis=0)
                refdt = Tnew[tidx[np.argmax(zenith[tidx][:, lidx], axis=0)]]
                refsnr1 = np.full(reffile.shape, np.nan)
                for pp, rr in enumerate(reffile):
                    with np.load(rr, allow_pickle=True) as fid:
                        data = fid['data']
                        header = fid['header']
                    sdata = [data[:, x] for x, y in enumerate(header) if 'S1' in y][0]
                    tdata = np.array([datetime(1980,1,6)+timedelta(seconds=x) for x in data[:,0]+data[:,1]])
                    refsnr1[pp] = CubicSpline(dt2float(tdata,refdt[pp]),sdata)(0)
                outfile = sub(r'.*/',f't03_pair_{working_leo}/',rofile)
                if os.path.exists(outfile):
                    with np.load(outfile, allow_pickle=True) as fid:
                        reffile0 = fid['reffile']
                        refzenith0 = fid['refzenith']
                        refdt0 = fid['refdt']
                        refsnr10 = fid['refsnr1']
                    idx = ~np.isin(reffile0,reffile)
                    reffile = np.hstack((reffile,reffile0[idx]))
                    refzenith = np.hstack((refzenith,refzenith0[idx]))
                    refdt = np.hstack((refdt,refdt0[idx]))
                    refsnr1 = np.hstack((refsnr1,refsnr10[idx]))
                idx = np.array(sorted(np.array([range(len(refsnr1)), refsnr1, [ord(x[-7]) for x in reffile]]).transpose(),
                                      key=lambda x: (x[2], x[1]), reverse=True))[:, 0].astype('int')
                reffile = reffile[idx]
                refzenith = refzenith[idx]
                refdt = refdt[idx]
                refsnr1 = refsnr1[idx]
                np.savez(outfile,reffile=reffile,refzenith=refzenith,refdt=refdt,refsnr1=refsnr1)
