from datetime import datetime, timedelta
import numpy as np
from re import sub
from tools import dt2float

def disp_progress(working_leo):
    from os.path import exists
    progfile = f'progress_rt_{working_leo}.npz'
    status = {-2: 'Aborted', -1: 'Queued', 0: 'Processing', 1: 'Completed'}
    if exists(progfile):
        print(' #      Last Modified       TAR File Token      Status')
        print('===  ===================  ==================  ==========')
        with np.load(progfile,allow_pickle=True) as fid:
            for nn, (tarfile, progress, timetag) in enumerate(zip(fid['tarfiles'], fid['progress'], fid['timetag'])):
                tarseg = sub(r'.*_(\d{8}_\d{4}_.*)\.tar',r'\1',tarfile)
                print(f"{nn:3d}  {timetag:%Y-%m-%d %H:%M:%S}  {tarseg}  {status[progress]}")
    else:
        print(f'{progfile} not found!')

def disp_pod(podfile):
    import pickle
    from matplotlib import pyplot as plt
    dt0 = datetime.strptime(sub(r'.*\.(\d{4}\.\d{3})\..*',r'\1',podfile),'%Y.%j')
    figfile = sub(r'\.pkl',r'.png',podfile)
    colors = np.array([[int(color[ii:ii + 2], 16) / 255 for ii in range(1, 6, 2)] for color in
                       plt.rcParams['axes.prop_cycle'].by_key()['color']])
    with open(podfile, 'rb') as fid:
        pkldata = pickle.load(fid)
        dsat = pkldata['dsat']
        dtime = pkldata['dtime']
        dataall = pkldata['dataall']
    plt.close('all')
    fig, (ax1, ax) = plt.subplots(2, 1, figsize = (12,9), gridspec_kw={'height_ratios': [1,8]})
    plt.subplots_adjust(hspace = 0.01)
    time_range = [np.inf, -np.inf]
    for time in dtime:
        time_range = [min(time_range[0], dt2float(time[:1], dt0)[0]), max(time_range[1], dt2float(time[-1:], dt0)[0])]
    time_match = np.arange(time_range[0] + 30, time_range[1] - 149).astype('int32')
    idx_match = np.zeros_like(time_match, 'int')
    for dnum, (sat, time, data) in enumerate(zip(dsat, dtime, dataall)):
        idx = np.all(np.isfinite(data), axis = 1)
        idx1 = np.logical_and(np.isfinite(data[:, 0]), ~np.isfinite(data[:, 3]))
        idx2 = np.logical_and(~np.isfinite(data[:, 0]), np.isfinite(data[:, 3]))
        ax.plot(dt2float(time[~idx], dt0) / 3600, np.full(time[~idx].shape, dnum), '.',
                color=(colors[dnum % 10] + 2) / 3, zorder=2)
        ax.plot(dt2float(time[idx1], dt0) / 3600, np.full(time[idx1].shape, dnum-0.1), '.', ms=0.5,
                color=colors[dnum%10], zorder=2)
        ax.plot(dt2float(time[idx2], dt0) / 3600, np.full(time[idx2].shape, dnum+0.1), '.', ms=0.5,
                color=colors[dnum%10], zorder=2)
        ax.plot(dt2float(time[idx],dt0)/3600, np.full(time[idx].shape, dnum),'.',color=colors[dnum%10], zorder=2)
        ax.text(-2.2-(0 if dnum%2 else 1), dnum, sat, horizontalalignment='right', verticalalignment='center', color=colors[dnum%10], fontsize=12)
        time_nice = dt2float(time[idx],dt0).astype('int32')
        if time_nice.size == 0:
            continue
        idx = [0]+list(np.where(np.diff(time_nice) > 3)[0] + 1)+[len(time_nice)]
        for ii in range(len(idx)-1):
            idx_match[np.logical_and(time_match >= time_nice[idx[ii]] + 30,
                                     time_match <= time_nice[idx[ii + 1] - 1] - 149)] += 1
    bad_time = time_match[idx_match==0]
    idx = [0] + list(np.where(np.diff(bad_time) > 1)[0] + 1) + [len(bad_time)]
    bad_time = bad_time.astype('float') / 3600
    for ii in range(len(idx)-1):
        ax.add_patch(
            plt.Rectangle((bad_time[idx[ii]], -1), bad_time[idx[ii + 1] - 1] - bad_time[idx[ii]], len(dsat) + 1,
                          facecolor=[0.9, 0.9, 0.9], zorder=1))
    ax.grid(True)
    ax.set_xlabel(f'Hours of {dt0:%Y-%m-%d}')
    ax.set_xticks(range(-2,25))
    ax.set_xlim((-2,24))
    ax.set_ylim((-1,len(dsat)))
    ax.set_yticks(range(len(dsat)))
    ax.set_yticklabels([])
    ax.tick_params(direction="in")
    ax1.plot(time_match.astype('float') / 3600, idx_match, 'k',linewidth=0.5)
    ax1.grid(True)
    ax1.set_xticks(range(-2,25))
    ax1.set_xlim((-2,24))
    ax1.set_xticks(range(-2, 25))
    ax1.set_xticklabels([])
    ax1.set_yticks(range(0, 30, 5))
    ax1.set_ylim((0, 25))
    ax1.set_title(sub(r'.*/(.*)\.pkl', r'\1', podfile))
    ax1.tick_params(direction="in")
    fig.savefig(figfile, dpi = 300)
    plt.close('all')

