import numpy as np
from datetime import datetime, timedelta

def to_datetime(dtlist):
    return datetime(int(dtlist[0]), int(dtlist[1]), int(dtlist[2]),int(dtlist[3]),int(dtlist[4]),int(float(dtlist[5])),round((float(dtlist[5])%1)*1E6))

def dt2float(dt,dt0):
    return ((dt - dt0) / timedelta(seconds=1)).astype('float')

def gps_week(dt):
    T0 = datetime(1980,1,6)
    return (dt-T0)//timedelta(days=7)

def disp_gnss(gfile,gsat,to_eci):
    from matplotlib import pyplot as plt
    from re import sub
    dt0 = datetime.strptime(sub(r'.*(\d{4}\.\d{3}).*', r'\1', gfile), '%Y.%j')
    with np.load(gfile, allow_pickle=True) as fid:
        Gdata = fid['Gdata']
        Gdt = fid['Gdt']
        Gsats = fid['Gsats']
    if ~np.isin(gsat,Gsats):
        return
    idx = np.where(Gsats == gsat)[0][0]
    Gdata = Gdata[:,idx,:]
    if to_eci:
        from frame_conversions import ro_ecef2eci
        pos, vel = ro_ecef2eci(Gdata[:, :3], Gdata[:, 4:7], Gdt)
        pos /= 1E3
        vel /= 1E3
    else:
        pos = Gdata[:, :3]
        vel = Gdata[:, 4:7] / 1E4
    fig, axs = plt.subplots(3, 1, figsize=(12, 9))
    axs[0].plot(dt2float(Gdt, dt0) / 3600, pos, '.', ms=0.5)
    axs[1].plot(dt2float(Gdt, dt0) / 3600, vel, '.', ms=0.5)
    axs[2].plot(dt2float(Gdt, dt0) / 3600, Gdata[:, 3], '.', ms=0.5)
    for ax in axs:
        ax.grid(True)
        ax.set_xlim((0, 72))
        ax.set_xticks(range(0, 75, 12))
    axs[0].set_title(gsat)
    # axs[2].set_ylim((-1000, 1000))
    axs[0].set_ylabel('Position (km)')
    axs[1].set_ylabel('Velocity (km/s)')
    axs[2].set_ylabel('Clock Bias (us)')
    axs[2].set_xlabel(f'Hours of {dt0:%Y-%m-%d}')
    fig.savefig(sub(r'.*/(.*)\.npz',r'\1.png',gfile), dpi=300)

def read_GNSS(file):
    Gsats = []
    Gdata = []
    isdata = False
    for line in open(file,'r'):
        if not isdata and line.startswith('+ '):
            Gsats = Gsats + [line[x:x + 3] for x in range(9, 60, 3) if line[x] != ' ']
        elif not isdata and line.startswith('*'):
            isdata = True
            Gsats = np.array(Gsats)
            Gtmp = np.full((len(Gsats),8), np.nan)
            Gdt = [to_datetime(line[1:].split())]
        elif isdata and line.startswith('*'):
            Gdata.append(Gtmp.copy())
            Gdt.append(to_datetime(line[1:].split()))
        elif isdata and line.startswith('P'):
            Gtmp[Gsats==line[1:4],:4] = np.array([line[x:x+14] for x in range(4,60,14)]).astype('float')
        elif isdata and line.startswith('V'):
            Gtmp[Gsats==line[1:4],4:] = np.array([line[x:x+14] for x in range(4,60,14)]).astype('float')
        elif isdata and line.startswith('EOF'):
            Gdata.append(Gtmp.copy())
            break
    Gdt = np.array(Gdt)
    Gdata = np.array(Gdata)
    Gdata[np.abs(Gdata - 1E6) < 1E-3] = np.nan
    return Gdata, Gdt, Gsats

def lagrange(X, Y, x, order):
    X = np.ravel(X)
    nX = len(X)
    nx = len(x)
    if Y.ndim == 1:
        y = np.full((nx, ), np.nan)
        dy = np.full((nx,), np.nan)
    else:
        y = np.full((nx, Y.shape[1]), np.nan)
        dy = np.full((nx, Y.shape[1]), np.nan)

    T = X[:, None] - X[None, :]
    np.fill_diagonal(T, 1)
    for ii in range(nx):
        idx = np.searchsorted(X, x[ii])
        idxL = max(min(idx - order, nx - 2 * order), 0)
        if idx < X.size and X[idx] == x[ii]:
            idxR = min(max(idx + order + 1, 2 * order), nX)
            J = np.arange(idxL, idxR)
            pX0 = x[ii] - X[J]
            R = np.prod(T[J][:, J], axis=1)
            y[ii] = Y[idx]
            pX1 = pX0[pX0 != 0]
            vX = np.full_like(pX0, np.prod(pX1))
            vX[pX0 == 0] *= np.sum(1 / pX1)
            vX[pX0 != 0] /= pX1
            dy[ii] = (vX / R) @ Y[J]
        else:
            idxR = min(max(idx + order, 2 * order), nX)
            J = np.arange(idxL, idxR)
            pX0 = x[ii] - X[J]
            R = np.prod(T[J][:, J], axis=1)
            pX = np.prod(pX0) / pX0 # method 1, quick as it multiplies all numerators and all denominators before division
            vX = pX * (np.sum(1 / pX0) - 1 / pX0)
            y[ii] = (pX / R) @ Y[J]
            dy[ii] = (vX / R) @ Y[J]
            # pX = np.tile(x[ii] - X[J], (len(J), 1)) # method 2, reduce error as it pairs numerators and denominators
            # np.fill_diagonal(pX, 1)
            # R = T[J][:,J]
            # y[ii] = np.prod(pX / R, axis=1) @ Y[J]
    return y, dy