import numpy as np

def precess(ttt):
    arcsec2rad = np.pi / 180 / 3600
    zeta = ((0.017998 * ttt + 0.30188) * ttt + 2306.2181) * ttt * arcsec2rad
    theta = ((- 0.041833 * ttt - 0.42665) * ttt + 2004.3109) * ttt * arcsec2rad
    z = ((0.018203 * ttt + 1.09468) * ttt + 2306.2181) * ttt * arcsec2rad
    coszeta = np.cos(zeta)
    sinzeta = np.sin(zeta)
    costheta = np.cos(theta)
    sintheta = np.sin(theta)
    cosz = np.cos(z)
    sinz = np.sin(z)
    prec = np.full((3,3,ttt.size),np.nan)
    prec[0, 0] = coszeta * costheta * cosz - sinzeta * sinz
    prec[0, 1] = coszeta * costheta * sinz + sinzeta * cosz
    prec[0, 2] = coszeta * sintheta
    prec[1, 0] = -sinzeta * costheta * cosz - coszeta * sinz
    prec[1, 1] = -sinzeta * costheta * sinz + coszeta * cosz
    prec[1, 2] = -sinzeta * sintheta
    prec[2, 0] = -sintheta * cosz
    prec[2, 1] = -sintheta * sinz
    prec[2, 2] = costheta
    return prec

def fundarg(ttt):
    l = (((0.064 * ttt + 31.31) * ttt + 1717915922.633) * ttt / 3600 + 134.96298139) % 360 * (np.pi / 180)
    lp = (((-0.012 * ttt - 0.577) * ttt + 129596581.224) * ttt / 3600 + 357.52772333) % 360 * (np.pi / 180)
    f = (((0.011 * ttt - 13.257) * ttt + 1739527263.137) * ttt / 3600 + 93.27191028) % 360 * (np.pi / 180)
    d = (((0.019 * ttt - 6.891) * ttt + 1602961601.328) * ttt / 3600 + 297.85036306) % 360 * (np.pi / 180)
    o = (((0.008 * ttt + 7.455) * ttt - 6962890.539) * ttt / 3600 + 125.04452222) % 360 * (np.pi / 180)
    return l, lp, f, d, o

def nutataion(ttt, dpsi_r, depsilon_r, gst):
    l, lp, f, d, o = fundarg(ttt)
    meaneps = (((0.001813 * ttt - 0.00059) * ttt - 46.8150) * ttt + 84381.448) / 3600 % 360 * (np.pi / 180)
    with np.load('../iers/nut80.npz') as fid:
        iar80 = fid['iar80']
        rar80 = fid['rar80'] * (0.0001 * np.pi / (180*3600))
    deltapsi = np.sum(np.sin(np.matmul(iar80, [l, lp, f, d, o])) * (rar80[:, [0]] + rar80[:, [1]] * ttt), axis=0) + dpsi_r
    deltaeps = np.sum(np.cos(np.matmul(iar80, [l, lp, f, d, o])) * (rar80[:, [2]] + rar80[:, [3]] * ttt), axis=0) + depsilon_r
    trueeps = meaneps + deltaeps
    cospsi = np.cos(deltapsi)
    sinpsi = np.sin(deltapsi)
    coseps = np.cos(meaneps)
    sineps = np.sin(meaneps)
    costrueeps = np.cos(trueeps)
    sintrueeps = np.sin(trueeps)
    nut = np.full((3,3,ttt.size),np.nan)
    nut[0, 0] = cospsi
    nut[0, 1] = costrueeps * sinpsi
    nut[0, 2] = sintrueeps * sinpsi
    nut[1, 0] = -coseps * sinpsi
    nut[1, 1] = costrueeps * coseps * cospsi + sintrueeps * sineps
    nut[1, 2] = sintrueeps * coseps * cospsi - sineps * costrueeps
    nut[2, 0] = -sineps * sinpsi
    nut[2, 1] = costrueeps * sineps * cospsi - sintrueeps * coseps
    nut[2, 2] = sintrueeps * sineps * cospsi + costrueeps * coseps
    rad2arcsec = np.pi / 180 / 3600
    ast = gst + deltapsi * coseps + 0.00264 * rad2arcsec * np.sin(o) + 0.000063 * rad2arcsec * np.sin(2 * o);
    cosast = np.cos(ast)
    sinast = np.sin(ast)
    st = np.full((3,3,ttt.size),np.nan)
    st[0, 0] = cosast
    st[0, 1] = -sinast
    st[0, 2] = 0
    st[1, 0] = sinast
    st[1, 1] = cosast
    st[1, 2] = 0
    st[2, 0] = 0
    st[2, 1] = 0
    st[2, 2] = 1
    return nut, st

def polarm(xx, yy, ttt):
    cosx = np.cos(xx)
    cosy = np.cos(yy)
    sinx = np.sin(xx)
    siny = np.sin(yy)
    pm = np.full((3,3,ttt.size),np.nan)
    pm[0, 0] = cosx
    pm[0, 1] = 0
    pm[0, 2] = -sinx
    pm[1, 0] = sinx * siny
    pm[1, 1] = cosy
    pm[1, 2] = cosx * siny
    pm[2, 0] = sinx * cosy
    pm[2, 1] = -siny
    pm[2, 2] = cosx * cosy
    return pm

