# -*-coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.layers import Conv2D, Flatten
from tensorflow.keras.layers import Reshape, Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K

import numpy as np
import matplotlib.pyplot as plt

# załadowanie zbioru MNIST
(x_train, _), (x_test, _) = mnist.load_data()

# zmiana wymiarów na (28, 28, 1) oraz normalizacja obrazów wejściowych
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# parametry sieci
input_shape = (image_size, image_size, 1)
batch_size = 32
kernel_size = 3
latent_dim = 16

# Liczba liczba warstw CNN pełniacych funkcję enkoderów/dekoderów oraz liczba filtrów na warstwę
layer_filters = [32, 64]

# budowanie modelu sieci autokodującej
# jako pierwszy budujemy koder (ang. encoder)

inputs = Input(shape=input_shape, name='koder_wejscie')
x = inputs
# stos warstw Conv2D(32)-Conv2D(64)
for filters in layer_filters:
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation='relu',
               strides=2,
               padding='same')(x)

# w celu zbudowania sieci autokodującej potrzebne są informacje 
# o kształcie (ang. shape) danych. By nie wykonywać obliczeń ręcznie
# najpierw wejście dekodera Conv2DTranspose będzie miało kształt
# (7, 7, 64) - zostanie przekształcone z powrotem przez dekoder do
# (28, 28, 1)

shape = K.int_shape(x)

# generowanie wektora niejawnego (ang. latent)
x = Flatten()(x)
latent = Dense(latent_dim, name='wektor_niejawny')(x)

# tworzenie instancji modelu kodera
encoder = Model(inputs,
                latent,
                name='koder')
encoder.summary()
plot_model(encoder,
           to_file='koder.png',
           show_shapes=True)

# budowanie modelu dekodera
latent_inputs = Input(shape=(latent_dim,), name='dekoder_wejscie')
# używamy wcześniej zachowanego kształtu (7, 7, 64)
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)
# z wektora do kształtu odpowiedniego dla transponowanej conv
x = Reshape((shape[1], shape[2], shape[3]))(x)

# stos warstw Conv2DTranspose(64)-Conv2DTranspose(32)
for filters in layer_filters[::-1]:
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(x)

# rekonstrukcja sygnałów wejściowych
outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='dekoder_wyjście')(x)

# utworzenie instancji modelu dekodera
decoder = Model(latent_inputs, outputs, name='dekoder')
decoder.summary()
plot_model(decoder, to_file='dekoder.png', show_shapes=True)

# sieć autokodująca = enkoder+dekoder
# stworzenie instancji sieci autokodującej
autoencoder = Model(inputs,
                    decoder(encoder(inputs)),
                    name='autokoder')
autoencoder.summary()
plot_model(autoencoder,
           to_file='autokoder.png',
           show_shapes=True)

# funkcja straty jako błąd średniokwadratowy (ang. Mean Square Error), optymalizator Adam
autoencoder.compile(loss='mse', optimizer='adam')

# trenowanie sieci autokodującej
autoencoder.fit(x_train,
                x_train,
                validation_data=(x_test, x_test),
                epochs=1,
                batch_size=batch_size)

# prognozowanie rezultatu działania sieci autokodującej na danych testowych
x_decoded = autoencoder.predict(x_test)

# wyświetlenie pierwszych ośmiu wejściowych i odtworzonych obrazów
imgs = np.concatenate([x_test[:8], x_decoded[:8]])
imgs = imgs.reshape((4, 4, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title('Wejście: dwa pierwsze wiersze \nWyjście: dwa ostatnie wiersze')
plt.imshow(imgs, interpolation='none', cmap='gray')
plt.savefig('we_wy.png')
plt.show()
