import sys
import numpy as np
import pandas as pd
from scipy.interpolate import CubicHermiteSpline, CubicSpline
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares
import matplotlib.pyplot as plt
from astropy.time import Time, TimeDelta
from astropy.coordinates import ITRS, GCRS, CartesianRepresentation, CartesianDifferential, get_body_barycentric_posvel
import astropy.units as u
from tools import dt2float
from datetime import datetime, timedelta

# ---------- User arguments ----------
# if len(sys.argv) < 3:
#     print("Usage: python sp3_30s_compare.py <sp3_file> <sat_id (e.g. G01)>")
#     sys.exit(1)
#
# SP3_FILE = sys.argv[1]
# SAT_ID = sys.argv[2]  # e.g. 'G01'
SP3_FILE = 'temp/GRG0OPSULT_20253070000_02D_05M_ORB.SP3.txt'
SAT_ID = 'G01'

# ---------- Parameters ----------
OUT_DT = 30.0            # output cadence in seconds
SRP_INIT = 1e-3          # initial Cr*A/m (m^2/kg) guess
OBS_SIGMA_M = 0.05       # assumed SP3 position uncertainty (m) for residual weighting
INTEGRATOR_RK = 'DOP853' # integrator
RTOL = 1e-9
ATOL = 1e-11

# ---------- Utilities ----------
def read_sp3_pclk(sp3_path, sat_id):
    """
    Very small SP3 parser to extract position (meters) and clock (seconds)
    for a single satellite across epochs.
    Returns: times_astropy (Astropy Time array), pos_itrs (Nx3 meters, ITRS/ECEF), clk_s (seconds)
    """
    times = []
    pos = []
    clk = []
    current_epoch = None
    with open(sp3_path, 'r') as fh:
        for line in fh:
            if line.startswith('*'):
                # epoch line: "*  yyyy mm dd hh mm ss.sssssssss"
                parts = line.split()
                # sometimes line has "*  YYYY MM DD HH MM SS.sssss"
                y = int(parts[1]); m = int(parts[2]); d = int(parts[3])
                hh = int(parts[4]); mm = int(parts[5]); ss = float(parts[6])
                # create astropy Time in UTC (SP3 epoch typically GPS time reference; treating as UTC is fine for transforms if consistent)
                # We'll treat as UTC and then convert; for high accuracy ensure correct time system.
                current_epoch = Time(f"{y}-{m:02d}-{d:02d}T{hh:02d}:{mm:02d}:{ss:012.9f}", format='isot', scale='utc')
                continue
            if line.startswith('P') and line[1:4] == sat_id:
                # typical SP3 P line columns: P PRN x y z clock (x,y,z km, clock microsec)
                # SP3 formatting can vary; we use fixed-column slicing that works for SP3c/d:
                # columns approximate: 0-1 'P', 1-4 sat, 4-18 X, 18-32 Y, 32-46 Z, 46-60 clock
                try:
                    xs = float(line[4:18].strip()) * 1000.0
                    ys = float(line[18:32].strip()) * 1000.0
                    zs = float(line[32:46].strip()) * 1000.0
                    clk_field = line[46:60].strip()
                    # clock in SP3 is often in microseconds; sometimes in seconds depending on variant - attempt conversion:
                    if clk_field == '':
                        clk_s = 0.0
                    else:
                        # parse as float; if magnitude > 1e3 assume microsec -> convert to sec
                        clk_val = float(clk_field)
                        clk_s = clk_val * 1e-6
                except Exception as e:
                    # fallback try split tokens
                    tokens = line.split()
                    xs = float(tokens[2]) * 1000.0
                    ys = float(tokens[3]) * 1000.0
                    zs = float(tokens[4]) * 1000.0
                    clk_s = float(tokens[5])
                times.append(current_epoch)
                pos.append([xs, ys, zs])
                clk.append(clk_s)
    if len(times) == 0:
        raise ValueError("No P records found for satellite " + sat_id)
    times = Time(times)  # Astropy Time array
    return times, np.array(pos), np.array(clk)

def times_to_seconds(time_astropy, epoch_ref=None):
    """Return seconds since epoch_ref (astropy Time). If epoch_ref None use first time."""
    if epoch_ref is None:
        epoch_ref = time_astropy[0]
    dt = (time_astropy - epoch_ref).to(u.s).value
    return dt, epoch_ref

# ---------- Step 1: read SP3 ----------
print("Reading SP3:", SP3_FILE, "sat:", SAT_ID)
times_astro, pos_itrs_m, clk_s = read_sp3_pclk(SP3_FILE, SAT_ID)
N = len(times_astro)
print(f"Loaded {N} epochs from SP3 for {SAT_ID}")