def ro_ecef2eci(pos, vel, dts):
    from astropy.time import Time, TimeDelta
    from astropy.utils.iers import IERS_A as iers_a
    from scipy.interpolate import CubicSpline
    sec2rad = np.pi / 180 / 3600
    msec2rad = sec2rad / 1000
    vel = vel/1E4 # dm/s -> km/s
    dts_mjd = (Time(dts) - TimeDelta(18, format='sec')).mjd - (pos[:,3]/1E6/86400 if pos.shape[-1]>3 else 0)  # leap seconds
    iers_table = iers_a.open('../iers/finals.daily')
    if dts_mjd[0] < iers_table['MJD'][0].value:
        iers_table = iers_a.open('../iers/finals.daily.extended')
    ut1utc = CubicSpline(iers_table['MJD'], iers_table['UT1_UTC'])(dts_mjd)
    dts_tut1 = (dts_mjd + ut1utc / 86400 - 51544.5) / 36525
    dts_gst = (((-6.2E-6 * dts_tut1 + 0.093104) * dts_tut1 + 3164400184.812866) * dts_tut1 + 67310.54841) * (
                np.pi / 43200) % (2 * np.pi)
    idx_t = np.isfinite(iers_table['MJD'])
    idx = np.logical_and(idx_t,np.isfinite(iers_table['PM_x_A']))
    xp = CubicSpline(iers_table['MJD'][idx], iers_table['PM_x_A'][idx])(dts_mjd)
    idx = np.logical_and(idx_t,np.isfinite(iers_table['PM_y_A']))
    yp = CubicSpline(iers_table['MJD'][idx], iers_table['PM_y_A'][idx])(dts_mjd)
    idx = np.logical_and(idx_t,np.isfinite(iers_table['dX_2000A_A']))
    dpsi = CubicSpline(iers_table['MJD'][idx], iers_table['dX_2000A_A'][idx])(dts_mjd)
    idx = np.logical_and(idx_t, np.isfinite(iers_table['dY_2000A_A']))
    depsilon = CubicSpline(iers_table['MJD'][idx], iers_table['dY_2000A_A'][idx])(dts_mjd)
    dts_ttt = (dts_mjd + 69.184 / 86400 - 51544.5) / 36525
    prec = precess(dts_ttt)
    nut, st = nutataion(dts_ttt, dpsi * msec2rad, depsilon * msec2rad, dts_gst)
    pm = polarm(xp * sec2rad, yp * sec2rad, dts_ttt)
    thetasa = 7.29211514670698E-5 * (1 - (ut1utc - 37) / 86400)
    omegaearth = thetasa[:,None] @ [[0,0,1]]
    matmxv = lambda m, v: np.array([x @ y for x, y in zip(m.transpose((2, 0, 1)), v)])
    pos_pef = matmxv(pm, pos[:, :3])
    pos_eci = matmxv(prec, matmxv(nut, matmxv(st, pos_pef))) * 1E3
    vel_pef = matmxv(pm, vel[:, :3])
    vel_eci = matmxv(prec, matmxv(nut, matmxv(st, vel_pef + np.cross(omegaearth, pos_pef)))) * 1E3
    return pos_eci, vel_eci

def ro_eci2ecef(pos, vel, dts):
    from astropy.time import Time, TimeDelta
    from astropy.utils.iers import IERS_A as iers_a
    from scipy.interpolate import CubicSpline
    sec2rad = np.pi / 180 / 3600
    msec2rad = sec2rad / 1000
    dts_mjd = (Time(dts) - TimeDelta(18, format='sec')).mjd  # leap seconds
    iers_table = iers_a.open('../iers/finals.daily')
    if dts_mjd[0] < iers_table['MJD'][0].value:
        iers_table = iers_a.open('../iers/finals.daily.extended')
    ut1utc = CubicSpline(iers_table['MJD'], iers_table['UT1_UTC'])(dts_mjd)
    dts_tut1 = (dts_mjd + ut1utc / 86400 - 51544.5) / 36525
    dts_gst = (((-6.2E-6 * dts_tut1 + 0.093104) * dts_tut1 + 3164400184.812866) * dts_tut1 + 67310.54841) * (
                np.pi / 43200) % (2 * np.pi)
    idx_t = np.isfinite(iers_table['MJD'])
    idx = np.logical_and(idx_t,np.isfinite(iers_table['PM_x_A']))
    xp = CubicSpline(iers_table['MJD'][idx], iers_table['PM_x_A'][idx])(dts_mjd)
    idx = np.logical_and(idx_t,np.isfinite(iers_table['PM_y_A']))
    yp = CubicSpline(iers_table['MJD'][idx], iers_table['PM_y_A'][idx])(dts_mjd)
    idx = np.logical_and(idx_t,np.isfinite(iers_table['dX_2000A_A']))
    dpsi = CubicSpline(iers_table['MJD'][idx], iers_table['dX_2000A_A'][idx])(dts_mjd)
    idx = np.logical_and(idx_t, np.isfinite(iers_table['dY_2000A_A']))
    depsilon = CubicSpline(iers_table['MJD'][idx], iers_table['dY_2000A_A'][idx])(dts_mjd)
    dts_ttt = (dts_mjd + 69.184 / 86400 - 51544.5) / 36525
    prec = precess(dts_ttt)
    nut, st = nutataion(dts_ttt, dpsi * msec2rad, depsilon * msec2rad, dts_gst)
    pm = polarm(xp * sec2rad, yp * sec2rad, dts_ttt)
    thetasa = 7.29211514670698E-5 * (1 - (ut1utc - 37) / 86400)
    omegaearth = thetasa[:,None] @ [[0,0,1]]
    matmpxv = lambda m, v: np.array([x @ y for x, y in zip(m.transpose((2, 1, 0)), v)])
    pos_pef = matmpxv(st, matmpxv(nut, matmpxv(prec, pos)))
    pos_ecef = matmpxv(pm, pos_pef)
    vel_pef = matmpxv(st, matmpxv(nut, matmpxv(prec, vel) - np.cross(omegaearth, pos_pef)))
    vel_ecef = matmpxv(pm, vel_pef)
    return pos_ecef, vel_ecef