from glob import glob
from re import sub
from datetime import datetime, timedelta
import numpy as np
from scipy.interpolate import CubicSpline
from tools import dt2float

lsat_names = dict(GN04=4,GN05=5,YM08=8)
for keys in list(lsat_names.keys()):
    lsat_names[lsat_names[keys]] = keys

def pair_ro_ref(dt):
    rofiles = glob('t01_RO/*.npz')
    reffiles = np.array(glob('t01_REF/*.npz'))
    if len(rofiles)==0 or len(reffiles)==0:
        return
    gnss_ref = np.expand_dims([sub(r'.*\.(...)_txt\.npz', r'\1', x) for x in reffiles],axis=0)
    dts_ro = np.expand_dims([datetime.strptime(sub(r'.*\.(\d{14})\.\d{14}\..*', r'\1', x),'%Y%m%d%H%M%S') for x in rofiles],axis=1)
    dte_ro = np.expand_dims([datetime.strptime(sub(r'.*\.\d{14}\.(\d{14})\..*', r'\1', x),'%Y%m%d%H%M%S') for x in rofiles],axis=1)
    dts_ref = np.expand_dims([datetime.strptime(sub(r'.*\.(\d{14})\.\d{14}\..*', r'\1', x),'%Y%m%d%H%M%S') for x in reffiles],axis=0)
    dte_ref = np.expand_dims([datetime.strptime(sub(r'.*\.\d{14}\.(\d{14})\..*', r'\1', x),'%Y%m%d%H%M%S') for x in reffiles],axis=0)
    leosat = sub(r'.*_(....)\..*',r'\1',rofiles[0])
    with np.load(f's02_zenith/zenith_{leosat}.{dt:%Y.%j}_.npz', allow_pickle=True) as fid:
        zenith = fid['zenith']
        Tnew = fid['Tnew']
        Gsats = fid['Gsats']
    pidx = (dts_ro>dts_ref)*(dte_ro<dte_ref)*(np.isin(gnss_ref,Gsats))
    for rofile,dts,dte,ppidx in zip(rofiles,dts_ro[:,0],dte_ro[:,0],pidx):
        if ~np.any(ppidx):
            continue
        reffile = reffiles[ppidx]
        tidx = np.where((Tnew>dts)*(Tnew<dte))[0]
        lidx = [list(Gsats).index(sub(r'.*\.(...)_.*', r'\1', x)) for x in reffile]
        refzenith = np.max(zenith[tidx][:, lidx], axis=0)
        refdt = Tnew[tidx[np.argmax(zenith[tidx][:, lidx], axis=0)]]
        refsnr1 = np.full(reffile.shape, np.nan)
        for pp, rr in enumerate(reffile):
            with np.load(rr, allow_pickle=True) as fid:
                data = fid['data']
                header = fid['header']
            sdata = [data[:, x] for x, y in enumerate(header) if 'S1' in y][0]
            tdata = np.array([datetime(1980,1,6)+timedelta(seconds=x) for x in data[:,0]+data[:,1]])
            refsnr1[pp] = CubicSpline(dt2float(tdata,refdt[pp]),sdata)(0)
        idx = np.array(sorted(np.array([range(len(refsnr1)), refsnr1, [ord(x[-11]) for x in reffile]]).transpose(),
                              key=lambda x: (x[2], x[1]), reverse=True))[:, 0].astype('int')
        reffile = reffile[idx]
        refzenith = refzenith[idx]
        refdt = refdt[idx]
        refsnr1 = refsnr1[idx]
        np.savez(sub(r'.*/',r't03_pair/',rofile),reffile=reffile,refzenith=refzenith,refdt=refdt,refsnr1=refsnr1)