from graph import Graph
import matplotlib.pyplot as plt
import re
import os

class Popularity(Graph):
    def __init__(self):
        super().__init__('popularity')
        self.data = {
            'mmcGrd': 0,
            'task_icvs_gpsro': 0,
            'star_gnssro': {}, # mission: num
            'starro': {}, # mssion: num
        }

        self.names = { # new name: [old names...]
            'cosmic1_v2013': ['cosmic', 'cosmic1'],
            'cosmic1_v2021': ['cosmic2021'],
            'kompsat5': ['kompsat5', 'kompsat5rt'],
            'paz': ['paz', 'pazrt'],
            'geoopt': ['geoopt', 'geooptrt'],
            'spire': ['spire', 'spirert'],
            'planetiq': ['planetiq', 'planetiqrt'],
            'DIAG and QC monitoring': ['task_icvs_gpsro']
        }

        # http
        self.STAR_GNSSRO_REGEX = re.compile(r'/star_gnssro/data/([^/]+)/.+')
        self.MMCGRD_REGEX = re.compile(r'/star_gnssro/data/(mmcGrd)/.+')
        self.TASK_ICVS_GNSSRO_REGEX = re.compile(r'/(task_icvs_gpsro)/.*')

        # ftp
        self.STARRO_REGEX = re.compile(r'/starro/data/([^/]+)/.+')

        self.patterns = [ # regex, matching key in self.data (None if is just int)
            (self.STAR_GNSSRO_REGEX, 'star_gnssro'),
            (self.MMCGRD_REGEX, None),
            (self.TASK_ICVS_GNSSRO_REGEX, None),
            (self.STARRO_REGEX, 'starro')
        ]

    def process(self, ip, time, request, size, location, log):
        for pattern, folder in self.patterns:
            out = re.match(pattern, request)
            if out:
                [mission] = out.groups()
                if folder:
                    self.inc(self.data[folder], mission)
                else:
                    self.data[mission] += 1
                break
        else:
            return [f'Could not read request {request}']
        
        return []

    def draw(self, path):
        print('Generating popularity graph...')
        old_size = plt.rcParams['font.size']
        plt.rcParams.update({'font.size': 12})

        formatted = {} # { mission: count }
        for key, value in self.data.items():
            if isinstance(value, int):
                formatted[key] = value
            else:
                for mission, count in list(value.items()):
                    self.inc(formatted, mission, count)

        combined = {} # combine missions according to self.names
        formattedNames = {}
        for key, values in self.names.items():
            for v in values:
                formattedNames[v] = key
        for mission, count in formatted.items():
            name = formattedNames[mission] if mission in formattedNames else mission
            self.inc(combined, name, count)

        labels = list(combined.keys())
        sizes = list(combined.values())

        # sort data in descending order
        sizes, labels = (list(t) for t in zip(*sorted(zip(sizes, labels), reverse=True)))

        fig, ax = plt.subplots(figsize=(19.2, 8))
        r = ax.bar(labels, sizes)

        for rect in r: # annotate each bar with their values
            h = rect.get_height()
            ax.annotate(f'{h}', xy=(rect.get_x() + rect.get_width() / 2, h), xytext=(0, 2), textcoords='offset points', ha='center', va='bottom', size=14)

        ax.set_xticks(range(len(labels)))
        ax.set_xticklabels(labels, rotation=-45, ha='left')
        ax.set_title('Number of Downloads Per Product')
        ax.set_ylabel('Number of Downloads')

        plt.tight_layout()
        plt.savefig(os.path.join(path, 'popularity.png'), bbox_inches='tight', dpi=100)

        plt.rcParams.update({'font.size': old_size})
