"""Generator danych
Jest to wielowątkowy, skalowalny i wydajny sposób odczytywania jako zbiorów danych 
ogromnych obrazów znajdujących się w systemie plików.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from tensorflow.python.keras.utils.data_utils import Sequence

import numpy as np
import layer_utils
import label_utils
import os
import skimage

from layer_utils import get_gt_data
from layer_utils import anchor_boxes

from skimage.io import imread
from skimage.util import random_noise
from skimage import exposure


class DataGenerator(Sequence):
    """Wielowątkowy generator danych
    Każdy wątek wczytuje partię obrazów i etykiety obiektów 

    Argumenty:
        args: konfiguracja użytkownika
        dictionary: słownik nazw obrazów i etykiety obiektów
        n_classes (int): liczba klas obiektów
        feature_shapes (tensor): kształt map cech czoła sieci SSD
        n_anchors (int): liczba pól zakotwiczenia na punkt mapy cech 
        shuffle (Bool): jeśli zbiór danych ma być wymieszany przed próbkowaniem
    """

    def __init__(self,
                 args,
                 dictionary,
                 n_classes,
                 feature_shapes=[],
                 n_anchors=4,
                 shuffle=True):
        self.args = args
        self.dictionary = dictionary
        self.n_classes = n_classes
        self.keys = np.array(list(self.dictionary.keys()))
        self.input_shape = (args.height, 
                            args.width,
                            args.channels)
        self.feature_shapes = feature_shapes
        self.n_anchors = n_anchors
        self.shuffle = shuffle
        self.on_epoch_end()
        self.get_n_boxes()


    def __len__(self):
        """Liczba partii na epokę"""
        blen = np.floor(len(self.dictionary) / self.args.batch_size)
        return int(blen)


    def __getitem__(self, index):
        """Pobranie partii danych """
        start_index = index * self.args.batch_size
        end_index = (index+1) * self.args.batch_size
        keys = self.keys[start_index : end_index]
        x, y = self.__data_generation(keys)
        return x, y


    def on_epoch_end(self):
        """Wymieszaj po każdej epoce"""
        if self.shuffle == True:
            np.random.shuffle(self.keys)


    def get_n_boxes(self):
        """Całkowita liczba obwiedni"""
        self.n_boxes = 0
        for shape in self.feature_shapes:
            self.n_boxes += np.prod(shape) // self.n_anchors
        return self.n_boxes


    def apply_random_noise(self, image, percent=30):
        """Zastosuj losowy szum do obrazu (nie używane)"""
        random = np.random.randint(0, 100)
        if random < percent:
            image = random_noise(image)
        return image


    def apply_random_intensity_rescale(self, image, percent=30):
        """Zastosuj przeskalowanie o losowej intensywności do obrazu (nie używane)"""
        random = np.random.randint(0, 100)
        if random < percent:
            v_min, v_max = np.percentile(image, (0.2, 99.8))
            image = exposure.rescale_intensity(image, in_range=(v_min, v_max))
        return image


    def apply_random_exposure_adjust(self, image, percent=30):
        """Zastosuj przeskalowanie o losowej intensywności do obrazu (nie używane)"""
        random = np.random.randint(0, 100)
        if random < percent:
            image = exposure.adjust_gamma(image, gamma=0.4, gain=0.9)
            # another exposure algo
            # image = exposure.adjust_log(image)
        return image


    def __data_generation(self, keys):
        """Generowanie danych uczących: obrazów i etykiet referencyjnych obiektów

        Argumenty:
            keys (array): losowo próbkowane klucze (kluczem jest nazwa pliku)

        Zwraca:
            x (tensor): obrazy z partii
            y (tensor): klasy partii, przesunięcia i maski
        """
        # uczące dane wejściowe
        x = np.zeros((self.args.batch_size, *self.input_shape))
        dim = (self.args.batch_size, self.n_boxes, self.n_classes)
        # klasa referencyjna
        gt_class = np.zeros(dim)
        dim = (self.args.batch_size, self.n_boxes, 4)
        # referencyjne przesunięcie
        gt_offset = np.zeros(dim)
        # maska poprawnego pola zakotwiczenia
        gt_mask = np.zeros(dim)

        for i, key in enumerate(keys):
            # oczekujemy przechowywania obrazów w self.args.data_path
            # kluczem jest nazwa pliku
            image_path = os.path.join(self.args.data_path, key)
            image = skimage.img_as_float(imread(image_path))
            # przypisanie obrazu do indeksu partii
            x[i] = image
            # wpis etykiety ma cztery wymiary — współrzędne pola zakotwiczenia
            # plus jednowymiarową klasę etykiety
            labels = self.dictionary[key]
            labels = np.array(labels)
            # 4 współrzędne obwiedni są pierwszymi czterema elementami etykiet
            # ostatnią etykietą jest etykieta klasy obiektu
            boxes = labels[:,0:-1]
            for index, feature_shape in enumerate(self.feature_shapes):
                # generowanie pól zakotwiczenia
                anchors = anchor_boxes(feature_shape,
                                       image.shape,
                                       index=index,
                                       n_layers=self.args.layers)
                # każda warstwa cech ma wiersz pól zakotwiczenia
                anchors = np.reshape(anchors, [-1, 4])
                # obliczenie IoU każdego pola zakotwiczenia
                # w stosunku do każdej obwiedni
                iou = layer_utils.iou(anchors, boxes)

                # generowanie klas, przesunięć i maski referencyjnej
                gt = get_gt_data(iou,
                                 n_classes=self.n_classes,
                                 anchors=anchors,
                                 labels=labels,
                                 normalize=self.args.normalize,
                                 threshold=self.args.threshold)
                gt_cls, gt_off, gt_msk = gt
                if index == 0:
                    cls = np.array(gt_cls)
                    off = np.array(gt_off)
                    msk = np.array(gt_msk)
                else:
                    cls = np.append(cls, gt_cls, axis=0)
                    off = np.append(off, gt_off, axis=0)
                    msk = np.append(msk, gt_msk, axis=0)

            gt_class[i] = cls
            gt_offset[i] = off
            gt_mask[i] = msk


        y = [gt_class, np.concatenate((gt_offset, gt_mask), axis=-1)]

        return x, y
