import numpy as np
from datetime import datetime, timedelta
from matplotlib import pyplot as plt
from glob import glob
from re import sub
from tools import dt2float

def disp_dcb():
    gnsses = np.array(
        ['C06', 'C07', 'C08', 'C09', 'C10', 'C11', 'C12', 'C13', 'C14', 'C16', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24',
         'C25', 'C26', 'C27', 'C28', 'C29', 'C30', 'C32', 'C33', 'C34', 'C35', 'C36', 'C37', 'C38', 'C39', 'C40', 'C41',
         'C42', 'C43', 'C44', 'C45', 'C47', 'C48', 'C49', 'C50', 'E02', 'E03', 'E04', 'E05', 'E06', 'E07', 'E08', 'E09',
         'E10', 'E11', 'E12', 'E13', 'E14', 'E15', 'E16', 'E18', 'E19', 'E21', 'E23', 'E24', 'E25', 'E26', 'E27', 'E29',
         'E30', 'E31', 'E33', 'E34', 'E36', 'G01', 'G02', 'G03', 'G04', 'G05', 'G06', 'G07', 'G08', 'G09', 'G10', 'G11',
         'G12', 'G13', 'G14', 'G15', 'G16', 'G17', 'G18', 'G19', 'G20', 'G21', 'G22', 'G23', 'G24', 'G25', 'G26', 'G27',
         'G28', 'G29', 'G30', 'G31', 'G32', 'J02', 'J03', 'J04', 'R02', 'R03', 'R04', 'R05', 'R06', 'R07', 'R08', 'R09',
         'R11', 'R12', 'R14', 'R15', 'R16', 'R17', 'R18', 'R19', 'R20', 'R21', 'R22', 'R24', 'R25', 'R26', 'R27',
         'R28'])
    ants = [0,1]
    leos = ['GN04','GN05']
    # gnsstypes = ['C','E']
    from matplotlib import dates as mdates
    plt.close('all')
    fig, axs = plt.subplots(2,2,figsize=(12,9))
    alldcbs = [[[],[]],[[],[]]]
    allgnsses = [[[],[]],[[],[]]]
    alldts = [[[],[]],[[],[]]]
    kernal = np.array([1]) # np.arange(10, 0, -1) / 55
    for aa,ant in enumerate(ants):
        for ll,leo in enumerate(leos):
            dcbfiles = glob(f'p04_teczdcb/teczdcb_{leo}.*.A{ant:02d}_h4000.npz')
            dcbfiles.sort()
            dts = np.array([datetime.strptime(sub(r'.*\.(\d{4}-\d\d-\d\d)\..*',r'\1',x),'%Y-%m-%d') for x in dcbfiles])
            idx = dts>datetime(2025,12,15)
            dcbfiles = np.array(dcbfiles)[idx]
            dts = dts[idx]
            alldcb = np.full((len(dts),len(gnsses)),np.nan)
            for dd, dt, dcbfile in zip(range(len(dts)),dts,dcbfiles):
                with np.load(dcbfile, allow_pickle=True) as fid:
                    # sec_tec = fid['sec_tec']
                    # tecz = fid['tecz']
                    dcb_gnss = fid['gnsses']
                    dcb_sol = fid['dcb_sol']
                alldcb[dd,np.isin(gnsses,dcb_gnss)] = dcb_sol
            idx = np.any(np.isfinite(alldcb), axis=0)
            allgnss = gnsses[idx]
            alldcb = alldcb[:, idx]
            idx = np.any(np.abs(alldcb-np.nanmedian(alldcb,axis=0))>3.5,axis=1)
            alldcb[idx] = np.nan
            for ii in range(alldcb.shape[1]):
                idx = np.isfinite(alldcb[:,ii])
                alldcb[:,ii] = np.interp(dt2float(dts,dts[0]),dt2float(dts[idx],dts[0]),alldcb[idx,ii])
            alldcbs[ll][aa] = np.array([np.convolve(kernal,alldcb[:,x],'valid') for x in range(alldcb.shape[1])]).T
            # alldcbs[ll][aa] = np.nanmedian(alldcb, axis=0)
            allgnsses[ll][aa] = allgnss.copy()
            alldts[ll][aa] = dts[len(kernal)-1:]
            # ucargnsses = []
            # ucardts = []
            # ucardcb = []
            # for dt in dts:
            #     print(f'{datetime.now():%H:%M:%S} - {dt:%Y-%m-%d}')
            #     ucarfiles = glob(
            #         f'/data3/xshao/planetiq/SWx/UCAR/L1B/{dt:%Y%m%d}/podTec_{leo}.*.{ant:02d}_0001.0001_nc')
            #     ucarfiles.sort()
            #     for ucarfile in ucarfiles:
            #         ucargnss = sub(r'.*(...)\.0._0001.0001_nc',r'\1',ucarfile)
            #         if ucargnss.startswith('G'):
            #             continue
            #         if ucargnss not in ucargnsses:
            #             ucargnsses.append(ucargnss)
            #             ucardts.append([])
            #             ucardcb.append([])
            #         uidx = ucargnsses.index(ucargnss)
            #         ucardts[uidx].append(datetime.strptime(sub(r'.*\.(\d{4}\.\d{3}\.\d\d\.\d\d)\..*',r'\1',ucarfile),'%Y.%j.%H.%M'))
            #         with Dataset(ucarfile,'r') as fid:
            #             ucardcb[uidx].append(fid.getncattr('leodcb'))
            # for gg,gnsstype in enumerate(gnsstypes):
            #     ax = axs[gg,0]
            #     gidx = np.array([x[0] == gnsstype for x in allgnss])
            for gg in range(1):
                ax = axs[aa,ll]
                gidx = np.full((alldcb.shape[1],),True)
                ax.plot(dts,alldcb[:,gidx]-np.nanmedian(alldcb[:,gidx],axis=0),'.-')
                ax.grid(True)
                ax.set_ylim((-10,10))
                ax.set_xlabel('Date')
                ax.set_ylabel('DCB difference')
                ax.set_title(f'{leo}.A{ant:02d}')
                date_format = mdates.DateFormatter('%m-%d')
                ax.xaxis.set_major_formatter(date_format)
            # for gg,ucargnss in enumerate(ucargnsses):
            #     if ucargnss.startswith(gnsstypes[0]):
            #         ax = axs[0,1]
            #     else:
            #         ax = axs[1,1]
            #     ax.plot(ucardts[gg],ucardcb[gg]-np.nanmedian(ucardcb[gg]),'.-')
            # for gg in range(2):
            #     ax = axs[gg,1]
            #     ax.grid(True)
            #     ax.set_ylim((-10, 10))
            #     ax.set_xlabel('Date')
            #     ax.set_ylabel('DCB difference')
            #     ax.set_title(f'UCAR {leo}.A{ant:02d}.{gnsstypes[gg]}')
            #     date_format = mdates.DateFormatter('%m-%d')
            #     ax.xaxis.set_major_formatter(date_format)
    fig.savefig('dcb_evolution.png', dpi=300)
    return

