# -*- coding: utf-8 -*-

''' Klasyfikacja cyfr MNIST z użyciem CNN

Trójwarstwowa CNN do klasyfikacji cyfr MNIST
Pierwsze dwie warstwy to Conv2D-ReLU-MaxPool
Trzecia warstwa to Conv2D-ReLU-Dropout
Czwarta warstwa to Dense(10)
Funkcja aktywacji wyjscia to softmax
Optymalizator to Adam

99.4% dokładność na zbiorze testowym po 10 epokach

https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Keras
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.datasets import mnist

# załadowanie zbioru MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# zliczenie liczby etykiet
num_labels = len(np.unique(y_train))

# konwersja na wektor „jeden-aktywny” (OH)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# wymiary obrazów wejściowych
image_size = x_train.shape[1]
# zmiana rozmiaru i normalizacja
x_train = np.reshape(x_train,[-1, image_size, image_size, 1])
x_test = np.reshape(x_test,[-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# parametry sieci
# obraz jest przetwarzany "tak-jak-jest" (kwadrat, skala szarości)
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
pool_size = 2
filters = 64
dropout = 0.2

# model jest stosem warstw CNN-ReLU-MaxPooling
model = Sequential()
model.add(Conv2D(filters=filters,
                 kernel_size=kernel_size,
                 activation='relu',
                 input_shape=input_shape))
model.add(MaxPooling2D(pool_size))
model.add(Conv2D(filters=filters,
                 kernel_size=kernel_size,
                 activation='relu'))
model.add(MaxPooling2D(pool_size))
model.add(Conv2D(filters=filters,
                 kernel_size=kernel_size,
                 activation='relu'))
model.add(Flatten())
# pomijanie jako regularyzacja
model.add(Dropout(dropout))
# warstwa wyjściowa jest 10-wymiarowym wektorem OH
model.add(Dense(num_labels))
model.add(Activation('softmax'))
model.summary()
plot_model(model, to_file='cnn-mnist.png', show_shapes=True)

# funkcja straty dla wektora OH, optymalizator Adam
# dokładność jest odpowiednią miarą do oceny jakości klasyfikatora
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
# trenowanie sieci
model.fit(x_train, y_train, epochs=10, batch_size=batch_size)

_, acc = model.evaluate(x_test,
                        y_test,
                        batch_size=batch_size,
                        verbose=0)
print("\nDokładność na zbiorze testowym: %.1f%%" % (100.0 * acc))