import os
from glob import glob
from re import sub
import numpy as np
from datetime import datetime, timedelta
from spire.tools import get_sat_freq, to_datetime # Jun modify
from netCDF4 import Dataset
from traceback import extract_tb

def log_message(msg, working_leo):
    with open(f"log/progress_rt_{working_leo}_{datetime.now():%Y%m%d}.txt", 'a') as fid:
        fid.write(f"{datetime.now():%H:%M:%S} {msg}\n")

def gather_data(working_leo):
    c0 = 299792458.0
    T0 = datetime(1980,1,6)
    ant_map = {'PRIMARY': 0, 'SETTING': 1, 'RISING': 2}

    log_message('- Start gathering RO profiles.', working_leo)
    tStart_min = datetime.now() + timedelta(days=1)
    tEnd_max = T0
    l1files = np.array(glob(f'./L0_extract_{working_leo}/rocObs/*_L1*.nc'))
    l1files.sort()
    l2files = l1files.copy()
    for nn, l1file in enumerate(l1files):
        l2file = glob(sub(r'(.*_L)1.(_.*nc)',r'\g<1>[2-9]*\g<2>',l1file))
        if len(l2file)==1:
            l2files[nn] = l2file[0]
        else:
            l2files[nn] = ''
    idx = l2files != ''
    l1files = l1files[idx]
    l2files = l2files[idx]
    sat_leos = []
    for l1file, l2file in zip(l1files, l2files):
        try:
            dt_now = datetime.strptime(sub(r'.*_([-\d]+)T([-\d]+)_.*',r'\1 \2',l1file),'%Y-%m-%d %H-%M-%S')
            sat_gnss = sub(r'.*_([A-Z]\d\d)_L.*',r'\1',l1file)
            sat_leo = sub(r'.*_(FM\d\d\d)_ant.*',r'\1',l1file)
            sat_leos.append(sat_leo)
            band_l1 = sub(r'.*_(L1.)_.*',r'\1',l1file)
            band_l2 = sub(r'.*_(L[2-9].)_.*',r'\1',l2file)
            with Dataset(l1file,'r') as fid1, Dataset(l2file,'r') as fid2:
                rate = fid1.getncattr('sampling_rate_hz')
                ant = ant_map[fid1.getncattr('virtual_antenna_id')]
                ref_wk = fid1['time'].getncattr('ref_gps_week')
                ref_sow = fid1['time'].getncattr('ref_gps_sow')
                ref_fos = fid1['time'].getncattr('ref_gps_fos')
                ref_sec = ref_wk * 7 * 86400 + ref_sow
                time1 = fid1['time'][...]
                time2 = fid2['time'][...]
                if len(time1) == len(time2) and np.all(np.abs(time1 - time2) < 1e-4):
                    time = time1
                else:
                    raise ValueError('L1 L2 times do not align')
                nf1 = fid1.getncattr('noise_floor')
                mdl_ph1 = fid1['model_phase'][...]
                i_L1 = fid1['i'][:,0].astype('float64')
                q_L1 = fid1['q'][:,0].astype('float64')
                nf2 = fid2.getncattr('noise_floor')
                mdl_ph2 = fid2['model_phase'][...]
                i_L2 = fid2['i'][:,0].astype('float64')
                q_L2 = fid2['q'][:,0].astype('float64')
            snr1 = np.sqrt(rate * (i_L1 ** 2 + q_L1 ** 2)) / nf1
            snr2 = np.sqrt(rate * (i_L2 ** 2 + q_L2 ** 2)) / nf2
            ang1 = np.arctan2(q_L1, i_L1)
            ang2 = np.arctan2(q_L2, i_L2)
            f1, f2, _ = get_sat_freq(sat_gnss)
            mdl_ph1_m = mdl_ph1 / f1 * c0
            ph1 = mdl_ph1_m - ang1 / (2 * np.pi) / f1 * c0
            mdl_ph2_m = mdl_ph2 / f2 * c0
            ph2 = mdl_ph2_m - ang2 / (2 * np.pi) / f2 * c0
            header = ['GPS_seconds', 'time_offset', band_l1, band_l2, band_l1 + '(M)',
                      sub('L', 'S', band_l1), sub('L', 'S', band_l2)]
            data = np.column_stack([np.full_like(time1, ref_sec), time1 + ref_fos, ph1, ph2, mdl_ph1_m, snr1, snr2])
            tStart, tEnd = [T0 + timedelta(seconds=data[0, 0] + data[x, 1]) for x in [0, -1]]
            if tStart < tStart_min:
                tStart_min = tStart
            if tEnd > tEnd_max:
                tEnd_max = tEnd
            occFileName = f"t01_RO_{working_leo}/occTab_{sat_leo}.{dt_now:%Y.%j}.{tStart:%Y%m%d%H%M%S}.{tEnd:%Y%m%d%H%M%S}.A{ant:02d}.{sat_gnss}.npz"
            np.savez(occFileName, data=data, header=header)
        except Exception as e:
            log_message(f'-- Error detected for {os.path.basename(l1file)}<>{os.path.basename(l2file)}', working_leo)
            log_message(f'---  {type(e).__name__} >> {str(e)}', working_leo)
            for tb in extract_tb(e.__traceback__):
                log_message(f"---  line {tb.lineno} of {tb.filename}", working_leo)
    sat_leos = np.unique(sat_leos)
    log_message(f"-- Processed {len(l1files)} profiles", working_leo)

    log_message('- Start gathering REF profiles.', working_leo)
    import pickle
    sat_dts = np.arange(datetime(*tStart_min.timetuple()[:3]) - timedelta(days=1),
                        datetime(*tEnd_max.timetuple()[:3]) + timedelta(days=2), timedelta(days=1)).astype(datetime)
    header = ['GPS_seconds', 'time_offset', 'L1C', 'L2X', 'S1C', 'S2X', 'C1C', 'C2X']
    # navfiles = glob(f'./L0_extract_{working_leo}/navObs/*nc')
    # dts = np.unique([sub(r'.*_([-\d]+)T.*',r'\1',navfile) for navfile in navfiles])
    # dts = np.array([datetime.strptime(dt, '%Y-%m-%d') for dt in dts])
    # leos = np.unique([sub(r'.*_(FM\d\d\d)\.nc',r'\1',navfile) for navfile in navfiles])

    for leo in sat_leos:
        dsat = []
        dtime = []
        dataall = []
        for dt in sat_dts:
            crxfileout = f's01_podRx3/podRx3_{leo}.{dt:%Y.%j}.00.pkl'
            if not os.path.exists(crxfileout):
                continue
            with open(crxfileout, 'rb') as fid:
                pkldata = pickle.load(fid)
                dsat0 = pkldata['dsat']
                dtime0 = pkldata['dtime']
                dataall0 = pkldata['dataall']
            for nn0, ss0 in enumerate(dsat0):
                if ss0 in dsat:
                    gnum = dsat.index(ss0)
                    dtime[gnum] = np.hstack([dtime[gnum], dtime0[nn0]])
                    dataall[gnum] = np.vstack([dataall[gnum], dataall0[nn0]])
                else:
                    dsat.append(ss0)
                    dtime.append(dtime0[nn0])
                    dataall.append(dataall0[nn0])
        didx = np.argsort(dsat)
        dsat = [dsat[x] for x in didx]
        dtime = [dtime[x] for x in didx]
        dataall = [dataall[x] for x in didx]
        for dnum in range(len(dsat)):
            dtime[dnum], idx = np.unique(dtime[dnum], return_index=True)
            dataall[dnum] = dataall[dnum][idx]
            f1,f2,_ = get_sat_freq(dsat[dnum])
            idx = np.all(np.isfinite(dataall[dnum]), axis=1)
            dtime[dnum] = dtime[dnum][idx]
            dataall[dnum] = dataall[dnum][idx]
            idx = [0] + list(np.where(np.diff(dtime[dnum]) > timedelta(seconds=3))[0] + 1) + [len(dtime[dnum])]
            for ss in range(len(idx) - 1):
                if idx[ss + 1] - idx[ss] < 60 or dtime[dnum][idx[ss + 1] - 1] < tStart_min or dtime[dnum][idx[ss]] > tEnd_max:
                    continue
                gps_sec = (dtime[dnum][idx[ss]] - T0) / timedelta(seconds=1)
                tdiff = (dtime[dnum][idx[ss]:idx[ss + 1]] - dtime[dnum][idx[ss]]) / timedelta(seconds=1)
                data = np.column_stack((
                    gps_sec + np.zeros_like(tdiff),
                    tdiff,
                    dataall[dnum][idx[ss]:idx[ss+1]] * np.array([c0 / f1, c0 / f2, 1, 1, 1, 1])
                )).astype('float64')
                tStart, tEnd = [T0 + timedelta(seconds=x[0] + x[1]) for x in data[[0, -1], :2]]
                occFileName = f't01_REF_{working_leo}/occTab_{leo}.{tStart:%Y.%j.%Y%m%d%H%M%S}.{tEnd:%Y%m%d%H%M%S}.A00.{dsat[dnum]}.npz'
                np.savez(occFileName, data=data, header=header)

    log_message('- Completed gathering profiles.', working_leo)

