import os
import numpy as np
from datetime import datetime, timedelta, timezone
# TC: updated 2026/04/28 (1 line, import zoneinfo)
from zoneinfo import ZoneInfo
from glob import glob
from re import sub
from traceback import extract_tb
import tarfile
from sys import argv
from pathlib import Path

# Jun modified-------------
from spire.step0 import prepare_file
from spire.step1 import gather_data
from spire.step3 import pair_ro_ref
from spire.step4 import cal_exph
from configuration_spire import hostname

current_path = Path.cwd()
working_dir = (current_path / __package__).resolve()

#    data_dirs = ['/data3/xinjiaz/data_GPSRO/spire_scdr/LEMUR_L0_NAVOBS/',
#                 '/data3/xinjiaz/data_GPSRO/spire_scdr/LEMUR_L0_NAVSOL/',
#                 '/data3/xinjiaz/data_GPSRO/spire_scdr/LEMUR_L0_ATTOBS/',
#                 '/data3/xinjiaz/data_GPSRO/spire_scdr/LEMUR_L0_RO/']

if hostname.endswith('.umd.edu'):
    from configuration_spire import data_dirs

elif hostname.endswith('ec2.internal'):
    # Jun modified2: added by Xinjia, for s3 bucket
    from configuration_spire import data_dirs
    import s3fs
    fs = s3fs.S3FileSystem(anon=False)

elif hostname.endswith('.star1.nesdis.noaa.gov'):
    from configuration_spire import scdr_executable
    # for SCDR command, by Xinjia
    import subprocess
    import shlex
# Jun modified-------------

