# -*- coding: utf-8 -*-

'''Trening odszumiającej sieci autokodującej dla zbioru MNIST

Usuwanie zakłóceń to jedno z klasycznych zastosowań autokoderów
Proces odszumiania umożliwia usunięcie niepożądanego szumu zakłócającego dane

Zakłócenia+dane ----> autokodująca sieć odszumiająca (DAE) ---> Dane

Mając dany zbiór treningowy z uszkodzonymi danymi jako sygnały wejściowe 
i danymi bez zakłóceń jako pożądane (wyjściowe), DAE może odtworzyć 
ukrytą strukturę zakłóceń aby odczyścić dane.

Ten przykład ma budowę modułową. Koder, dekoder i autokoder są trzema modelami
współdzielącymi wagi. Na przykład po trenowaniu autokodera, koder może zostać 
wykorzystany do wygenerowania wektora niejawnego dla danych wejściowych 
do niskowymiarowej wizualizacji lub przekształceń typu PCA lub TSNE.
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.layers import Conv2D, Flatten
from tensorflow.keras.layers import Reshape, Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

np.random.seed(1337)

# załaduj zbiór danych MNIST
(x_train, _), (x_test, _) = mnist.load_data()

# zmień kształt na (28, 28, 1) i znormalizuj obrazy wejściowe 
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# generuj uszkodzone obrazy MNIST przez dodanie szumu w postaci 
# rozkładu normalnego o środku w 0.5 i std=0.5 
noise = np.random.normal(loc=0.5, scale=0.5, size=x_train.shape)
x_train_noisy = x_train + noise
noise = np.random.normal(loc=0.5, scale=0.5, size=x_test.shape)
x_test_noisy = x_test + noise

# dodanie szumu może przekroczyć znormalizowane wartości pikseli
# >1.0 lub <0.0, więc wartości pikseli przycinaj: >1.0 do 1.0 i <0.0 do 0.0
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)

# parametry sieci
input_shape = (image_size, image_size, 1)
batch_size = 32
kernel_size = 3
latent_dim = 16
# liczba warstw CNN i filtrów na warstwę kodera i dekodera
layer_filters = [32, 64]

# budowanie modelu autokodera
# jako pierwszy budujemy koder (ang. encoder)
inputs = Input(shape=input_shape, name='koder_wejście')
x = inputs

# stos warstw Conv2D(32)-Conv2D(64)
for filters in layer_filters:
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=2,
               activation='relu',
               padding='same')(x)

# W celu zbudowania sieci autokodującej potrzebne są informacje 
# o kształcie (ang. shape) danych. By nie wykonywać obliczeń ręcznie
# najpierw wejście dekodera Conv2DTranspose będzie miało kształt
# (7, 7, 64), zostanie przekształcony z powrotem przez dekoder do
# (28, 28, 1)
shape = K.int_shape(x)

# generowanie wektora niejawnego
x = Flatten()(x)
latent = Dense(latent_dim, name='wektor_niejawny')(x)

# tworzenie instancji modelu kodera
encoder = Model(inputs, latent, name='Koder')
encoder.summary()

# budowanie modelu dekodera
latent_inputs = Input(shape=(latent_dim,), name='dekoder_wejście')
# używamy wcześniej zachowanego kształtu (7, 7, 64)
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)
# z wektora do kształtu odpowiedniego dla transponowanej conv
x = Reshape((shape[1], shape[2], shape[3]))(x)

# stos warstw Conv2DTranspose(64)-Conv2DTranspose(32)
for filters in layer_filters[::-1]:
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        strides=2,
                        activation='relu',
                        padding='same')(x)

# rekonstrukcja odszumionych sygnałów wejściowych
outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          padding='same',
                          activation='sigmoid',
                          name='Dekoder_wy')(x)

# stworzenie instancji modelu dekodera
decoder = Model(latent_inputs, outputs, name='Dekoder')
decoder.summary()

# sieć autokodująca = enkoder+dekoder
# utworzenie instancji sieci autokodującej
autoencoder = Model(inputs, decoder(encoder(inputs)), name='Autokoder')
autoencoder.summary()

# funkcja straty jako błąd średniokwadratowy (ang. Mean Square Error),
# optymalizator Adam
autoencoder.compile(loss='mse', optimizer='adam')

# trenowanie sieci autokodującej
autoencoder.fit(x_train_noisy,
                x_train,
                validation_data=(x_test_noisy, x_test),
                epochs=10,
                batch_size=batch_size)

# prognozowanie rezultatu działania sieci autokodującej na 
# testowych danych uszkodzonych 
x_decoded = autoencoder.predict(x_test_noisy)

# 3 zbiory obrazów z 9 cyframi MNIST
# Pierwszy wiersz = dane oryginalne
# Drugi wiersz = dane uszkodzone zakłóceniami
# Trzeci wiersz = dane odszumione
rows, cols = 3, 9
num = rows * cols
imgs = np.concatenate([x_test[:num], x_test_noisy[:num], x_decoded[:num]])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs, rows, axis=1))
imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
imgs = (imgs * 255).astype(np.uint8)
plt.figure()
plt.axis('off')
plt.title('Oryginalne obrazy: górne wiersze, '
          'Uszkodzone: środkowe wiersze, '
          'Odszumione: dolne wiersze')
plt.imshow(imgs, interpolation='none', cmap='gray')
Image.fromarray(imgs).save('uszkodzone i odszumione.png')
Image.fromarray(imgs).save('uszkodzone i odszumione.tif')
plt.show()