# Expect 48h * 12 samples/hr = 576 + 1 maybe; user said 48h at 5 min -> 577 samples
# Basic checks:
if N < 10:
    raise ValueError("Too few epochs found: " + str(N))

# ---------- Step 2: compute velocities from positions (finite differences) ----------
tsec_full, t0_ast = times_to_seconds(times_astro)
# central differences (numpy gradient handles uneven spacing)
vel_itrs_mps = np.gradient(pos_itrs_m, tsec_full, axis=0)

# ---------- Step 3: Hermite interpolation (Method A) ----------
# Build Hermite splines per coordinate
print("Building Hermite splines for positions (Method A)")
hermite_splines = []
for i in range(3):
    ch = CubicHermiteSpline(tsec_full, pos_itrs_m[:, i], vel_itrs_mps[:, i])
    # ch = CubicSpline(tsec_full, pos_itrs_m[:, i])
    hermite_splines.append(ch)
# clock spline
clk_spline = CubicSpline(tsec_full, clk_s, bc_type='natural')

# make 30s grid covering last 24h (we will generate dense for entire interval but only keep last 24h in outputs)
t_end = tsec_full[-1]
t_start_last24 = t_end - 24*3600.0
t_out_full = np.arange(tsec_full[0], t_end+1e-6, OUT_DT)
t_out_last24_mask = t_out_full >= t_start_last24

pos_herm_full = np.vstack([s(t_out_full) for s in hermite_splines]).T
vel_herm_full = np.vstack([s.derivative()(t_out_full) for s in hermite_splines]).T
clk_herm_full = clk_spline(t_out_full)

# slice last 24h
t_out = t_out_full[t_out_last24_mask]
pos_herm = pos_herm_full[t_out_last24_mask]
vel_herm = vel_herm_full[t_out_last24_mask]
clk_herm = clk_herm_full[t_out_last24_mask]

print(f"Hermite output: {len(t_out)} samples for last 24h at {OUT_DT}s cadence.")

# ---------- Helper: transforms ITRS <-> GCRS using astropy ----------
def itrs_to_gcrs_positions(pos_itrs_m, times_astropy):
    """
    Convert Nx3 ITRS coordinates (meters) sampled at times_astropy (Astropy Time array)
    into GCRS cartesian positions (meters).
    Returns array shape (N,3)
    """
    gcrs_positions = []
    for i, t in enumerate(times_astropy):
        # build ITRS CartesianRepresentation
        itrs = ITRS(CartesianRepresentation(pos_itrs_m[i, :] * u.m), obstime=t)
        gcrs = itrs.transform_to(GCRS(obstime=t))
        xyz = gcrs.cartesian.xyz.to(u.m).value.flatten()
        gcrs_positions.append(xyz)
    return np.array(gcrs_positions)

def gcrs_to_itrs_positions(pos_gcrs_m, times_astropy):
    """
    Convert Nx3 GCRS cartesian positions (meters) to ITRS at times_astropy
    """
    itrs_positions = []
    for i, t in enumerate(times_astropy):
        g = GCRS(CartesianRepresentation(pos_gcrs_m[i,:] * u.m), obstime=t)
        itrs = g.transform_to(ITRS(obstime=t))
        xyz = itrs.cartesian.xyz.to(u.m).value.flatten()
        itrs_positions.append(xyz)
    return np.array(itrs_positions)

