"""Uczenie ResNet na zbiorze CIFAR10.

ResNet v1
[a] Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf

ResNet v2
[b] Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras.layers import Dense, Conv2D
from tensorflow.keras.layers import BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input
from tensorflow.keras.layers import Flatten, add
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import plot_model
from tensorflow.keras.utils import to_categorical
import numpy as np
import os
import math

# parametry uczenia
batch_size = 32 # w oryginalnym artykule batch_size=128 dla uczenia
epochs = 200
data_augmentation = True
num_classes = 10

# włączenie subtract_pixel_mean polepsza dokładność
subtract_pixel_mean = True

# Parametry modelu
# ----------------------------------------------------------------------------
#           |      | 200-epok    | Oryg Artyk  | 200-epok    | Oryg Artyk| sek/epokę
# Model     |  n   | ResNet v1   | ResNet v1    | ResNet v2   | ResNet v2 | GTX1080Ti
#           |v1(v2)| %Dokładność | %Dokładność | %Dokładność | %Dokładność| v1 (v2)
# ----------------------------------------------------------------------------
# ResNet20  | 3 (2)| 92.16     | 91.25     | -----     | -----     | 35 (---)
# ResNet32  | 5(NA)| 92.46     | 92.49     | NA        | NA        | 50 ( NA)
# ResNet44  | 7(NA)| 92.50     | 92.83     | NA        | NA        | 70 ( NA)
# ResNet56  | 9 (6)| 92.71     | 93.03     | 93.01     | NA        | 90 (100)
# ResNet110 |18(12)| 92.65     | 93.39+-.16| 93.15     | 93.63     | 165(180)
# ResNet164 |27(18)| -----     | 94.07     | -----     | 94.54     | ---(---)
# ResNet1001| (111)| -----     | 92.39     | -----     | 95.08+-.14| ---(---)
# ---------------------------------------------------------------------------
n = 3

# wersja modelu:
# oryginalny artykuł: version = 1 (ResNet v1), 
# ulepszony ResNet: version = 2 (ResNet v2)
version = 1

# obliczenie głębokości z dostarczonych parametrów modelu 
if version == 1:
    depth = n * 6 + 2
elif version == 2:
    depth = n * 9 + 2

# nazwa modelu, głębokość i wersja
model_type = 'ResNet%dv%d' % (depth, version)

# załadowanie zbioru CIFAR10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# wymiary obrazów wejściowych
input_shape = x_train.shape[1:]

# normalizacja danych
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# jeśli subtract_pixel_mean jest True
if subtract_pixel_mean:
    x_train_mean = np.mean(x_train, axis=0)
    x_train -= x_train_mean
    x_test -= x_train_mean

print('postać x_train:', x_train.shape)
print(x_train.shape[0], 'próbek uczących')
print(x_test.shape[0], 'próbek testowych')
print('postać y_train:', y_train.shape)

# konwersja wektorów klas na binarną macierz klas
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)


def lr_schedule(epoch):
    """Harmonogramowanie współczynnika uczenia

    Współczynnik uczenia ma być zredukowany po 80, 120, 160, 180 epoce.
    Wywoływanie automatyczne podczas terningu w każdej epoce jako 
    część wywołania zwrotnego

    # Argumenty
        epoch (int): Liczba epok

    # Zwraca
        lr (float32): współczynnik uczenia
    """
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    elif epoch > 80:
        lr *= 1e-1
    print('współczynnik uczenia ', lr)
    return lr


def resnet_layer(inputs,
                 num_filters=16,
                 kernel_size=3,
                 strides=1,
                 activation='relu',
                 batch_normalization=True,
                 conv_first=True):
    """Konstruktor stosu warstw 2D Convolution-Batch Normalization-Activation

    Argumenty:
        inputs (tensor): tensor wejściowy z obrazu wejściowego 
        lub z poprzedniej warstwy
        num_filters (int): liczba filtrów Conv2D
        kernel_size (int): wymiary kwadratowego jądra Conv2D
        strides (int): wymiary kroku Conv2D
        activation (string): nazwa funkcji aktywacji
        batch_normalization (bool): czy wykonać normalizację wsadową
        conv_first (bool): conv-bn-activation (True) or 
                           bn-activation-conv (False)

    Zwraca:
        x (tensor): tensor jako wejście dla następnej warstwy
    """
    conv = Conv2D(num_filters,
                  kernel_size=kernel_size,
                  strides=strides,
                  padding='same',
                  kernel_initializer='he_normal',
                  kernel_regularizer=l2(1e-4))

    x = inputs
    if conv_first:
        x = conv(x)
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
    else:
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
        x = conv(x)
    return x


def resnet_v1(input_shape, depth, num_classes=10):
    """Konstruktor sieci ResNet w wersji 1 [a]

    Stos 2 x (3x3) Conv2D-BN-ReLU
    Ostatnia warstwa ReLU jest po połączeniu skrótowym.
    Na początku każdego etapu rozmiar mapy cech jest zmniejszany o połowę
    (ang. downsampled) przez warstwę splotową z krokiem 2 tak długo, jak długo 
    jest spełniony warunek. Na każdym etapie warstwy mają taką samą liczbę filtrów.
    Rozmiar mapy cech:
    etap 0: 32x32, 16
    etap 1: 16x16, 32
    etap 2: 8x8, 64
    Liczba parametrów jest w przybliżeniu równa pokazanym w tabeli 6 dla [a]:
    ResNet20 0.27M
    ResNet32 0.46M
    ResNet44 0.66M
    ResNet56 0.85M
    ResNet110 1.7M

    Argumenty:
        input_shape (tensor): postać tensora obrazu wejściowego
        depth (int): liczba warstw splotowego jądra bazowego
        num_classes (int): liczba klas (CIFAR10 ma 10 klas)

    Zwraca:
        model (Model): instancję modelu Keras
    """
    if (depth - 2) % 6 != 0:
        raise ValueError('głębokość powinna wynosić 6n+2 (tzn. 20, 32, w [a])')
    # początek definicji modelu
    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = Input(shape=input_shape)
    x = resnet_layer(inputs=inputs)
    # stworzenie instancji stosu jednostek resztkowych
    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            # pierwsza warstwa (ale nie stos)
            if stack > 0 and res_block == 0:
                strides = 2  # zmniejszanie rozmiaru (ang. downsample)
            y = resnet_layer(inputs=x,
                             num_filters=num_filters,
                             strides=strides)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters,
                             activation=None)
            # pierwsza warstwa (ale nie stos)
            if stack > 0 and res_block == 0:
                # projekcja liniowa skrótu resztkowego
                # połączenie — dopasowanie zmienionych wymiarów
                x = resnet_layer(inputs=x,
                                 num_filters=num_filters,
                                 kernel_size=1,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            x = add([x, y])
            x = Activation('relu')(x)
        num_filters *= 2

    # dodanie klasyfikatora na górze
    # v1 nie używa BN po ostatnim połączeniu skrótowym-ReLU
    x = AveragePooling2D(pool_size=8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes,
                    activation='softmax',
                    kernel_initializer='he_normal')(y)

    # stworzenie instancji modelu
    model = Model(inputs=inputs, outputs=outputs)
    return model




def resnet_v2(input_shape, depth, num_classes=10):
    """Konstruktor sieci ResNet w wersji 2 [b]

    Stos warstw BN-ReLU-Conv2D (1x1)-(3x3)-(1x1) 
    również znanych jako warstwa ograniczająca.
    Pierwsze połączenie skrótowe na warstwę to 1x1 Conv2D.
    Drugie i następne połączenia są identyczne.
    Na początku każdego etapu rozmiar mapy cech jest zmniejszany o połowę
    (ang. downsampled) przez warstwę splotową z krokiem 2 tak długo, jak długo 
    jest spełniony warunek. Na każdym etapie warstwy mają taką samą liczbę filtrów  taki sam rozmiar filtrów map.
    Rozmiary map cech:
    conv1  : 32x32,   16
    stage 0: 32x32,   64
    stage 1: 16x16, 128
    stage 2: 8x8,    256

    Argumenty:
        input_shape (tensor): postać tensora obrazu wejściowego
        depth (int): liczba warstw splotowego jądra bazowego
        num_classes (int): liczba klas (CIFAR10 ma 10 klas)

    Zwraca:
        model (Model): instancję modelu Keras
    """
    if (depth - 2) % 9 != 0:
        raise ValueError('glebokosc powinna byc 9n+2 (tzn. 110 w [b])')
    # początek definicji modelu
    num_filters_in = 16
    num_res_blocks = int((depth - 2) / 9)

    inputs = Input(shape=input_shape)
    # w v2 występuje Conv2D z BN-ReLU na wejściu przed podziałem na dwie ścieżki
    x = resnet_layer(inputs=inputs,
                     num_filters=num_filters_in,
                     conv_first=True)

    # stworzenie instancji stosu jednostek resztkowych
    for stage in range(3):
        for res_block in range(num_res_blocks):
            activation = 'relu'
            batch_normalization = True
            strides = 1
            if stage == 0:
                num_filters_out = num_filters_in * 4
                # pierwsza warstwa i pierwszy stos
                if res_block == 0:
                    activation = None
                    batch_normalization = False
            else:
                num_filters_out = num_filters_in * 2
                # pierwsza warstwa (ale nie stos)
                if res_block == 0:
                    # zmniejszanie rozmiaru/próbkowanie w dół (ang. downsample)
                    strides = 2

            # resztkowa jednostka ograniczająca
            y = resnet_layer(inputs=x,
                             num_filters=num_filters_in,
                             kernel_size=1,
                             strides=strides,
                             activation=activation,
                             batch_normalization=batch_normalization,
                             conv_first=False)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters_in,
                             conv_first=False)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters_out,
                             kernel_size=1,
                             conv_first=False)
            if res_block == 0:
                # projekcja liniowa skrótu resztkowego, połączenie — dopasowanie zmienionych wymiarów
                x = resnet_layer(inputs=x,
                                 num_filters=num_filters_out,
                                 kernel_size=1,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            x = add([x, y])

        num_filters_in = num_filters_out

    # dodanie klasyfikatora na górze, v2 ma BN-ReLU przed warstwą łączącą
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = AveragePooling2D(pool_size=8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes,
                    activation='softmax',
                    kernel_initializer='he_normal')(y)

    # stworzenie instancji modelu
    model = Model(inputs=inputs, outputs=outputs)
    return model



if version == 2:
    model = resnet_v2(input_shape=input_shape, depth=depth)
else:
    model = resnet_v1(input_shape=input_shape, depth=depth)

model.compile(loss='categorical_crossentropy',
              optimizer=Adam(lr=lr_schedule(0)),
              metrics=['acc'])
model.summary()

# włącz jesli możesz uzyć pydot
# pip install pydot
#plot_model(model, to_file="%s.png" % model_type, show_shapes=True)
print(model_type)

# przygotowanie katalogu do zapisania modeli
save_dir = os.path.join(os.getcwd(), 'zapisane_modele')
model_name = 'cifar10_%s_model.{epoch:03d}.h5' % model_type
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
filepath = os.path.join(save_dir, model_name)

# przygotowanie wywołań zwrotnych 
# do zapisywania modeli i dopasowań współczynnika uczenia
checkpoint = ModelCheckpoint(filepath=filepath,
                             monitor='val_acc',
                             verbose=1,
                             save_best_only=True)

lr_scheduler = LearningRateScheduler(lr_schedule)

lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                               cooldown=0,
                               patience=5,
                               min_lr=0.5e-6)

callbacks = [checkpoint, lr_reducer, lr_scheduler]

# uruchomienie uczenia z lub bez dogenerowania danych
if not data_augmentation:
    print('Bez dogenerowania danych.')
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_data=(x_test, y_test),
              shuffle=True,
              callbacks=callbacks)
else:
    print('Dogenerowanie danych rzeczywistych.')
    # wykonanie wstępnego przetwarzania i dogenerowania danych rzeczywistych
    datagen = ImageDataGenerator(
        # ustawienie średniej dla zbioru wejściowego na 0
        featurewise_center=False,
        # ustawienie średniej dla każdej próbki na 0
        samplewise_center=False,
        # podział wejść wg odchylenia standardowego zbioru
        featurewise_std_normalization=False,
        # podziałkazdego wejścia według jego odchylenia standardowego
        samplewise_std_normalization=False,
        # zastosowanie wybielania ZCA
        zca_whitening=False,
        # obrót obrazów o kąt z losowego zakresu (0, 180) stopni
        rotation_range=0,
        # przypadkowe przesunięcie obrazu w poziomie
        width_shift_range=0.1,
        # przypadkowe przesuniecie obrazu w pionie
        height_shift_range=0.1,
        # przypadkowe odbicie w poziomie
        horizontal_flip=True,
        # przypadkowe odbicie w pionie
        vertical_flip=False)

    # obliczenia wartości wymaganych do normalizacji w obrebie cech
    # (odchylenie standardowe, średnia 
    # i składowe główne jeśli stosujemy wybielanie ZAC).
    datagen.fit(x_train)

    steps_per_epoch =  math.ceil(len(x_train) / batch_size)
    # dopasowanie modelu na próbkach generowanych przez datagen.flow().
    model.fit(x=datagen.flow(x_train, y_train, batch_size=batch_size),
              verbose=1,
              epochs=epochs,
              validation_data=(x_test, y_test),
              steps_per_epoch=steps_per_epoch,
              callbacks=callbacks)


# punktacja wyuczonego modelu
scores = model.evaluate(x_test,
                        y_test,
                        batch_size=batch_size,
                        verbose=0)
print('Strata na zbiorze testowym:', scores[0])
print('Dokładność na zbiorze testowym:', scores[1])
