'''
Sieć RNN do klasyfikacji cyfr MNIST
Dokładność 98.3% na zbiorze testowym po 20 epokach
https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Keras
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, SimpleRNN
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.datasets import mnist

# załaduj zbiór MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# zlicz liczbę etykiet
num_labels = len(np.unique(y_train))

# konwersja na wektor OH
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# zmiana rozmiaru i normalizacja
image_size = x_train.shape[1]
x_train = np.reshape(x_train,[-1, image_size, image_size])
x_test = np.reshape(x_test,[-1, image_size, image_size])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# parametry sieci
input_shape = (image_size, image_size)
batch_size = 128
units = 256
dropout = 0.2

# model to sieć RNN z 256 jednostkami, wejście to 28-wymiarowy wektor
# 28 kroków czasowych
model = Sequential()
model.add(SimpleRNN(units=units,
                    dropout=dropout,
                    input_shape=input_shape))
model.add(Dense(num_labels))
model.add(Activation('softmax'))
model.summary()
plot_model(model, to_file='rnn-mnist.png', show_shapes=True)

# funkcja straty dla wektora OH
# sgd jako optymalizator
# dokładność jest odpowiednią miarą oceny jakości klasyfikatora
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])
# uczenie sieci
model.fit(x_train, y_train, epochs=20, batch_size=batch_size)

_, acc = model.evaluate(x_test,
                        y_test,
                        batch_size=batch_size,
                        verbose=0)
print("\nDokładność na zbiorze testowym: %.1f%%" % (100.0 * acc)) 