# # ---------- Step 4: Reduced-dynamic fit (Method B) ----------
# # We'll transform SP3 ITRS positions at epochs into GCRS (inertial) frame and fit initial state there.
# print("Transforming SP3 positions to GCRS (inertial) for dynamic fit...")
# pos_gcrs_m = itrs_to_gcrs_positions(pos_itrs_m, times_astro)
#
# # Time array for integrator: seconds since t0
# t_eval = tsec_full  # seconds since t0
# t_span = (t_eval[0], t_eval[-1])
#
# # get Sun and Moon ephemerides using astropy helper per time, done inside dynamics
# # dynamics in GCRS frame (geocentric); we will use sun/moon positions in GCRS (geocentric)
# MU_E = 3.986005e14
# R_E = 6378136.3
# J2 = 1.0826359e-3
# MU_SUN = 1.32712440018e20
# MU_MOON = 4.902800066e12
# P_SRP_1AU = 4.56e-6
# AU = 1.495978707e11
#
# def get_sun_moon_gcrs(t_unix):
#     """Return r_sun_gcrs, r_moon_gcrs (meters) at astropy Time given by unix seconds since epoch_ref"""
#     t_ast = t0_ast + (t_unix * u.s)
#     # barycentric posvel for Sun and Moon
#     psun, vsun = get_body_barycentric_posvel('sun', t_ast)
#     pmoon, vmoon = get_body_barycentric_posvel('moon', t_ast)
#     # transform barycentric to GCRS by subtracting Earth's barycentric position at same epoch
#     # get Earth's barycentric:
#     pearth, vearth = get_body_barycentric_posvel('earth', t_ast)
#     # geocentric vector = body - earth
#     r_sun = (psun.xyz - pearth.xyz).to(u.m).value.flatten()
#     r_moon = (pmoon.xyz - pearth.xyz).to(u.m).value.flatten()
#     return r_sun, r_moon
#
# # store epoch reference
# t0_ast = times_astro[0]
#
# def accel_model(t_unix, r_vec, cram):
#     # r_vec in GCRS meters (geocentric)
#     rnorm = np.linalg.norm(r_vec)
#     a_central = -MU_E * r_vec / (rnorm**3)
#     # J2 perturbation
#     z2 = r_vec[2]**2
#     r2 = rnorm**2
#     fac = 1.5 * J2 * MU_E * (R_E**2) / (rnorm**5)
#     a_j2 = fac * r_vec * (5*z2/r2 - 1)
#     # third bodies
#     r_sun, r_moon = get_sun_moon_gcrs(t_unix)
#     a_sun = MU_SUN * ((r_vec - r_sun) / np.linalg.norm(r_vec - r_sun)**3 - (- r_sun) / np.linalg.norm(-r_sun)**3)
#     a_moon = MU_MOON * ((r_vec - r_moon) / np.linalg.norm(r_vec - r_moon)**3 - (- r_moon) / np.linalg.norm(-r_moon)**3)
#     # SRP: cannonball, acceleration away from sun direction
#     r_sat_to_sun = r_sun - r_vec
#     d = np.linalg.norm(r_sat_to_sun)
#     u = r_sat_to_sun / d
#     P = P_SRP_1AU * (AU / d)**2
#     a_srp = P * cram * u  # CR*A/m param in m^2/kg
#     return a_central + a_j2 + a_sun + a_moon + a_srp
#
# def dyn_ode(t, state, cram):
#     # t in seconds since t0_ast
#     r = state[0:3]; v = state[3:6]
#     a = accel_model(t, r, cram)
#     return np.hstack((v, a))
#
# # Residual function for least squares: parameters p = [x0(3), v0(3), cram]
# def residuals_dyn(p, t_obs, pos_obs_gcrs):
#     x0 = p[0:3]; v0 = p[3:6]; cram = p[6]
#     state0 = np.hstack((x0, v0))
#     try:
#         sol = solve_ivp(lambda tt, yy: dyn_ode(tt, yy, cram),
#                         (t_obs[0], t_obs[-1]),
#                         state0, t_eval=t_obs,
#                         method=INTEGRATOR_RK, rtol=RTOL, atol=ATOL)
#     except Exception as e:
#         print("Integrator error:", e)
#         return np.ones(pos_obs_gcrs.size) * 1e6
#     if not sol.success:
#         print("Integrator failed:", sol.message)
#         return np.ones(pos_obs_gcrs.size) * 1e6
#     pred_pos = sol.y.T[:, 0:3]
#     res = (pred_pos - pos_obs_gcrs).ravel() / OBS_SIGMA_M
#     return res
#
# # initial guess from SP3: convert first ITRS pos -> GCRS
# print("Preparing initial guess for dynamics (convert first epoch into GCRS)")
# first_itrs = pos_itrs_m[0]
# # transform to GCRS using astropy:
# itrs0 = ITRS(CartesianRepresentation(first_itrs * u.m), obstime=t0_ast)
# gcrs0 = itrs0.transform_to(GCRS(obstime=t0_ast))
# x0_gcrs = gcrs0.cartesian.xyz.to(u.m).value.flatten()
# # initial velocity from finite diff in ITRS -> convert approx to GCRS using differential transform
# # For simplicity, convert a nearby epoch positions to GCRS and use finite diff there:
# if len(times_astro) >= 3:
#     gcrs_all = pos_gcrs_m  # already computed above
#     # get velocity initial guess in GCRS by central diff at index 0 (forward diff)
#     if len(gcrs_all) >= 3:
#         v0_gcrs = (gcrs_all[1] - gcrs_all[0]) / (tsec_full[1] - tsec_full[0])
#     else:
#         v0_gcrs = np.array([0.0, 0.0, 0.0])
# else:
#     v0_gcrs = np.array([0.0, 0.0, 0.0])
#
# p0 = np.hstack((x0_gcrs, v0_gcrs, SRP_INIT))
# print("Initial parameter guess p0:", p0)
#
# # Run least squares fit
# print("Running least-squares fit of initial state + SRP scale (this may take a few minutes)...")
# res = least_squares(residuals_dyn, p0, args=(t_eval, pos_gcrs_m), verbose=2, xtol=1e-9, ftol=1e-9, gtol=1e-9, max_nfev=200)
# p_fit = res.x
# print("Fit completed. result summary:")
# print(res.message)
# print("Fitted params:", p_fit)
#
# # Integrate with fitted params to full t_out_full (we need 30s grid across last 24h; but integrate through entire span)
# # build t_out_full seconds since t0
# t_out_full = np.arange(tsec_full[0], tsec_full[-1]+1e-6, OUT_DT)
# print("Propagating fitted dynamical model to full output grid (this may take a minute)...")
# sol_fit = solve_ivp(lambda tt, yy: dyn_ode(tt, yy, p_fit[6]),
#                     (t_out_full[0], t_out_full[-1]),
#                     p_fit[0:6],
#                     t_eval=t_out_full,
#                     method=INTEGRATOR_RK, rtol=RTOL, atol=ATOL)
# if not sol_fit.success:
#     raise RuntimeError("Propagation failed: " + sol_fit.message)
# pos_dyn_gcrs_full = sol_fit.y.T[:, 0:3]
# vel_dyn_gcrs_full = sol_fit.y.T[:, 3:6]
#
# # Transform dynamic results back to ITRS for comparison with SP3 (at corresponding times)
# print("Transforming dynamic results (GCRS) back to ITRS for comparison...")
# times_out_ast = t0_ast + t_out_full * u.s
# pos_dyn_itrs_full = gcrs_to_itrs_positions(pos_dyn_gcrs_full, times_out_ast)
#
# # slice last 24h
# mask_last24 = t_out_full >= (t_out_full[-1] - 24*3600.0)
# t_out_24 = t_out_full[mask_last24]
# pos_dyn_24 = pos_dyn_itrs_full[mask_last24]
# vel_dyn_24 = vel_dyn_gcrs_full[mask_last24]  # velocities are in GCRS; approximate transform to ITRS derivative omitted for simplicity

