"""Konstruktor modelu ResNet jako szkieletu
Zaadaptowane z Rozdziału 2 Głębokie sieci neuronowe

ResNet v1
[a] Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf

ResNet v2
[b] Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf

TODO: Połączyć z kodem do detekcji obiektów
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from tensorflow.keras.layers import Dense, Conv2D
from tensorflow.keras.layers import BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input, Flatten
from tensorflow.keras.layers import Add
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
import numpy as np

from model import conv_layer


#           |      | 200-epok    | Oryg Artykuł| 200-epok   | Oryg Artykuł   | sek/epokę
# Model     |  n   | ResNet v1   | ResNet v1   | ResNet v2   | ResNet v2     | GTX1080Ti
#           |v1(v2)| %Dokładność | %Dokładność | %Dokładność | %Dokładność   | v1 (v2)
# ------------------------------------------------------------------------------
# ResNet20  | 3 (2)| 92.16       | 91.25       | -----       | -----         | 35 (---)
# ResNet32  | 5(NA)| 92.46       | 92.49       | NA          | NA            | 50 ( NA)
# ResNet44  | 7(NA)| 92.50       | 92.83       | NA          | NA            | 70 ( NA)
# ResNet56  | 9 (6)| 92.71       | 93.03       | 93.01       | NA            | 90 (100)
# ResNet110 |18(12)| 92.65       | 93.39+-.16  | 93.15       | 93.63         | 165(180)
# ResNet164 |27(18)| -----       | 94.07       | -----       | 94.54         | ---(---)
# ResNet1001| (111)| -----       | 92.39       | -----       | 95.08+-.14    | ---(---)
# ---------------------------------------------------------------------------

def resnet_layer(inputs,
                 num_filters=16,
                 kernel_size=3,
                 strides=1,
                 activation='relu',
                 batch_normalization=True,
                 conv_first=True):
    """Konstruktor stosu warstw 2D Convolution-Batch Normalization-Activation

    Argumenty:
        inputs (tensor): tensor wejściowy z obrazu wejściowego 
        lub z poprzedniej warstwy
        num_filters (int): liczba filtrów Conv2D
        kernel_size (int): wymiary kwadratowego jądra Conv2D
        strides (int): wymiary kroku Conv2D
        activation (string): nazwa funkcji aktywacji
        batch_normalization (bool): czy wykonać normalizację wsadową
        conv_first (bool): conv-bn-activation (True) or 
                           bn-activation-conv (False)

    Zwraca:
        x (tensor): tensor jako wejście dla następnej warstwy
    """
    conv = Conv2D(num_filters,
                  kernel_size=kernel_size,
                  strides=strides,
                  padding='same',
                  kernel_initializer='he_normal',
                  kernel_regularizer=l2(1e-4))

    x = inputs
    if conv_first:
        x = conv(x)
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
    else:
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
        x = conv(x)
    return x


def resnet_v1(input_shape, depth, num_classes=10):
    """Konstruktor sieci ResNet w wersji 1 [a]

    Stos 2 x (3x3) Conv2D-BN-ReLU
    Ostatnia warstwa ReLU jest po połączeniu skrótowym.
    Na początku każdego etapu rozmiar mapy cech jest zmniejszany o połowę
    (ang. downsampled) przez warstwę splotową z krokiem 2 tak długo, jak długo 
    jest spełniony warunek. Na każdym etapie warstwy mają taką samą liczbę filtrów.
    Rozmiar mapy cech:
    etap 0: 32x32, 16
    etap 1: 16x16, 32
    etap 2: 8x8, 64
    Liczba parametrów jest w przybliżeniu równa pokazanym w tabeli 6 dla [a]:
    ResNet20 0.27M
    ResNet32 0.46M
    ResNet44 0.66M
    ResNet56 0.85M
    ResNet110 1.7M

    Argumenty:
        input_shape (tensor): postać tensora obrazu wejściowego
        depth (int): liczba warstw splotowego jądra bazowego
        num_classes (int): liczba klas (CIFAR10 ma 10 klas)

    Zwraca:
        model (Model): instancję modelu Keras
    """
    if (depth - 2) % 6 != 0:
        raise ValueError('głębokość powinna wynosić 6n+2 (tzn. 20, 32, w [a])')
    # początek definicji modelu
    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = Input(shape=input_shape)
    x = resnet_layer(inputs=inputs)
    # stworzenie instancji stosu jednostek resztkowych
    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0:              # pierwsza warstwa (ale nie stos)
                strides = 2   # zmniejszanie rozmiaru (ang. downsample)
            y = resnet_layer(inputs=x,
                             num_filters=num_filters,
                             strides=strides)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters,
                             activation=None)
            if stack > 0 and res_block == 0:             # pierwsza warstwa (ale nie stos)
                # projekcja liniowa skrótu resztkowego
                # połączenie — dopasowanie zmienionych wymiarów
                x = resnet_layer(inputs=x,
                                 num_filters=num_filters,
                                 kernel_size=1,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            x = Add()([x, y])
            x = Activation('relu')(x)
        num_filters *= 2

    # mapa cech
    outputs = features_pyramid(x, n_layers)
    

    # stworzenie instancji modelu
    name = 'ResNet%dv1' % (depth)
    model = Model(inputs=inputs,
                  outputs=outputs,
                  name=name)
    return model


def resnet_v2(input_shape, depth, n_layers=4):
    """Konstruktor sieci ResNet w wersji 2 [b]

    Stos warstw BN-ReLU-Conv2D (1x1)-(3x3)-(1x1) 
    również znanych jako warstwa ograniczająca.
    Pierwsze połączenie skrótowe na warstwę to 1x1 Conv2D.
    Drugie i następne połączenia są identyczne.
    Na początku każdego etapu rozmiar mapy cech jest zmniejszany o połowę
    (ang. downsampled) przez warstwę splotową z krokiem 2 tak długo, jak długo 
    jest spełniony warunek. Na każdym etapie warstwy mają taką samą liczbę filtrów  taki sam rozmiar filtrów map.
    Rozmiary map cech:
    conv1  : 32x32,   16
    stage 0: 32x32,   64
    stage 1: 16x16, 128
    stage 2: 8x8,    256

    Argumenty:
        input_shape (tensor): postać tensora obrazu wejściowego
        depth (int): liczba warstw splotowego jądra bazowego
        num_classes (int): liczba klas (CIFAR10 ma 10 klas)

    Zwraca:
        model (Model): instancję modelu Keras
    """
    if (depth - 2) % 9 != 0:
        raise ValueError('głębokość powinna byc 9n+2 (tzn. 110 w [b])')
    # początek definicji modelu
    num_filters_in = 16
    num_res_blocks = int((depth - 2) / 9)

    inputs = Input(shape=input_shape)
    # w v2 występuje Conv2D z BN-ReLU na wejściu przed podziałem na dwie ścieżki
    x = resnet_layer(inputs=inputs,
                     num_filters=num_filters_in,
                     conv_first=True)

    # stworzenie instancji stosu jednostek resztkowych
    for stage in range(3):
        for res_block in range(num_res_blocks):
            activation = 'relu'
            batch_normalization = True
            strides = 1
            if stage == 0:
                num_filters_out = num_filters_in * 4
                if res_block == 0: 
                    activation = None
                    batch_normalization = False
            else:
                num_filters_out = num_filters_in * 2
                if res_block == 0:  
                    strides = 2  

            # resztkowa jednostka ograniczająca
            y = resnet_layer(inputs=x,
                             num_filters=num_filters_in,
                             kernel_size=1,
                             strides=strides,
                             activation=activation,
                             batch_normalization=batch_normalization,
                             conv_first=False)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters_in,
                             conv_first=False)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters_out,
                             kernel_size=1,
                             conv_first=False)
            if res_block == 0:
                # projekcja liniowa skrótu resztkowego, połączenie — dopasowanie zmienionych wymiarów
                x = resnet_layer(inputs=x,
                                 num_filters=num_filters_out,
                                 kernel_size=1,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            x = Add()([x, y])

        num_filters_in = num_filters_out

    # v2 ma BN-ReLU przed warstwą łączącą
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    # pierwsza mapa cech

    #główna mapa cech(160, 120)
    # następne mapy cech sa skalowane w dół:  2, 4, 8
    outputs = features_pyramid(x, n_layers)

   # stworzenie instancji modelu
    name = 'ResNet%dv2' % (depth)
    model = Model(inputs=inputs,
                  outputs=outputs,
                  name=name)
    return model


def features_pyramid(x, n_layers):
    """Generowanie piramidy cech z wyjścia 
    ostatniej warstwy sieci szkieletowej (ResNetv1 lub v2)

    Argumenty:
        x (tensor): wyjściowa mapa cech sieci szkieletowej
        n_layers (int): liczba dodatkowych warstw piramidy

    Zwraca:
        outputs (list): piramida cech 
    """
    outputs = [x]
    conv = AveragePooling2D(pool_size=2, name='pool1')(x)
    outputs.append(conv)
    prev_conv = conv
    n_filters = 512

    # dodatkowe warstwy mapy cech
    for i in range(n_layers - 1):
        postfix = "_layer" + str(i+2)
        conv = conv_layer(prev_conv,
                          n_filters,
                          kernel_size=3,
                          strides=2,
                          use_maxpool=False,
                          postfix=postfix)
        outputs.append(conv)
        prev_conv = conv

    return outputs
    

def build_resnet(input_shape,
                 n_layers=4,
                 version=2,
                 n=6):
    """Konstruowanie szkieletu ResNet

    # Argumenty:
        input_shape (list): Rozmiar i kanały obrazu wejściowego
        n_layers (int): Liczba warstw cech 
        version (int): Obsługuje ResNetv1 oraz v2 (domyślnie v2)
        n (int): Określa liczbę warstw ResNet (domyślnie ResNet50)

    # Zwraca
        model (Keras Model)

    """
    # obliczona głębokość z dostarczonego modelu (parametr n)
    if version == 1:
        depth = n * 6 + 2
    elif version == 2:
        depth = n * 9 + 2

    # nazwa modelu, głębokość i wersja
    # input_shape (h, w, 3)
    if version==1:
        model = resnet_v1(input_shape=input_shape,
                          depth=depth,
                          n_layers=n_layers)
    else:
        model = resnet_v2(input_shape=input_shape,
                          depth=depth,
                          n_layers=n_layers)
    return model


if __name__ == '__main__':
    from model_utils import parser
    parser = parser()
    args = parser.parse_args()
    # postać danych wejściowych to domyślnie (480, 640, 3)
    input_shape = (args.height,
                   args.width,
                   args.channels)

    backbone = build_resnet(input_shape,
                            n_layers=args.layers)
    backbone.summary()