def cmpr_ucar_tec(tecfile):
    from netCDF4 import Dataset
    from matplotlib import pyplot as plt
    from matplotlib import dates as mdates
    colors = np.array([[int(color[ii:ii + 2], 16) / 255 for ii in range(1, 6, 2)] for color in
                       plt.rcParams['axes.prop_cycle'].by_key()['color']])
    dt0_gps = datetime(1980, 1, 6)
    ant = int(sub(r'.*\.A(\d\d)\..\d\d\.nc',r'\1',tecfile))
    leo = sub(r'.*/podTec_(..\d\d)\..*',r'\1',tecfile)
    dt_rng = [datetime.strptime(x, '%Y%m%d%H%M%S') for x in sub(r'.*(\d{14}\.\d{14}).*', r'\1', tecfile).split('.')]
    dt_rr = datetime.strptime(sub(r'.*\.(\d{4}\.\d{3})\..*', r'\1', tecfile), '%Y.%j')
    sec_rr = (dt_rr - dt0_gps) / timedelta(seconds=1)
    gnss = sub(r'.*\.(...).nc', r'\1', tecfile)

    ucarfiles = np.array(glob(f'/data3/xshao/planetiq/SWx/UCAR/L1B/{dt_rr:%Y%m%d}/podTec_{leo}.*.{gnss}.{ant:02d}*_nc'))
    dts_ucar = np.array([datetime.strptime(sub(r'.*_..\d\d\.(.*)\.\d{4}\.[A-Z].*',r'\1',ucarfile),'%Y.%j.%H.%M') for ucarfile in ucarfiles])
    dur_ucar = np.array([timedelta(minutes=int(sub(r'.*\.(\d{4})\.[A-Z].*',r'\1',ucarfile))) for ucarfile in ucarfiles])
    fidx = np.logical_and(dts_ucar<dt_rng[1]-timedelta(minutes=1),dts_ucar+dur_ucar>dt_rng[0]+timedelta(minutes=1))
    if ~np.any(fidx):
        return
    ucarfiles = ucarfiles[fidx]
    dts_ucar = dts_ucar[fidx]
    dur_ucar = dur_ucar[fidx]
    titlestr = [sub(r'.*/',r'',tecfile)]

    plt.close('all')
    fig, axs = plt.subplots(2, 1, figsize=(12, 9))
    with Dataset(tecfile, mode='r') as fid:
        umd_tec = fid['TEC'][...]
        umd_sec = (fid['time'][...] - sec_rr).astype('int')
    axs[0].plot(np.array([dt_rr + timedelta(seconds=int(x)) for x in umd_sec]), umd_tec, '.-', markersize=5,
             markeredgewidth=0, linewidth=0.5, label=f'UMD ({dt_rng[0]:%H:%M}-{dt_rng[1]:%H:%M})', color='k')
    for nn, ucarfile in enumerate(ucarfiles):
        with Dataset(ucarfile, mode='r') as fid:
            ucar_tec = fid['TEC'][...]
            ucar_sec = (fid['time'][...] - sec_rr).astype('int')
            # ucar_leodcb = fid.getncattr('leodcb')
            # ucar_gpsdcb = fid.getncattr('gpsdcb')
            # ucar_offset = fid.getncattr('leveling_offset')
        ridx = np.isin(umd_sec, ucar_sec)
        uidx = np.isin(ucar_sec, umd_sec)
        titlestr.append(sub(r'.*/',r'',ucarfile))
        axs[0].plot(np.array([dt_rr + timedelta(seconds=int(x)) for x in ucar_sec]), ucar_tec, '.-', markersize=5,
                 markeredgewidth=0, linewidth=0.5, color=colors[nn],
                 label=f'UCAR ({dts_ucar[nn]:%H:%M}-{dts_ucar[nn] + dur_ucar[nn]:%H:%M})')
        axs[1].plot(np.array([dt_rr + timedelta(seconds=int(x)) for x in ucar_sec[uidx]]), ucar_tec[uidx]-umd_tec[ridx], '.-', markersize=5,
                    markeredgewidth=0, linewidth=0.5, color=colors[nn],
                    label=f'UCAR ({dts_ucar[nn]:%H:%M}-{dts_ucar[nn] + dur_ucar[nn]:%H:%M})')
        # axs[0].plot(90-np.arccos(coszn[ridx][0])/np.pi*180,90-np.arccos(coszn[ridx][-1])/np.pi*180, '.', color=colors[0], markersize=10, markeredgewidth=0)
        # axs[1].plot(altp_km[0], altp_km[-1],'.', color=colors[0], markersize=10, markeredgewidth=0)
        # axs[0, 1].plot(ucar_sec, np.full_like(ucar_sec, ucar_offset), '.', color=colors[nn+1], markersize=2, markeredgewidth=0)
        # axs[0, 1].plot(ucar_sec[uidx], np.full_like(ucar_sec[uidx],
        #                -(np.sum(D1[ridx] * S2[ridx, 0]) / np.sum(S2[ridx, 0]) - np.sum(D2[ridx] * S2[ridx, 1]) / np.sum(S2[ridx, 1]))),
        #                '--', color=colors[nn+1], linewidth=2)

        # axs[0].plot(ucar_sec[uidx], tec[ridx] + ucar_leodcb + dcb_value, '.-', markersize=10, markeredgewidth=0, label='(A) TEC with UCAR leodcb')
        # axs[0].plot(ucar_sec[uidx], tecp[ridx] + ucar_offset + ucar_leodcb + ucar_gpsdcb, '.-', markersize=10, markeredgewidth=0,
        #             label='(B) TEC with UCAR leveling/leodcb/gpsdcb')
        # axs[0].plot(ucar_sec[uidx], ucar_tec[uidx], label='(R) UCAR TEC')
        # axs[0].set_title(sub(r'.*/',r'',reffile))
        # axs[0].set_ylabel('TEC (tecu)')
        # axs[0].legend()
        # axs[1].plot(ucar_sec[uidx], tec[ridx] + ucar_leodcb + dcb_value - ucar_tec[uidx], '.-', markersize=10,
        #             markeredgewidth=0, label='(A)-(R)')
        # axs[1].plot(ucar_sec[uidx], tecp[ridx] + ucar_offset + ucar_leodcb + ucar_gpsdcb - ucar_tec[uidx], '.-', markersize=10, markeredgewidth=0, label='(B)-(R)')
        # axs[1].set_ylabel('TEC Difference (tecu)')
        # axs[1].legend()

        # val = np.nanmedian(tec[ridx] + ucar_leodcb + dcb_value - ucar_tec[uidx])
        # if np.any(np.abs(tecp[ridx]+ucar_offset+ucar_leodcb+ucar_gpsdcb)>1E4) or np.any(ucar_tec[uidx]-tec[ridx]-ucar_leodcb-dcb_value>100):
        #     continue
        # axs[0, 0].plot(ucar_tec[uidx], tecp[ridx] + ucar_offset + ucar_leodcb + ucar_gpsdcb, '.', markersize=2,
        #                markeredgewidth=0)
        # axs[0, 1].plot(ucar_tec[uidx], tecp[ridx] + ucar_offset + ucar_leodcb + dcb_value, '.', markersize=2,
        #                markeredgewidth=0)
        # axs[1, 0].plot(ucar_tec[uidx], tec[ridx] + ucar_leodcb + ucar_gpsdcb, '.', markersize=2,
        #                markeredgewidth=0)
        # axs[1, 1].plot(ucar_tec[uidx], tec[ridx] + ucar_leodcb + dcb_value, '.', markersize=2,
        #                markeredgewidth=0)
    # for ax in [axs]: #axs.ravel():
    #     ax.set_xlim([0,750])
    #     ax.set_ylim([0,750])
    #     ax.set_aspect('equal')
    #     ax.grid(True)
    #     ax.set_xlabel('UCAR TEC')

    # axs[0, 0].set_ylabel('UCAR leodcb/gpsdcb/leveling_offset')
    # axs[0, 1].set_ylabel('UCAR leodcb/leveling_offset and CODE gpsdcb')
    # axs[1, 0].set_ylabel('UCAR leodcb/gpsdcb and profile-wise offset')
    # axs[1, 1].set_ylabel('UCAR leodcb and CODE gpsdcb/profile-wise offset')

    # if len(tecfiles)>0:
    #     fig.savefig(sub(r'.*/(.*)\.npz',r'Figures/\1_D.png',reffile),dpi=300)
    #     break
    # plt.close('all')

    # for ax in axs.ravel():
    #     ax.grid(True)
    # axs[1].set_xlim([-1000,1000])
    for ax in axs.ravel():
        ax.grid(True)
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
    axs[0].legend(loc='best')
    axs[0].set_title('\n'.join(titlestr))
    axs[1].set_xlabel(f'Time of {dt_rr:%Y-%m-%d}')
    xl = axs[0].get_xlim()
    axs[1].set_xlim(xl)
    yl = axs[1].get_ylim()
    axs[1].set_ylim(np.mean(yl)+np.array([-1,1])*max(0.1,(yl[1]-yl[0])/2))
    fig.savefig(f'Figures/{titlestr[0][:-3]}.png', dpi=300)
    plt.close('all')
    return