# ---------- Clocks: 30s fit (apply same spline for both methods) ----------
clk_spline_full = CubicSpline(tsec_full, clk_s, bc_type='natural')
clk_out_full = clk_spline_full(t_out_full)
# clk_out_24 = clk_out_full[mask_last24]

# ---------- Comparison diagnostics ----------
# Compute residuals vs SP3 at SP3 epochs in last 24h for both methods:
# For fairness, evaluate hermite & dynamic at the SP3 epoch times (tsec_full indices where >= last 24h)
idx_last24 = np.where(tsec_full >= (tsec_full[-1] - 24*3600.0))[0]
t_sp3_last24 = tsec_full[idx_last24]

# Hermite at SP3 epochs:
pos_herm_at_sp3 = np.vstack([s(t_sp3_last24) for s in hermite_splines]).T
# # Dynamic: need to transform sol_fit predictions to match SP3 epoch times (we integrated with t_out_full, which includes SP3 times)
# # find indices in t_out_full corresponding to t_sp3_last24
# idx_in_out = np.searchsorted(t_out_full, t_sp3_last24)
# pos_dyn_at_sp3 = pos_dyn_itrs_full[idx_in_out]

# truth:
pos_truth_sp3 = pos_itrs_m[idx_last24]

# residuals
res_herm = pos_herm_at_sp3 - pos_truth_sp3
# res_dyn = pos_dyn_at_sp3 - pos_truth_sp3

rms_herm = np.sqrt(np.mean(np.sum(res_herm**2, axis=1)))
# rms_dyn = np.sqrt(np.mean(np.sum(res_dyn**2, axis=1)))

print("RMS residuals (last 24h at SP3 epochs):")
print(f" Hermite interpolation RMS (m): {rms_herm:.4f}")
# print(f" Dynamic fit RMS         RMS (m): {rms_dyn:.4f}")

# Also compute mean and std of component differences
def comp_stats(res):
    return np.mean(res, axis=0), np.std(res, axis=0)

mean_herm, std_herm = comp_stats(res_herm)
# mean_dyn, std_dyn = comp_stats(res_dyn)
print("Hermite mean (XYZ m):", mean_herm, "std:", std_herm)
# print("Dynamic mean  (XYZ m):", mean_dyn, "std:", std_dyn)