if __name__ == '__main__':
    if len(argv) == 3:
        working_leo = argv[2]
        nProcess = int(argv[1])
    elif len(argv) == 2:
        working_leo = ''
        nProcess = int(argv[1])
    else:
        working_leo = ''
        nProcess = 0
    
    ztypes = ['navObs', 'navSol', 'attObs', 'rocObs']
    os.chdir(working_dir)
    working_dir = str(working_dir)
    root_out = working_dir + f'/L0_extract_{working_leo}'
    os.makedirs(root_out, exist_ok=True)
    # Jun modified ---------------------------
    os.makedirs('s01_podRx3',exist_ok=True)
    os.makedirs('s01_leoAtt',exist_ok=True)
    os.makedirs('s01_leoOrb',exist_ok=True)
    os.makedirs('s04_atmPhs',exist_ok=True)
    os.makedirs('log',exist_ok=True)
    #Jun modified ----------------------------
    

    logfile = f'log/progress_rt_{working_leo}_{datetime.now():%Y%m}.txt'
    def log_message(msg):
        with open(logfile, 'a') as fid:
            fid.write(f"{datetime.now():%H:%M:%S} {msg}\n")

    log_message('- checking progress...')
    progfile = f'progress_rt_{working_leo}.npz'
    if os.path.exists(progfile):
        with np.load(progfile, allow_pickle=True) as fid:
            tarfiles = fid['tarfiles']
            progress = fid['progress']
            timetag = fid['timetag']
            # TC: updated 2026/04/29 (add 4 lines, add a new variable to indicate processing status)
            if 'lastaccess' in fid:
                lastaccess = fid['lastaccess']
            else:
                lastaccess = np.array([datetime(2000,1,1)])
    else:
        tarfiles = np.array([],dtype='<U128')
        progress = np.array([],dtype='int8')
        timetag = np.array([],dtype=datetime)
        # TC: updated 2026/04/29 (add 1 line)
        lastaccess = np.array([datetime.now()])

    def update_progress(pnum,state):
        progress[pnum] = state
        timetag[pnum] = datetime.now()
        # TC: updated 2026/04/29 (add 1 line)
        lastaccess[0] = datetime.now()

    # TC: updated 2026/04/29 (remove 8 lines, avoid updating lastaccess before checking, move the whole block below 'else' statement)
    def save_progress():
        np.savez(f'progress_rt_{working_leo}.npz',
            tarfiles = tarfiles,
            progress = progress,
            timetag = timetag,
            # TC: updated 2026/04/29 (add 1 line)
            lastaccess = lastaccess,
        )

    # TC: updated 2026/04/29 (1 line)
    if datetime.now() - lastaccess[0] < timedelta(minutes=60):
        log_message("-- Still progressing.")
    else:
        # TC: updated 2026/04/29 (9 lines, claim processing status via calling update_progress as list matching is time-consuming)
        idx = progress == 0
        update_progress(idx, -1)
        if np.sum(idx) > 0:
            log_message(f"-- {np.sum(idx)} timeout processes were terminated.")
        idx = progress == -2
        update_progress(idx, -1)
        if np.sum(idx) > 0:
            log_message(f"-- {np.sum(idx)} error processes were terminated.")
        save_progress()
        manual_date = True # set True to manually set dates to process
        manual_date = False # set False for cronjob
        nloop = 0
        while True:
            new_proc = False
            nloop += 1
            if manual_date: # manually assign dates to process
                process_dates = [datetime(2026, 4, 15) + timedelta(days=x) for x in range(1)]
            else:
                #dt_today = datetime.today()
                #process_dates = [dt_today + timedelta(days=x) for x in range(0,2)] # 2 recent days, today and tomorrow

                dt_today = datetime.now(timezone.utc)  ## safer then datetime.utcnow
                process_dates = [dt_today + timedelta(days=x) for x in range(-1,1)] # 2 recent days, yesterday and today

            if hostname.endswith('.umd.edu'):
                root_files = np.hstack([np.hstack([glob(data_dir + f'{x:%Y-%m-%d}/*.tar.gz') for x in process_dates]) for data_dir in data_dirs])

            elif hostname.endswith('ec2.internal'):
                # Jun modified2: added by Xinjia, for s3 bucket
                root_files = np.hstack([np.hstack([fs.glob(data_dir + f'{x:%Y/%m/%d}/*/*.tar.gz') for x in process_dates]) for data_dir in data_dirs])

            elif hostname.endswith('.star1.nesdis.noaa.gov'):
                ## --- for SCDR command, by Xinjia
                first_date = process_dates[0].strftime("%Y-%m-%d")
                last_date  = process_dates[-1].strftime("%Y-%m-%d")
                spire_types = [
                    "SPIRE_LEMUR_L0_ATTOBS",
                    "SPIRE_LEMUR_L0_NAVOBS",
                    "SPIRE_LEMUR_L0_NAVSOL",
                    "SPIRE_LEMUR_L0_ROCREF",
                    "SPIRE_LEMUR_L0_RO" ]

                all_files = []
                for dtype in spire_types:
                    scdr_command = f"{scdr_executable} -t {dtype} -stime {first_date} -etime {last_date}T23:59:59Z"
                    # Split the command string into a list of arguments for subprocess.run
                    command_args = shlex.split(scdr_command)
                    result = subprocess.run(
                                        command_args,
                                        capture_output=True,
                                        text=True,  # Decode stdout/stderr as text
                                        check=True,  # Raise CalledProcessError if the command returns a non-zero exit code
                                        shell=False # Set to False for security and clarity
                                        )
                    lines = result.stdout.strip().splitlines()
                    if lines:
                        all_files.extend(lines)
                root_files = np.array(all_files)

            idx = np.array([tarfile in root_files for tarfile in tarfiles], dtype=bool)
            tarfiles = tarfiles[idx]
            progress = progress[idx]
            timetag = timetag[idx]
            if np.any(idx==False):
                log_message(f'- {np.sum(idx==False):d} files are removed from the queue.')

            idx = np.array([root_file not in tarfiles for root_file in root_files], dtype=bool)
            tarfiles = np.concatenate((tarfiles, root_files[idx]))
            progress = np.concatenate((progress, np.full((np.sum(idx),),-1,dtype='int8')))
            timetag = np.concatenate((timetag, np.full((np.sum(idx),),datetime.now(),dtype=datetime)))
            if np.any(idx):
                log_message(f'- {np.sum(idx == True):d} files are added to the queue.')

            #idx = np.argsort(tarfiles)  ## sort by the basename, by Xinjia
            get_base = np.vectorize(os.path.basename)
            base_tarfiles = get_base(tarfiles)
            idx = np.argsort(base_tarfiles)

            tarfiles = tarfiles[idx]
            print(f"tarfiles: {tarfiles}")
            progress = progress[idx]
            timetag = timetag[idx]
            save_progress()
            print("after save_progress (1)")

            try:
                #ztypes = ['navObs', 'navSol', 'attObs', 'rocObs']
                for ztype in ztypes[:-1]:
                    print(f"ztype: {ztype}")
                    print(len(progress), len(tarfiles))
                    print("progress == -1 count:", np.sum(progress == -1))
                    print("example tarfiles:", tarfiles[:5])
                    
                    idx = np.where(np.logical_and(progress==-1, [f'_{ztype}_' in tt for tt in tarfiles]))[0]
                    log_message(f"-- Processing {len(idx)} {ztype} files")

                    update_progress(idx, 0)
                    save_progress()
                    os.makedirs(root_out, exist_ok=True)
                    os.system('rm -r ' + root_out + '/*')
                    os.makedirs(root_out+'/'+ztype, exist_ok=True)
                    for nn, ii in enumerate(idx):
                        if hostname.endswith('ec2.internal'):
                            with fs.open(tarfiles[ii], 'rb') as f:
                                with tarfile.open(fileobj=f, mode='r:gz') as tar:
                                    tar.extractall(path=root_out+'/'+ztype)
                        else:
                            with tarfile.open(tarfiles[ii], 'r:gz') as tar:
                                tar.extractall(path=root_out+'/'+ztype)
                    prepare_file(working_leo, ztype)
                    update_progress(idx, 1)
                    save_progress()
                # TC: updated 2026/04/28 (inserted 7 lines, reprocess failed cases within 3 hours)
                # TC: updated 2026/04/29 (edited 2 lines and inserted 1 additional line, add 25-minute cool-down for revisit)
                idx1 = np.where(np.logical_and(progress == 2, ['_RO/' in tt for tt in tarfiles]))[0]
                if len(idx1) > 0:
                    rotime = np.array([datetime.strptime(sub(r'.*_gnss-ro_(.*)_FM.*', r'\1', tt),
                                       '%Y-%m-%dT%H-%M-%S').replace(tzinfo=ZoneInfo('UTC')) for tt in tarfiles[idx1]])
                    idxt = datetime.now().astimezone(ZoneInfo('UTC')) - rotime < timedelta(hours=3)
                    idx = datetime.now() - timetag[idx1] > timedelta(minutes=25)
                    update_progress(idx1[np.logical_and(idx,idxt)], -1)
                    update_progress(idx1[np.logical_and(idx,~idxt)], 1)
                if hostname.endswith('.star1.nesdis.noaa.gov'):
                    idx = np.where(np.logical_and(progress==-1, ['/spire_gnss-ro_' in tt for tt in tarfiles]))[0]
                else:
                    idx = np.where(np.logical_and(progress==-1, ['_RO/' in tt for tt in tarfiles]))[0]
                all_leos = np.array([sub(r'.*_(FM...)_.*', r'\1', x) for x in tarfiles[idx]])
                print(f"all_leos: {all_leos}")
                uniq_leos = np.unique(all_leos)
                print(f"uniq_leos: {uniq_leos}")
                for leo in uniq_leos:
                    idx_now = idx[all_leos==leo]
                    log_message(f"-- Processing {leo} ({len(idx_now):d}/{np.sum(progress==-1):d})")
                    update_progress(idx_now,0)
                    save_progress()
                    # dt = datetime.strptime(sub(r'.*/(\d{4}-\d\d-\d\d)/.*',r'\1',tarfiles[idx[ii]]),'%Y-%m-%d')
                    os.makedirs(root_out, exist_ok=True)
                    os.makedirs(f't01_RO_{working_leo}', exist_ok=True)
                    os.makedirs(f't01_REF_{working_leo}', exist_ok=True)
                    os.makedirs(f't03_pair_{working_leo}', exist_ok=True)
                    os.system('rm -r ' + root_out + '/*')
                    os.system(f'rm t01_RO_{working_leo}/*.npz')
                    os.system(f'rm t01_REF_{working_leo}/*.npz')
                    os.system(f'rm t03_pair_{working_leo}/*.npz')
                    os.makedirs(root_out + '/' + ztypes[-1], exist_ok=True)
                    print("root_out/ztypes", root_out + '/' + ztypes[-1])
                    for ii in idx_now:
                        if hostname.endswith('ec2.internal'):
                            with fs.open(tarfiles[ii], 'rb') as f:
                                with tarfile.open(fileobj=f, mode='r:gz') as tar:
                                    tar.extractall(path=root_out+'/'+ztypes[-1], filter="data")
                        else:
                            with tarfile.open(tarfiles[ii], 'r:gz') as tar:
                                tar.extractall(path=root_out+'/'+ztypes[-1], filter="data")
                    
                    gather_data(working_leo)
                    print("after gather_data")
                    pair_ro_ref(working_leo)
                    print("after pair_ro_ref")
                    # TC: updated 2026/04/28 (3 lines, take results and switch status accordingly)
                    results = cal_exph(working_leo, nProcess, '')
                    print("after cal_exph")
                    update_progress(idx_now[results], 1)
                    update_progress(idx_now[~results], 2)
                    #cal_exph(working_leo, nProcess, sub(r'.*/', r'', tarfiles[ii]))
                    #update_progress(idx_now, 1)
                    print("after update_progress")
                    save_progress()
                    print("after save_progress")
                if len(idx)>0:
                    save_progress()
                    new_proc = True
                    log_message(f'- Progress completed.')
            except Exception as e:
                log_message('- Unexpected error occurs.')
                log_message(f"-- {type(e).__name__} >> {str(e)}")
                for tb in extract_tb(e.__traceback__):
                    log_message(f"--  line {tb.lineno} of {tb.filename}")
                update_progress(progress == 0, -2)
                save_progress()
                new_proc = True
            if not new_proc:
                if nloop > 1:
                    log_message(f'- No new item, exit program.')
                break
        # TC: updated 2026/04/29 (add 2 lines, release processing status)
        lastaccess[0] = datetime(2000,1,1)
        save_progress()
