from glob import glob
from re import sub
from datetime import datetime, timedelta
import numpy as np
from scipy.interpolate import CubicSpline
from tools import dt2float, read_gnss_v1

lsat_names = dict(GN04=4, GN05=5, YM08=8)
for keys in list(lsat_names.keys()):
    lsat_names[lsat_names[keys]] = keys

def cal_zenith(dt, leosat):
    leofile = f's01_leoOrb/leoOrb_{leosat}.{dt:%Y.%j}.npz'
    with np.load(leofile, allow_pickle=True) as fid:
        Ldata = fid['data']
        Ldt = fid['dts']
    Gdata, Gdt, Gsats = read_gnss_v1(Ldt[[0,-1]])
    idx = np.all(np.isfinite(Gdata[:-1, :, :3]), axis=(0, 2))
    Gdata = Gdata[:, idx]
    Gsats = Gsats[idx]
    Tbdry = [max(Gdt[0],Ldt[0]), min(Gdt[-1],Ldt[-1])]
    Tbdry_i = [x.replace(second=10*(x.second//10), microsecond=0) for x in Tbdry]
    if Tbdry_i[0] != Tbdry[0]:
        Tbdry_i[0] = Tbdry_i[0] + timedelta(seconds=10)
    Tnew = np.arange(Tbdry_i[0], Tbdry_i[1] + timedelta(seconds=10), timedelta(seconds=10)).astype(datetime)
    Gnew = CubicSpline(dt2float(Gdt[:-1], Tnew[0]), Gdata[:-1, :, :3])(dt2float(Tnew, Tnew[0]))
    Lnew = CubicSpline(dt2float(Ldt, Tnew[0]), Ldata[:, :3])(dt2float(Tnew, Tnew[0]))[:, None, :]
    LG = Gnew - Lnew
    zenith = np.arccos(np.sum(LG * Lnew, axis=2) /
                       np.sum(LG * LG, axis=2) ** 0.5 / np.sum(Lnew * Lnew, axis=2) ** 0.5) / np.pi * 180
    return zenith, Tnew, Gsats

def pair_ro_ref(dt, working_leo):
    rofiles = glob(f't01_RO_{working_leo}/*.npz')
    leosats = np.unique([sub(r'.*_(....)\..*',r'\1',rofile) for rofile in rofiles])
    for leosat in leosats:
        rofiles = glob(f't01_RO_{working_leo}/occTab_{leosat}*.npz')
        reffiles = np.array(glob(f't01_REF_{working_leo}/occTab_{leosat}*.npz'))
        if len(rofiles)==0 or len(reffiles)==0:
            continue
        gnss_ref = np.expand_dims([sub(r'.*\.(...)\.npz', r'\1', x) for x in reffiles], axis=0)
        dts_ro = np.expand_dims([datetime.strptime(sub(r'.*\.(\d{14})\.\d{14}\..*', r'\1', x),'%Y%m%d%H%M%S') for x in rofiles], axis=1)
        dte_ro = np.expand_dims([datetime.strptime(sub(r'.*\.\d{14}\.(\d{14})\..*', r'\1', x),'%Y%m%d%H%M%S') for x in rofiles], axis=1)
        dts_ref = np.expand_dims([datetime.strptime(sub(r'.*\.(\d{14})\.\d{14}\..*', r'\1', x),'%Y%m%d%H%M%S') for x in reffiles], axis=0)
        dte_ref = np.expand_dims([datetime.strptime(sub(r'.*\.\d{14}\.(\d{14})\..*', r'\1', x),'%Y%m%d%H%M%S') for x in reffiles], axis=0)
        # leosat = sub(r'.*_(....)\..*',r'\1',rofiles[0])
        zenith, Tnew, Gsats = cal_zenith(dt, leosat)
        # with np.load(f's02_zenith/zenith_{leosat}.{dt:%Y.%j}_.npz', allow_pickle=True) as fid:
        #     zenith = fid['zenith']
        #     Tnew = fid['Tnew']
        #     Gsats = fid['Gsats']
        pidx = (dts_ro>dts_ref)*(dte_ro<dte_ref)*(np.isin(gnss_ref,Gsats))
        for rofile,dts,dte,ppidx in zip(rofiles,dts_ro[:,0],dte_ro[:,0],pidx):
            if ~np.any(ppidx):
                continue
            reffile = reffiles[ppidx]
            tidx = np.where((Tnew>dts)*(Tnew<dte))[0]
            lidx = [list(Gsats).index(sub(r'.*\.(...)\.npz', r'\1', x)) for x in reffile]
            refzenith = np.max(zenith[tidx][:, lidx], axis=0)
            refdt = Tnew[tidx[np.argmax(zenith[tidx][:, lidx], axis=0)]]
            refsnr1 = np.full(reffile.shape, np.nan)
            for pp, rr in enumerate(reffile):
                with np.load(rr, allow_pickle=True) as fid:
                    data = fid['data']
                    header = fid['header']
                sdata = [data[:, x] for x, y in enumerate(header) if 'S1' in y][0]
                tdata = np.array([datetime(1980,1,6)+timedelta(seconds=x) for x in data[:,0]+data[:,1]])
                refsnr1[pp] = CubicSpline(dt2float(tdata,refdt[pp]),sdata)(0)
            idx = np.array(sorted(np.array([range(len(refsnr1)), refsnr1, [ord(x[-7]) for x in reffile]]).transpose(),
                                  key=lambda x: (x[2], x[1]), reverse=True))[:, 0].astype('int')
            reffile = reffile[idx]
            refzenith = refzenith[idx]
            refdt = refdt[idx]
            refsnr1 = refsnr1[idx]
            np.savez(sub(r'.*/',f't03_pair_{working_leo}/',rofile),reffile=reffile,refzenith=refzenith,refdt=refdt,refsnr1=refsnr1)