import os
from glob import glob
from re import sub
import numpy as np
from datetime import datetime, timedelta
# from tools import get_sat_freq
from netCDF4 import Dataset

def log_message(msg, working_leo):
    with open(f"log/progress_rt_{working_leo}_{datetime.now():%Y%m%d}.txt", 'a') as fid:
        fid.write(f"{datetime.now():%H:%M:%S} {msg}\n")

def prepare_file(working_leo, ztype='all'):
    c0 = 299792458.0
    T0 = datetime(1980,1,6)

    if ztype=='navObs' or ztype=='all':
        log_message('- Start gathering navObs>podRx3 profiles.', working_leo)
        import pickle
        navfiles = glob(f'./L0_extract_{working_leo}/navObs/*nc')
        dts = np.unique([sub(r'.*_([-\d]+)T.*', r'\1', navfile) for navfile in navfiles])
        dts = np.array([datetime.strptime(dt, '%Y-%m-%d') for dt in dts])
        leos = np.unique([sub(r'.*_(FM\d\d\d)\.nc', r'\1', navfile) for navfile in navfiles])
        for dt in dts:
            for leo in leos:
                crxfileout = f's01_podRx3/podRx3_{leo}.{dt:%Y.%j}.00.pkl'
                crxfileprev = f's01_podRx3/podRx3_{leo}.{dt - timedelta(days=1):%Y.%j}.00.pkl'
                if os.path.exists(crxfileout):
                    log_message(f'-- Combining {crxfileout}', working_leo)
                    with open(crxfileout, 'rb') as fid:
                        pkldata = pickle.load(fid)
                        dsat = pkldata['dsat']
                        dtime = pkldata['dtime']
                        dataall = pkldata['dataall']
                elif os.path.exists(crxfileprev):
                    log_message(f'-- Retrieving {crxfileprev}', working_leo)
                    with open(crxfileprev, 'rb') as fid:
                        pkldata = pickle.load(fid)
                        dsat = pkldata['dsat']
                        dtime = pkldata['dtime']
                        dataall = pkldata['dataall']
                    for dnum in range(len(dsat)):
                        tsidx = np.where(np.diff(np.hstack([[T0], dtime[dnum]])) > timedelta(seconds=3))[0].astype('int')
                        teidx = np.hstack([tsidx[1:], tsidx[:1]]) - 1
                        tidx = tsidx[dtime[dnum][teidx] > dt - timedelta(hours=1)]
                        if tidx.size > 0:
                            dtime[dnum] = dtime[dnum][tidx[0]:]
                            dataall[dnum] = dataall[dnum][tidx[0]:, :]
                        else:  # assign empty array with .size == 0 to be removed afterward
                            dtime[dnum] = dtime[dnum][:0]
                            dataall[dnum] = dataall[dnum][:0]
                    didx = np.where([tt.size > 0 for tt in dtime])[0]
                    dsat = [dsat[x] for x in didx]
                    dtime = [np.array(dtime[x]) for x in didx]
                    dataall = [np.array(dataall[x]) for x in didx]
                else:
                    log_message(f'-- Creating {crxfileout}', working_leo)
                    dsat = []
                    dtime = []
                    dataall = []

                navfiles = glob(f'./L0_extract_{working_leo}/navObs/*_{dt:%Y-%m-%d}T*_{leo}.nc')
                navfiles.sort()
                log_message(f'--- Processing {leo} @ {dt:%Y-%m-%d}: {len(navfiles)} Files', working_leo)
                for navfile in navfiles:
                    with Dataset(navfile, 'r') as fid:
                        sig_val = fid['signal_type'].getncattr('flag_values')
                        sig_band = fid['signal_type'].getncattr('flag_meanings').split(' ')
                        sig_dict = dict(zip(sig_val, sig_band))
                        sv_id = fid['sv_id'][...]
                        idx = sv_id.mask==np.full(sv_id.shape, False)
                        sgn_typ = fid['signal_type'][...][idx]
                        ant_id = fid['virtual_antenna_id'][...][idx]
                        status = fid['status'][...][:, idx]
                        ref_wk = fid['time'].getncattr('ref_gps_week')
                        ref_sow = fid['time'].getncattr('ref_gps_sow')
                        ref_sec = ref_wk * 7 * 86400 + ref_sow
                        data_c = fid['pseudorange'][...][:, idx]
                        data_l = fid['phase'][...][:, idx]
                        # data_d = fid['doppler'][...][:, idx]
                        data_s = fid['cn0'][...][:, idx]
                        ftime = fid['time'][...]
                    gnss_band = np.array([f"{sig_dict[x][0]}{y:02d}-L{'1' if sig_dict[x][1] == '1' else '2'}" for x, y in
                                          zip(sgn_typ, sv_id)])
                    idx = np.argsort(gnss_band)
                    gnss_band = gnss_band[idx]
                    ant_id = ant_id[idx]
                    status = status[:, idx]
                    data_c = data_c[:, idx]
                    data_l = data_l[:, idx]
                    # data_d = data_d[:, idx]
                    data_s = data_s[:, idx]
                    bidx = 0
                    while bidx + 1 < len(ant_id):
                        if ant_id[bidx] == 0 and gnss_band[bidx][-1] == '1' and gnss_band[bidx + 1][-1] == '2' and \
                                gnss_band[bidx + 1][:-1] == gnss_band[bidx][:-1]:
                            idx_val = np.all(np.isin(status[:, bidx:bidx + 2], [21, 22, 53, 54]), axis=1)
                            if np.any(idx_val):
                                gsat = gnss_band[bidx][:3]
                                gtime = np.array([T0 + timedelta(seconds=ref_sec + x) for x in ftime[idx_val]])
                                gtmp = np.column_stack(
                                    [data_l[idx_val, bidx], data_l[idx_val, bidx + 1],
                                     data_s[idx_val, bidx], data_s[idx_val, bidx + 1],
                                     data_c[idx_val, bidx], data_c[idx_val, bidx + 1]])
                                if gsat not in dsat:
                                    dsat.append(gsat)
                                    dtime.append(gtime)
                                    dataall.append(gtmp)
                                else:
                                    gnum = dsat.index(gsat)
                                    dtime[gnum] = np.hstack([dtime[gnum], gtime])
                                    dataall[gnum] = np.vstack([dataall[gnum], gtmp])
                            bidx += 2
                        else:
                            bidx += 1
                didx = np.argsort(dsat)
                dsat = [dsat[x] for x in didx]
                dtime = [dtime[x] for x in didx]
                dataall = [dataall[x] for x in didx]
                for gnum in range(len(dsat)):
                    dtime[gnum], idx = np.unique(dtime[gnum], return_index=True)
                    dataall[gnum] = dataall[gnum][idx]
                pkldata = {'dsat': dsat, 'dtime': dtime, 'dataall': dataall}
                log_message(f'--- Saving {crxfileout}', working_leo)
                with open(crxfileout, 'wb') as fid:
                    pickle.dump(pkldata, fid)

    if ztype == 'navSol' or ztype == 'all':
        log_message('- Start gathering navSol>leoOrb profiles.', working_leo)
        orbfiles = glob(f'./L0_extract_{working_leo}/navSol/*nc')
        leo_sats = np.unique([sub(r'.*_(FM\d\d\d)\.nc', r'\1', x) for x in orbfiles])
        leo_dts = np.unique([sub(r'.*_([-\d]+)T.*', r'\1', orbfile) for orbfile in orbfiles])
        leo_dts = np.array([datetime.strptime(dt, '%Y-%m-%d') for dt in leo_dts])
        for ff in leo_sats:
            for tt in leo_dts:
                log_message(f"-- Processing {ff} @ {tt:%Y-%m-%d}...", working_leo)
                orbfiles = glob(f'./L0_extract_{working_leo}/navSol/*_{tt:%Y-%m-%d}T*_{ff}.nc')
                if len(orbfiles)==0:
                    continue
                orbfiles.sort()
                orbfileout = f'./s01_leoOrb/leoOrb_{ff}.{tt:%Y.%j}.npz'
                orbfileprev = f'./s01_leoOrb/leoOrb_{ff}.{tt - timedelta(days=1):%Y.%j}.npz'
                dts0 = np.full((0,), np.nan)
                data0 = np.full((0,8), np.nan)
                for orbfile in orbfiles:
                    with Dataset(orbfile, 'r') as fid:
                        ref_wk = fid['time'].getncattr('ref_gps_week')
                        ref_sow = fid['time'].getncattr('ref_gps_sow')
                        ref_sec = ref_wk * 7 * 86400 + ref_sow
                        dts0 = np.hstack([dts0, [T0+timedelta(seconds=ref_sec+x) for x in fid['time'][...]]])
                        pos = fid['position'][...] / 1E3
                        clk = fid['clock_error'][...] / c0 * 1E6
                        vel = fid['velocity'][...] * 10
                        clkr = fid['clock_error_rate'][...] / c0 * 1E10
                        data0 = np.vstack([data0,np.hstack([pos,clk[:,None],vel,clkr[:,None]])])
                if os.path.isfile(orbfileout):
                    with np.load(orbfileout, allow_pickle=True) as fid:
                        data = fid['data']
                        dts = fid['dts']
                    idx = ~np.isin(dts, dts0)
                    dts = np.hstack([dts[idx],dts0])
                    data = np.vstack([data[idx], data0])
                elif os.path.isfile(orbfileprev):
                    with np.load(orbfileprev, allow_pickle=True) as fid:
                        data = fid['data']
                        dts = fid['dts']
                    idx = np.logical_and(~np.isin(dts, dts0),dts>tt-timedelta(hours=3))
                    dts = np.hstack([dts[idx],dts0])
                    data = np.vstack([data[idx], data0])
                else:
                    dts = dts0
                    data = data0
                dts, idx = np.unique(dts, return_index=True)
                data = data[idx]
                np.savez(orbfileout, data=data, dts=dts)

    if ztype == 'attObs' or ztype == 'all':
        log_message('- Start gathering attObs>leoAtt profiles.', working_leo)
        T0_att_ls = datetime(1970, 1, 1, 0, 0, 18)
        attfiles = glob(f'./L0_extract_{working_leo}/attObs/*.nc')
        leo_sats = np.unique([sub(r'.*_(FM\d\d\d)\.nc', r'\1', x) for x in attfiles])
        leo_dts = np.unique([sub(r'.*_([-\d]+)T.*', r'\1', attfile) for attfile in attfiles])
        leo_dts = np.array([datetime.strptime(dt, '%Y-%m-%d') for dt in leo_dts])
        vars_list = ['sca_x', 'sca_y', 'sca_z', 'sca_w']
        for ff in leo_sats:
            for tt in leo_dts:
                log_message(f"-- Processing {ff} @ {tt:%Y-%m-%d}...", working_leo)
                attfiles = glob(f'./L0_extract_{working_leo}/attObs/*_{tt:%Y-%m-%d}T*_{ff}.nc')
                if len(attfiles)==0:
                    continue
                attfiles.sort()
                attfileout = f'./s01_leoAtt/leoAtt_{ff}.{tt:%Y.%j}.npz'
                attfileprev = f'./s01_leoAtt/leoAtt_{ff}.{tt-timedelta(days=1):%Y.%j}.npz'
                dts0 = np.full((0,), np.nan)
                data0 = np.full((0, 4), np.nan)
                for attfile in attfiles:
                    with Dataset(attfile, 'r') as fid:
                        dts0 = np.hstack([dts0, [T0_att_ls + timedelta(seconds=x) for x in fid['time'][...].astype('float64')]])
                        data0 = np.vstack([data0, fid['qbo'][...]])
                if os.path.isfile(attfileout):
                    with np.load(attfileout, allow_pickle=True) as fid:
                        data = fid['data']
                        dts = fid['dts']
                    idx = ~np.isin(dts, dts0)
                    dts = np.hstack([dts[idx],dts0])
                    data = np.vstack([data[idx], data0])
                elif os.path.isfile(attfileprev):
                    with np.load(attfileprev, allow_pickle=True) as fid:
                        data = fid['data']
                        dts = fid['dts']
                    idx = np.logical_and(~np.isin(dts, dts0),dts>tt-timedelta(hours=3))
                    dts = np.hstack([dts[idx],dts0])
                    data = np.vstack([data[idx], data0])
                else:
                    dts = dts0
                    data = data0
                dts, idx = np.unique(dts, return_index=True)
                data = data[idx]
                np.savez(attfileout, data=data, dts=dts, vars=vars_list)

    if ztype == 'roeRef' or ztype == 'all':
        pass
    log_message('- Completed preparing files.', working_leo)