def disp_bore_sight_lut(lutfile):
    with np.load(lutfile, allow_pickle=True) as fid:
        el_bins = fid['el_bins']
        az_bins = fid['az_bins']
        sad_bins = fid['sad_bins']
        mpe_grid = fid['mpe_grid']

    for ll in range(2):
        plt.close('all')
        fig, axs = plt.subplots(3,3, figsize=(12,9))
        for nn in range(9):
            ax = axs.ravel()[nn]
            im = ax.pcolor(az_bins,el_bins,mpe_grid[:,:,nn,ll],clim=(-1.5,1.5),cmap='jet')
            ax.grid(True)
            ax.set_xticks(np.linspace(-60,60,5))
            ax.set_xlim((-90,90))
            ax.set_ylim((-45,90))
            ax.text(88,80,f'{sad_bins[nn]:.1f}',horizontalalignment='right',verticalalignment='center')
            if nn>5:
                ax.set_xlabel('Azimuth from boresight [deg]')
            if nn%3==0:
                ax.set_ylabel('Elevation from boresight [deg]')
            if nn==1:
                ax.set_title(sub(r'.*/bslut_(.*)\.npz',r'\1.L{:d}'.format(ll+1),lutfile))
        plt.tight_layout()
        fig.colorbar(im, ax=axs.ravel().tolist(), fraction=0.05, pad=0.01, aspect=50)
        fig.savefig(sub(r'.*/(.*)\.npz', r'Figures/\1.L{:d}.png'.format(ll+1), lutfile), dpi=300)
    return