''' CNN przy użyciu funkcyjnego API

~99.3% dokładność na zbiorze testowym
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical


# załadowanie zbioru MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# zmiana etykiet z rzadkich na kategorie
num_labels = len(np.unique(y_train))
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# zmiana postaci (ang. shape) i normalizacja obrazów wejściowych
image_size = x_train.shape[1]
x_train = np.reshape(x_train,[-1, image_size, image_size, 1])
x_test = np.reshape(x_test,[-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# parametry sieci
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
filters = 64
dropout = 0.3

# konstruowanie warstw CNN z użyciem funkcyjnego API
inputs = Input(shape=input_shape)
y = Conv2D(filters=filters,
           kernel_size=kernel_size,
           activation='relu')(inputs)
y = MaxPooling2D()(y)
y = Conv2D(filters=filters,
           kernel_size=kernel_size,
           activation='relu')(y)
y = MaxPooling2D()(y)
y = Conv2D(filters=filters,
           kernel_size=kernel_size,
           activation='relu')(y)
# przekształcenie obrazu na wektor przed połączeniem do warstwy gęstej — spłaszczenie
y = Flatten()(y)
# regularyzacja — pomijanie
y = Dropout(dropout)(y)
outputs = Dense(num_labels, activation='softmax')(y)

# konstruowanie modelu przez podanie wejść/wyjść
model = Model(inputs=inputs, outputs=outputs)
# model w formie tekstowej
model.summary()

# funkcja straty klasyfikatora, optymalizacja Adam, dokładność
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

# uczenie modelu z obrazami wejściowymi i etykietami
model.fit(x_train,
          y_train,
          validation_data=(x_test, y_test),
          epochs=20,
          batch_size=batch_size)

# dokładność modelu na zbiorze testowym
score = model.evaluate(x_test,
                       y_test,
                       batch_size=batch_size,
                       verbose=0)
print("\nDokładność na zbiorze testowym: %.1f%%" % (100.0 * score[1]))