"""Generator danych dla zbioru MNIST: 
    przycięcie względem środka i transformacja obrazu.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras.utils import Sequence
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

import numpy as np
from skimage.transform import resize, rotate


class DataGenerator(Sequence):
    def __init__(self,
                 args,
                 shuffle=True,
                 siamese=False,
                 mine=False,
                 crop_size=4):
        """Wielowątkowy generator danych. Każdy wątek odczytuje partię obrazów
            i dokonuje ich transformacji w taki sposób
            by klasa obrazu nie uległa zmianie.

        Argumenty:
            args (argparse): Opcje definiowane przez użytkownika,
                takie jak rozmiar partii itp.
            shuffle (Bool): Czy dokonać przemieszania zbioru przed próbkowaniem
                czy nie.
            siamese (Bool): Czy gnerować parę obrazów (X oraz X_kreskowane)
                czy nie
            mine (Bool): Użyj algorytmu MINE zamiast IIC
            crop_size (int): Liczba pikseli do przycięcia z każdej strony obrazu
        """
        self.args = args
        self.shuffle = shuffle
        self.siamese = siamese
        self.mine = mine
        self.crop_size = crop_size
        self._dataset()
        self.on_epoch_end()

    def __len__(self):
        """Liczba partii na epokę
        """
        return int(np.floor(len(self.indexes) / self.args.batch_size))


    def __getitem__(self, index):
        """Indeksy próbek obrazu dla bieżącej partii
        """
        start_index = index * self.args.batch_size
        end_index = (index+1) * self.args.batch_size
        return self.__data_generation(start_index, end_index)

    def _dataset(self):
        """Ładuje zbiór i przeprowadza normalizację
        """
        dataset = self.args.dataset
        if self.args.train:
            (self.data, self.label), (_, _) = dataset.load_data()
        else:
            (_, _), (self.data, self.label) = dataset.load_data()

        if self.args.dataset == mnist:
            self.n_channels = 1
        else:
            self.n_channels = self.data.shape[3]

        image_size = self.data.shape[1]
        side = image_size - self.crop_size
        self.input_shape = [side, side, self.n_channels]

        # Z rozproszonych etykiet na skategoryzowane
        self.n_labels = len(np.unique(self.label))
        self.label = to_categorical(self.label)

        # zmiana kształtu i normalizacja obrazów wejściowych
        orig_shape = [-1, image_size, image_size, self.n_channels]
        self.data = np.reshape(self.data, orig_shape)
        self.data = self.data.astype('float32') / 255
        self.indexes = [i for i in range(self.data.shape[0])]


    def on_epoch_end(self):
        """Jeśli wybrane, miesza zbiór danych po każdej epoce
        """
        if self.shuffle == True:
            np.random.shuffle(self.indexes)


    def random_crop(self, image, target_shape, crop_sizes):
        """Wykonuje losowe przycięcie i przeskalowanie do początkowego rozmiaru

        Argumenty:
            image (tensor): Obraz do przycięcia o zmiany rozmiaru
            target_shape (tensor): Wyjściowy kształt
            crop_sizes (list): Lista rozmiarów 
                do jakich obraz może zostać przycięty
        """
        height, width = image.shape[0], image.shape[1]
        crop_size_idx = np.random.randint(0, len(crop_sizes))
        d = crop_sizes[crop_size_idx]
        x = height - d
        y = width - d
        center = np.random.randint(0, 2)
        if center:
            dx = dy = d // 2
        else:
            dx = np.random.randint(0, d + 1)
            dy = np.random.randint(0, d + 1)

        image = image[dy:(y + dy), dx:(x + dx), :]
        image = resize(image, target_shape)
        return image


    def random_rotate(self,
                      image, 
                      deg=20, 
                      target_shape=(24, 24, 1)):
        """Losowy obrót obrazu

        Argumenty:
            image (tensor): Obraz do przycięcia i zmiany wielkości
            deg (int): O ile stopni obrót
            target_shape (tensor): Wyjściowy kształt
        """
        angle = np.random.randint(-deg, deg)
        image = rotate(image, angle)
        image = resize(image, target_shape)
        return image


    def __data_generation(self, start_index, end_index):
        """Algorytm generowanie danych. Metoda generuje
        partię sparowanych obrazów (obraz oryginalny X oraz
        obraz przekształcony X kreskowane). Partia bliźniaczych
        obrazów jest wykorzystywana do uczenia algorytmu opartego na MI:
        1) IIC oraz 2) MINE (sekcja 7.)

        Argumenty:
        start_index (int): przy zadanej macierzy obrazów
            jest to początkowy indeks do pobierania próbki
        end_index (int): przy zadanej macierzy obrazów
            jest to końcowy indeks do pobierania próbki
        """


        d = self.crop_size // 2
        crop_sizes = [self.crop_size*2 + i for i in range(0,5,2)]
        image_size = self.data.shape[1] - self.crop_size
        x = self.data[self.indexes[start_index : end_index]]
        y1 = self.label[self.indexes[start_index : end_index]]

        target_shape = (x.shape[0], *self.input_shape)
        x1 = np.zeros(target_shape)
        if self.siamese:
            y2 = y1 
            x2 = np.zeros(target_shape)

        for i in range(x1.shape[0]):
            image = x[i]
            x1[i] = image[d: image_size + d, d: image_size + d]
            if self.siamese:
                rotate = np.random.randint(0, 2)
                # 50-50% szans na przycięcie lub obrót
                if rotate == 1:
                    shape = target_shape[1:]
                    x2[i] = self.random_rotate(image,
                                               target_shape=shape)
                else:
                    x2[i] = self.random_crop(image,
                                             target_shape[1:],
                                             crop_sizes)

        # dla IIC interesują nas głównie obrazy sparowane
        # X oraz X kreskowane = G(X)
        if self.siamese:
            # Jeśli jest wybrany algorytm MINE, użyj tego do wygenerowania
            # danych uczących (patrz podrozdział 13.9)
            if self.mine:
                y = np.concatenate([y1, y2], axis=0)
                m1 = np.copy(x1)
                m2 = np.copy(x2)
                np.random.shuffle(m2)

                x1 =  np.concatenate((x1, m1), axis=0)
                x2 =  np.concatenate((x2, m2), axis=0)
                x = (x1, x2)
                return x, y

            x_train = np.concatenate([x1, x2], axis=0)
            y_train = np.concatenate([y1, y2], axis=0)
            y = []
            for i in range(self.args.heads):
                y.append(y_train)
            return x_train, y

        return x1, y1


if __name__ == '__main__':
    datagen = DataGenerator()

