from graph import Graph
import matplotlib.pyplot as plt
import os

class OverTime(Graph):
    def __init__(self):
        super().__init__('overTime')
        self.data = {
            'date': [], # should directly correlate to downloads
            'count': [{}], # { country: num }
            'size': [{}], # { country: num }
        }
    
    def process(self, ip, time, request, size, location, log):
        # logs are given in time ascending order
        if len(self.data['date']) == 0:
            self.data['date'].append(time)
        elif self.data['date'][-1] != time:
            self.data['date'].append(time)
            self.data['count'].append({})
            self.data['size'].append({})

        con = location.country.name if location else 'Unknown'

        self.inc(self.data['count'][-1], con)
        self.inc(self.data['size'][-1], con, size)

        return []
    
    def draw(self, path):
        print('Generating time series...')
        old_size = plt.rcParams['font.size']
        plt.rcParams.update({'font.size': 15})

        self.plot(path, self.data['count'], 'Downloads by Count Over Time', 'Number of Downloads', 'download_over_time_count.png')
        self.plot(path, self.data['size'], 'Downloads by File Size Over Time', 'Download Size (GiB)', 'download_over_time_size.png')
        
        plt.rcParams.update({'font.size': old_size})

    def plot(self, path, data, title, xaxis, file):
        regions = {}
        regionCount = {}

        # get time series for each country
        for i, downloads in enumerate(data):
            for country, count in downloads.items():
                if country not in regions:
                    regions[country] = [0] * i
                    regionCount[country] = 0
                regions[country].append(count)
                regionCount[country] += count
            
            # pad regions that do not occur at time step as we use the same x axis for all countries
            for r in regions:
                if len(regions[r]) != i + 1:
                    regions[r].append(0)
        
        # only plot top x
        numPlot = 5
        regionPlot = {}
        inOrder = sorted(zip(regions.items(), regionCount.values()), reverse=True, key=lambda k: k[1])
        for [[k, v], _] in inOrder[:numPlot]:
            regionPlot[k] = v
        
        # coalesce other values into 'Other'
        if len(regions) > numPlot:
            regionPlot['Other'] = [0] * len(data)
            for [[_, series], _] in inOrder[numPlot:]:
                for j, v in enumerate(series):
                    regionPlot['Other'][j] += v

        fig, ax = plt.subplots(figsize=(12.8, 6))
        ax.stackplot(self.data['date'], regionPlot.values(), labels=regionPlot.keys(), alpha=0.75, colors=self.getColors(regionPlot.keys()))

        ax.set_title(title)
        ax.set_ylabel(xaxis)
        ax.set_xlabel('Date of Month')
        ax.set_yscale('log')
        plt.legend(loc='upper left', fontsize=12)

        plt.tight_layout()
        plt.savefig(os.path.join(path, file), bbox_inches='tight', dpi=100)