def gather_ucarsp3(working_leo):
    orbfiles = glob('/data3/tcliu/gps_pod/spire_sp3/*sp3')
    leo_sats = np.unique([sub(r'.*_\d{4}\.\d{3}\.(\d\d\d)\.\d\d_.*', r'FM\1', x) for x in orbfiles])
    for ff in leo_sats:
        log_message(f"-- Processing {ff}...", working_leo)
        orbfiles = glob(f'/data3/tcliu/gps_pod/spire_sp3/*.{ff[2:]}.*_sp3')
        orbfiles.sort()
        dt = datetime.strptime(sub(r'.*_(\d{4}\.\d{3})\....\.\d\d_.*', r'\1', orbfiles[0]), '%Y.%j')
        orbfileout = f'./u01_leoOrb/leoOrb_{ff}.{dt:%Y.%j}.npz'
        dts0 = []
        data0 = []
        tmp0 = np.full((8,),np.nan)
        for orbfile in orbfiles:
            newfile = True
            for line in open(orbfile,'r'):
                if line.startswith("*"):
                    if not newfile:
                        data0.append(tmp0.copy())
                    else:
                        newfile = False
                        tmp0[:] = np.nan
                    dts0.append(to_datetime(line[1:].split()))
                elif line.startswith("P"):
                    tmp0[:4] = np.array([line[x:x+14] for x in range(4,60,14)]).astype('float')
                elif line.startswith("V"):
                    tmp0[4:] = np.array([line[x:x+14] for x in range(4,60,14)]).astype('float')
            data0.append(tmp0.copy())
        dts0, idx = np.unique(np.array(dts0), return_index=True)
        data0 = np.array(data0)[idx]
        if os.path.isfile(orbfileout):
            with np.load(orbfileout, allow_pickle=True) as fid:
                data = fid['data']
                dts = fid['dts']
            idx = ~np.isin(dts, dts0)
            dts = np.hstack([dts[idx],dts0])
            data = np.vstack([data[idx], data0])
            idx = np.argsort(dts)
            dts = dts[idx]
            data = data[idx]
        else:
            dts = dts0
            data = data0
        np.savez(orbfileout, data=data, dts=dts)
