"""Generator danych
To jest wielowątkowy, skalowalny i wydajny sposób na wczytanie ogromnych obrazów 
z systemu plików jako zbioru danych 
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from tensorflow.keras.utils import Sequence

import numpy as np
import os
import skimage
from skimage.io import imread
from model_utils import parser

class DataGenerator(Sequence):
    """Wielowątkowy generator danych.
        Każdy wątek odczytuje partię plików i etykiety ich obiektów.
        Etykieta jest maską semantyczną na poziomie piksela.

    Argumenty:
        args : Konfiguracja użytkownika
        shuffle (Bool): Jeśli zbiór danych powinien 
        być wymieszany przed próbkowaniem.
    """
    def __init__(self,
                 args,
                 shuffle=True):
        self.args = args
        self.input_shape = (args.height, 
                            args.width,
                            args.channels)
        self.shuffle = shuffle
        self.get_dictionary()
        self.on_epoch_end()


    def get_dictionary(self):
        """Załaduj słownik z danymi odniesienia w formacie 
            nazwa pliku : maska segmentacji
        """
        path = os.path.join(self.args.data_path,
                            self.args.train_labels)
        self.dictionary = np.load(path,
                                  allow_pickle=True).flat[0]
        self.keys = np.array(list(self.dictionary.keys()))
        labels = self.dictionary[self.keys[0]]
        self.n_classes = labels.shape[-1]


    def __len__(self):
        """Liczba partii na epokę"""
        blen = np.floor(len(self.dictionary) / self.args.batch_size)
        return int(blen)


    def __getitem__(self, index):
        """Pobierz partię danych"""
        start_index = index * self.args.batch_size
        end_index = (index+1) * self.args.batch_size
        keys = self.keys[start_index : end_index]
        x, y = self.__data_generation(keys)
        return x, y


    def on_epoch_end(self):
        """Przemieszaj po każdej epoce"""
        if self.shuffle == True:
            np.random.shuffle(self.keys)


    def __data_generation(self, keys):
        """ Generowanie danych uczących: obrazów 
        oraz prawidłowych etykiet segmentacji 

        Argumenty:
            keys (array): losowo próbkowane klucze
            (kluczem jest nazwa obrazu)

        Zwraca:
            x (tensor): partia obrazów
            y (tensor): partia kategorii pikseli
        """
    

        # partia obrazów
        x = []
        # i odpowiadająca im maska segmentacji
        y = []

        for i, key in enumerate(keys):
            # założenie: obrazy są przechowywane w ścieżce self.args.data_path
            # kluczem jest nazwa pliku 
            image_path = os.path.join(self.args.data_path, key)
            image = skimage.img_as_float(imread(image_path))
            # dodanie obrazu do listy
            x.append(image)
            # i odpowiadającej mu etykiety (maski segmentacji) 
            labels = self.dictionary[key]
            y.append(labels)

        return np.array(x), np.array(y)


if __name__ == '__main__':
    parser = parser()
    args = parser.parse_args()
    data_gen = DataGenerator(args)
    images, labels = data_gen.__getitem__(0)
    
    import matplotlib.pyplot as plt
    
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Obraz wejsciowy', fontsize=14)
    plt.imshow(images[0])
    plt.savefig("input_image.png", bbox_inches='tight')
    plt.show()

    labels = labels * 255
    masks = labels[..., 1:]
    bgs = labels[..., 0]

    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Segmentacja semantyczna', fontsize=14)
    plt.imshow(masks[0])
    plt.savefig("segmentation.png", bbox_inches='tight')
    plt.show()

    shape = (bgs[0].shape[0], bgs[0].shape[1])
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Tlo', fontsize=14)
    plt.imshow(np.reshape(bgs[0], shape), cmap='gray', vmin=0, vmax=255)
    plt.savefig("tlo.png", bbox_inches='tight')
    plt.show()