def disp_orb(orbfile):
    from matplotlib import pyplot as plt
    from os.path import isfile
    dt0 = datetime.strptime(sub(r'.*\.(\d{4}\.\d{3})\..*',r'\1',orbfile),'%Y.%j')
    figfile = sub(r'\.npz',r'.png',orbfile)
    colors = np.array([[int(color[ii:ii + 2], 16) / 255 for ii in range(1, 6, 2)] for color in
                       plt.rcParams['axes.prop_cycle'].by_key()['color']])
    with np.load(orbfile, allow_pickle=True) as fid:
        data_orb = fid['data']
        dts_orb = fid['dts']
    attfile = sub('leoOrb','leoAtt',orbfile)
    isatt = isfile(attfile)
    if isatt:
        with np.load(attfile, allow_pickle=True) as fid:
            data_att = fid['data']
            dts_att = fid['dts']
            vars_att = fid['vars']
        data_att = data_att[:, ['att' in x for x in vars_att]]
        vars_att = vars_att[['att' in x for x in vars_att]]
        idx = np.cumsum(np.hstack(([False], np.any(
            np.abs(np.diff(data_att, axis=0) / np.diff(dt2float(dts_att, dt0))[:, None]) > 0.01, axis=1)))) % 2 == 1
        data_att[idx] *= -1
    plt.close('all')
    fig, axs = plt.subplots(4, 1, figsize = (12,9))
    plt.subplots_adjust(hspace = 0.01)
    time_range = dt2float(dts_orb[[0,-1]], dt0)
    if isatt:
        time_range = [min(time_range[0], dt2float(dts_att[:1], dt0)[0]), max(time_range[1], dt2float(dts_att[-1:], dt0)[0])]
    t1 = np.hstack([time_range[:1],dt2float(dts_orb,dt0),time_range[1:]])
    idx = np.where(np.diff(t1)>15)[0]
    t1 = np.array([t1[idx],t1[idx+1]]).T/3600
    t2 = np.hstack([time_range[:1],dt2float(dts_att,dt0),time_range[1:]])
    idx = np.where(np.diff(t2)>15)[0]
    t2 = np.array([t2[idx],t2[idx+1]]).T/3600
    axs[0].plot(dt2float(dts_orb, dt0)/3600, data_orb[:, :3], '.', ms=1,markeredgewidth=0, zorder=2)
    axs[0].set_ylabel('Position (km)')
    axs[0].set_title(sub(r'.*/leo..._(.*)\.npz',r'\1 ',orbfile)+f'({dt0:%Y-%m-%d})')
    # sgn = np.sign(data_orb[:, 3])
    # idx = sgn>0
    # axs[1].plot(dt2float(dts_orb[idx], dt0), np.log10(data_orb[idx, 3]) + 6, '.', ms=1,markeredgewidth=0)
    # idx = sgn==0
    # axs[1].plot(dt2float(dts_orb[idx], dt0), np.zeros_like(data_orb[idx, 3]), '.', ms=1, markeredgewidth=0)
    # idx = sgn<0
    # axs[1].plot(dt2float(dts_orb[idx], dt0), -np.log10(-data_orb[idx, 3]) - 6, '.', ms=1, markeredgewidth=0)
    if max(np.abs(data_orb[:, 3])) > 0.002:
        yt_list = np.array([-0.005,-0.004,-0.003,-0.002,-0.001,0,0.001,0.002,0.003,0.004,0.005])
        ytl_list = np.array(['-10','-1','-0.1', '-1E-2', '-1E-3', '0', '1E-3', '1E-2', '0.1', '1','10'])
        idx = data_orb[:, 3] > 0.001
        if np.any(idx):
            data_orb[idx, 3] = 0.001 * (np.log10(data_orb[idx, 3]) + 4)
        idx = data_orb[:, 3] < -0.001
        if np.any(idx):
            data_orb[idx, 3] = -0.001 * (np.log10(-data_orb[idx, 3]) + 4)
        axs[1].plot(dt2float(dts_orb, dt0)/3600, data_orb[:, 3], '.', ms=1, markeredgewidth=0, zorder=2)
        yl = axs[1].get_ylim()
        idx = np.logical_and(yt_list>=yl[0],yt_list<yl[1])
        axs[1].set_yticks(yt_list[idx])
        axs[1].set_yticklabels(ytl_list[idx])
    else:
        axs[1].plot(dt2float(dts_orb, dt0)/3600, data_orb[:, 3], '.', ms=1,markeredgewidth=0, zorder=2)
    axs[1].set_ylabel('Clock Bias (us)')
    axs[2].plot(dt2float(dts_orb, dt0)/3600, data_orb[:, 4:7], '.', ms=1,markeredgewidth=0, zorder=2)
    axs[2].set_ylabel('Velocity (dm/s)')
    if isatt:
        axs[3].plot(dt2float(dts_att, dt0)/3600, data_att, '.', ms=1,markeredgewidth=0, zorder=2)
        axs[3].set_ylabel('Attitude Quaternion')
    else:
        axs = axs[:3]
    for ax in axs:
        yl = ax.get_ylim()
        for ii in range(len(t1)):
            ax.add_patch(
                plt.Rectangle((t1[ii,0], yl[0]), t1[ii,1] - t1[ii,0], yl[1]-yl[0],
                              facecolor=[0.8, 0.8, 1.0], zorder=1, alpha=0.5))
        for ii in range(len(t2)):
            ax.add_patch(
                plt.Rectangle((t2[ii,0], yl[0]), t2[ii,1] - t2[ii,0], yl[1]-yl[0],
                              facecolor=[1.0, 0.8, 0.8], zorder=1, alpha=0.5))
        ax.grid(True)
        ax.set_xlim([-3,24])
        ax.set_xticks(range(-3,25))
        if ax == axs[-1]:
            ax.set_xlabel(f'Hours of {dt0:%Y-%m-%d}')
        else:
            ax.set_xticklabels([])
    fig.savefig(figfile, dpi = 300)
    plt.close('all')