"""Konstruktor modelu SSD
Dostarcza również narzędzi do konstruowania warstw sieci
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from tensorflow.keras.layers import Activation, Dense, Input
from tensorflow.keras.layers import Conv2D, Flatten
from tensorflow.keras.layers import BatchNormalization, Concatenate
from tensorflow.keras.layers import ELU, MaxPooling2D, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

import numpy as np

def conv2d(inputs,
           filters=32,
           kernel_size=3,
           strides=1,
           name=None):

    conv = Conv2D(filters=filters,
                  kernel_size=kernel_size,
                  strides=strides,
                  kernel_initializer='he_normal',
                  name=name,
                  padding='same')

    return conv(inputs)


def conv_layer(inputs,
               filters=32,
               kernel_size=3,
               strides=1,
               use_maxpool=True,
               postfix=None,
               activation=None):

    x = conv2d(inputs,
               filters=filters,
               kernel_size=kernel_size,
               strides=strides,
               name='conv'+postfix)
    x = BatchNormalization(name="bn"+postfix)(x)
    x = ELU(name='elu'+postfix)(x)
    if use_maxpool:
        x = MaxPooling2D(name='pool'+postfix)(x)
    return x


def build_ssd(input_shape,
              backbone,
              n_layers=4,
              n_classes=4,
              aspect_ratios=(1, 2, 0.5)):
    """Konstruowanie modelu SSD, mając szkielet

    Argumenty:
        input_shape (list): kształt obrazu wejściowego
        backbone (model): model szkieletu w Keras
        n_layers (int): liczba warstw czoła warstw ssd
        n_classes (int): liczba klas obiektów
        aspect_ratios (list): proporcje pól zakotwiczenia

    Zwraca:
        n_anchors (int): liczba pól zakotwiczenia na punkt mapy cech 
        feature_shape (tensor): mapa cech czoła sieci SSD
        model (Keras model): model SSD
    """
    # liczba pól zakotwiczenia na punkt mapy cech

    n_anchors = len(aspect_ratios) + 1

    inputs = Input(shape=input_shape)
    # liczba base_outputs zależna od n_layers
    base_outputs = backbone(inputs)
    
    outputs = []
    feature_shapes = []
    out_cls = []
    out_off = []

    for i in range(n_layers):
        # użyta jest każda warstwa splotowa z sieci szkieletowej
        # jako mapa cech dla predykcji klasy i przesunięć 
        # również znane jako predykcja wielkoskalowa
        conv = base_outputs if n_layers==1 else base_outputs[i]
        name = "cls" + str(i+1)
        classes  = conv2d(conv,
                          n_anchors*n_classes,
                          kernel_size=3,
                          name=name)

        # przesunięcie: (partia, wysokość, szerokość, liczba_kotwic·4)
        name = "off" + str(i+1)
        offsets  = conv2d(conv,
                          n_anchors*4,
                          kernel_size=3,
                          name=name)

        shape = np.array(K.int_shape(offsets))[1:]
        feature_shapes.append(shape)

        # zmiana formy predykcji klas, dając tensory 3D kształtów 
        # (partia, wysokość·szerokość·liczba_kotwic, liczba_klas)
        # na ostatnim wykonujemy softmax
        name = "cls_res" + str(i+1)
        classes = Reshape((-1, n_classes), 
                          name=name)(classes)

        # zmiana kształtu predykcji przesunięcia, dając tensory 3D o kształcie
        # (partia, wysokość·szerokość·liczba_kotwic,, 4)
        # na ostatnim obliczamy stratę: (gładką) L1 lub L2 
        name = "off_res" + str(i+1)
        offsets = Reshape((-1, 4),
                          name=name)(offsets)
        # złączenie — do wyrównania z rozmiarem referencyjnym
        # obliczonym z wartości referencyjnej przesunięć i maski o tym samym wymiarze
        # potrzebnym podczas obliczania straty
        offsets = [offsets, offsets]
        name = "off_cat" + str(i+1)
        offsets = Concatenate(axis=-1,
                              name=name)(offsets)

        # zebranie predykcji przesunięć na skalę
        out_off.append(offsets)

        name = "cls_out" + str(i+1)

        #activation = 'sigmoid' if n_classes==1 else 'softmax'
        #print("Activation:", activation)

        classes = Activation('softmax',
                             name=name)(classes)

        # zebranie predykcji klas na skalę
        out_cls.append(classes)

    if n_layers > 1:
        # połączenie wszystkich klas i przesunięć z każdej skali 
        name = "offsets"
        offsets = Concatenate(axis=1,
                              name=name)(out_off)
        name = "classes"
        classes = Concatenate(axis=1,
                              name=name)(out_cls)
    else:
        offsets = out_off[0]
        classes = out_cls[0]

    outputs = [classes, offsets]
    model = Model(inputs=inputs,
                  outputs=outputs,
                  name='ssd_head')

    return n_anchors, feature_shapes, model