# ---------- Plots ----------
plt.figure(figsize=(12,8))
plt.subplot(311)
# plt.plot(t_out_24/3600.0, np.linalg.norm(pos_herm - pos_dyn_24, axis=1), label='|Hermite - Dynamic| (m)')
plt.ylabel('Difference (m)')
plt.legend()
plt.subplot(312)
plt.plot(t_sp3_last24/3600.0, np.linalg.norm(res_herm, axis=1), label='Hermite - SP3 (m)')
# plt.plot(t_sp3_last24/3600.0, np.linalg.norm(res_dyn, axis=1), label='Dynamic - SP3 (m)')
plt.legend(); plt.ylabel('Residual (m)')
plt.subplot(313)
# plt.plot(t_out_24/3600.0, clk_out_24*1e6, label='Clock (us)')  # microsec
plt.ylabel('Clock (us)')
plt.xlabel('Hours since t0')
plt.legend()
plt.tight_layout()
plt.savefig('test.png',dpi=300)

# ---------- Save outputs (CSV) ----------
# out_df = pd.DataFrame({
#     't_s_since_t0': t_out_24,
#     'x_herm_m': pos_herm[:,0], 'y_herm_m': pos_herm[:,1], 'z_herm_m': pos_herm[:,2],
#     'vx_herm_m_s': vel_herm[:,0], 'vy_herm_m_s': vel_herm[:,1], 'vz_herm_m_s': vel_herm[:,2],
#     'x_dyn_m': pos_dyn_24[:,0], 'y_dyn_m': pos_dyn_24[:,1], 'z_dyn_m': pos_dyn_24[:,2],
#     'vx_dyn_approx_m_s': vel_dyn_24[:,0], 'vy_dyn_approx_m_s': vel_dyn_24[:,1], 'vz_dyn_approx_m_s': vel_dyn_24[:,2],
#     'clk_s': clk_out_24
# })
# out_csv = f"{SAT_ID}_30s_compare.csv"
# out_df.to_csv(out_csv, index=False)
# print("Saved comparison CSV to", out_csv)

# herm = np.hstack([pos_herm,clk_out_24[:,None],vel_herm])
# dyn = np.hstack([pos_dyn_24,clk_out_24[:,None],vel_dyn_24])
# np.savez('G01_30s_compare.npz',t_out_24=t_out_24,herm=herm,dyn=dyn)

with np.load('gnssOrb/GNSS_2025.308_NRT.npz', allow_pickle=True) as fid:
    Gsats = fid['Gsats']
    if SAT_ID in Gsats:
        gidx = np.where(Gsats == SAT_ID)[0][0]
        Gdt = fid['Gdt']
        Gdata = fid['Gdata'][:, gidx, :]

T0 = t0_ast.to_datetime() + timedelta(days=1)
fig, axs = plt.subplots(3, 2, figsize=(12, 9))
axs[0,0].plot(dt2float(Gdt,T0)/60,Gdata[:,:3])
axs[0,0].set_xlim((0,1435))
axs[0,1].plot(dt2float(Gdt,T0)/60,Gdata[:,3])
axs[0,1].set_xlim((0,1435))
axs[1,0].plot(t_out_24[10:]/60-1440,pos_herm[10:,:]/1E3-Gdata[:-10,:3],'.')
axs[1,0].set_xlim((0,1435))
axs[1,0].set_ylim((-0.1,0.1))
axs[2,0].plot(t_out_24[10:]/60-1440,pos_dyn_24[10:,:]/1E3-Gdata[:-10,:3],'.')
axs[2,0].set_xlim((0,1435))
axs[1,1].plot(t_out_24[10:]/60-1440,clk_out_24[10:]*1E6-Gdata[:-10,3],'.')
axs[1,1].set_xlim((0,1435))
plt.savefig(f'{SAT_ID}_Cmpr.png',dpi=300)

fig.clf()
axs[-1,-1].plot(dt2float(Gdt,dt_range[0])/60,Gdata[:,:3])
axs[-1,-1].set_xlim((0,(dt_range[1]-dt_range[0])/timedelta(minutes=1)))
for ii, gtoken in enumerate(sorted(data)):
    axs.T.ravel()[ii].plot(data[gtoken][:,0],data[gtoken][:,1:4]-Gdata[data[gtoken][:,0].astype('int')*2,:3],'.-')
    axs.T.ravel()[ii].set_xlim((0, (dt_range[1] - dt_range[0]) / timedelta(minutes=1)))
    axs.T.ravel()[ii].set_ylim((-9E-5, 9E-5))
    axs.T.ravel()[ii].set_title(gtoken)
plt.savefig(f'{gsat}_Orb.png',dpi=300)