'''Implementacja sieci Y z użyciem funkcyjnego API

~99.3% dokładność na zbiorze testowym
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Flatten, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.utils import plot_model

# załadowanie zbioru MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# zmiana etykiet z rzadkich na kategorie
num_labels = len(np.unique(y_train))
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# zmiana postaci (ang. shape)  i normalizacja obrazów wejściowych
image_size = x_train.shape[1]
x_train = np.reshape(x_train,[-1, image_size, image_size, 1])
x_test = np.reshape(x_test,[-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# parametry sieci
input_shape = (image_size, image_size, 1)
batch_size = 32
kernel_size = 3
dropout = 0.4
n_filters = 32

# lewa gałąź sieci Y
left_inputs = Input(shape=input_shape)
x = left_inputs
filters = n_filters
# 3 zestawy warstw Conv2D-Dropout-MaxPooling2D
# liczba filtrów podwaja się po każdym zestawie warstw (32-64-128)
for i in range(3):
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               padding='same',
               activation='relu')(x)
    x = Dropout(dropout)(x)
    x = MaxPooling2D()(x)
    filters *= 2

# prawa gałąź sieci Y
right_inputs = Input(shape=input_shape)
y = right_inputs
filters = n_filters
# 3 zestawy warstw Conv2D-Dropout-MaxPooling2D
# liczba filtrów podwaja się po każdym zestawie warstw (32-64-128)
for i in range(3):
    y = Conv2D(filters=filters,
               kernel_size=kernel_size,
               padding='same',
               activation='relu',
               dilation_rate=2)(y)
    y = Dropout(dropout)(y)
    y = MaxPooling2D()(y)
    filters *= 2

# łączymy wyjścia obu gałęzi 
y = concatenate([x, y])
# przekształcenie mapy cech na wektor przed połączeniem do warstwy gęstej
# - spłaszczenie
y = Flatten()(y)
y = Dropout(dropout)(y)
outputs = Dense(num_labels, activation='softmax')(y)

# konstruowanie modelu przez podanie wejść/wyjść korzystając z funkcyjnego API
model = Model([left_inputs, right_inputs], outputs)

# weryfikacja struktury modely w postaci graficznej
# włącz jeśli możesz zainstalować pydot
# pip install pydot
#plot_model(model, to_file='cnn-y-network.png', show_shapes=True)

# weryfikacja budowy modelu z użyciem opisu warstw w formie tekstu
model.summary()

# funkcja straty klasyfikatora, optymalizacja Adam, dokładność
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

# uczenie modelu z obrazami wejściowymi i etykietami
model.fit([x_train, x_train],
          y_train, 
          validation_data=([x_test, x_test], y_test),
          epochs=20,
          batch_size=batch_size)

# dokładność modelu na zbiorze testowym
score = model.evaluate([x_test, x_test],
                       y_test,
                       batch_size=batch_size,
                       verbose=0)
print("\nDokładność na zbiorze testowym: %.1f%%" % (100.0 * score[1]))
