"""Klasa FCN do zbudowania, uczenia i oceny modelu FCN 
    do segmentacji semantycznej. 

1) Szkielet ResNet50 (v2).
    Uczenie z 6 warstwami map cech.
    Proszę dostosuj rozmiar próbki w zależności od pamięci Twojego GPU.
    dla 1060 z 6GB, --batch-size=1. Dla V100 z 32GB, 
    --batch-size=4

python3 fcn-12.3.1.py --train --batch-size=4

2)  Szkielet ResNet50 (v2).
    Trenowanie z uprzednio zachowanego modelu:

python3 fcn-12.3.1.py --restore-weights=ResNet56v2-3layer-drinks-200.h5 \
        --train --batch-size=4

3)  Szkielet ResNet50 (v2).
    Ocena:

python3 fcn-12.3.1.py --restore-weights=ResNet56v2-3layer-drinks-200.h5 \
        --evaluate --image-file=dataset/drinks/0010018.jpg

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import LearningRateScheduler

import os
import skimage
import numpy as np

from data_generator import DataGenerator
from model_utils import parser, lr_scheduler
os.sys.path.append("../lib")
from common_utils import print_log, AccuracyCallback
from model import build_fcn
from skimage.io import imread


class FCN:
    """Klasa złożona z modelu FCN i generatora danych.
    Definuje funkcje to uczenia i walidowania modelu FCN.

    Argumenty:
        args: konfiguracja zdefiniowana przez użytkownika

    Attrybuty:
        fcn (model): Model sieci FCN
        train_generator: Wielowątkowy generator danych do uczenia
    """
    def __init__(self, args):
        """Kopia konfiguracji zdefiniowanej przez użytkownika.
            Konstruowanie szkieletu i modelu sieci FCN.
        """
        self.args = args
        self.fcn = None
        self.train_generator = DataGenerator(args)
        self.build_model()
        self.eval_init()


    def build_model(self):
        """Konstruowanie sieci szkieletowej i jej użycie 
        w celu wykonania segmentacji semantycznej. 
        Bazą sieci szkieletowej jest sieć FCN.
        """
        
        # wejściowe wymiary to domyślnie (480, 640, 3)
        self.input_shape = (self.args.height, 
                            self.args.width,
                            self.args.channels)

        # konstruowanie sieci szkieletowej (np. ResNet50)
        # sieć szkieletowa jest używana do pierwszego zbioru cech piramidy cech
        self.backbone = self.args.backbone(self.input_shape,
                                           n_layers=self.args.layers)

        # użycie sieci szkieletowej, konstruowanie sieci fcn
        # warstwa wyjściowa jest klasyfikatorem dla wszystkich pikseli
        self.n_classes =  self.train_generator.n_classes
        self.fcn = build_fcn(self.input_shape,
                             self.backbone,
                             self.n_classes)


    def eval_init(self):
        """Porządkowanie ustawień do oceny uczonych modeli"""
        # model weights are saved for future validation
        # prepare model model saving directory.
        save_dir = os.path.join(os.getcwd(), self.args.save_dir)
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        model_name = self.backbone.name
        model_name += '-' + str(self.args.layers) + "layer-"
        model_name += self.args.dataset
        model_name += '-best-iou.h5'
        log = "Nazwa pliku z wagami: %s" % model_name
        print_log(log, self.args.verbose)
        self.weights_path = os.path.join(save_dir, model_name)
        self.preload_test()
        self.miou = 0
        self.miou_history = []
        self.mpla_history = []


    def preload_test(self):
        """Wstępne załadowanie zbioru testowego by zaoszczędzić czas"""
        path = os.path.join(self.args.data_path,
                            self.args.test_labels)

        # ground truth data is stored in an npy file
        self.test_dictionary = np.load(path,
                                       allow_pickle=True).flat[0]
        self.test_keys = np.array(list(self.test_dictionary.keys()))
        print_log("Załadowano %s" % path, self.args.verbose)


    def train(self):
        """Uczenie sieci FCN"""
        optimizer = Adam(lr=1e-3)
        loss = 'categorical_crossentropy'
        self.fcn.compile(optimizer=optimizer, loss=loss)

        log = "# of classes %d" % self.n_classes
        print_log(log, self.args.verbose)
        log = "Batch size: %d" % self.args.batch_size
        print_log(log, self.args.verbose)

        # przygotowanie wywołań zwrotnych do zapisywania wag modelu
        # i harmonogramu współczynnika uczenia
        # wagi modelu są zapisywane, gdy testowa IoU jest najwyższa
        # współczynnik uczenia zmniejsza się o 50% co 20 epok po 40. epoce

        accuracy = AccuracyCallback(self)
        scheduler = LearningRateScheduler(lr_scheduler)

        callbacks = [accuracy, scheduler]
        # uczenie sieci fcn
        self.fcn.fit(x=self.train_generator,
                     use_multiprocessing=False,
                     callbacks=callbacks,
                     epochs=self.args.epochs)
                     #workers=self.args.workers)


    def restore_weights(self):
        """Załadowanie wag uprzednio wytrenowanego modelu"""
        if self.args.restore_weights:
            save_dir = os.path.join(os.getcwd(), self.args.save_dir)
            filename = os.path.join(save_dir, self.args.restore_weights)
            log = "Ladowanie wag: %s" % filename
            print_log(log, self.args.verbose)
            self.fcn.load_weights(filename)


    def segment_objects(self, image, normalized=True):
        """Uruchomienie predykcji segmentacji dla zadanego obrazu
    
        Argumenty:
            image (tensor): Obraz ładowany do tensora numpy.
                 składowe RGB z zakresu [0.0, 1.0]
            normalized (Bool): Użyj normalized=True dla predykcji kategorii 
                pikseli; False jeżeli segmentacja ma byc wyświetlana
                w formie obrazu RGB.
        """

        from tensorflow.keras.utils import to_categorical
        image = np.expand_dims(image, axis=0)
        segmentation = self.fcn.predict(image)
        segmentation = np.squeeze(segmentation, axis=0)
        segmentation = np.argmax(segmentation, axis=-1)
        segmentation = to_categorical(segmentation,
                                      num_classes=self.n_classes)
        if not normalized:
            segmentation = segmentation * 255
        segmentation = segmentation.astype('uint8')
        return segmentation


    def evaluate(self, imagefile=None, image=None):
        """Wykonanie segmentacji na pliku o podanej nazwie 
            i wyświetlenie rezultatów.
        """
        import matplotlib.pyplot as plt
        save_dir = "prediction"
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        if image is not None:
            imagefile = os.path.splitext(imagefile)[0]
        elif self.args.image_file is not None:
            image = skimage.img_as_float(imread(self.args.image_file))
            imagefile = os.path.split(self.args.image_file)[-1]
            print("imagefile:", imagefile)
        else:
            raise ValueError("Musi być podany plik obrazu")

        maskfile = imagefile + "-mask.png"
        mask_path = os.path.join(save_dir, maskfile)
        inputfile = imagefile + "-input.png"
        input_path = os.path.join(save_dir, inputfile)
        segmentation = self.segment_objects(image,
                                            normalized=False)
        mask = segmentation[..., 1:]
        plt.xlabel('x')
        plt.ylabel('y')
        plt.title('Obraz wejsciowy', fontsize=14)
        plt.imshow(image)
        plt.savefig(input_path)
        #plt.show()

        plt.xlabel('x')
        plt.ylabel('y')
        plt.title('Segmentacja semantyczna', fontsize=14)
        plt.imshow(mask)
        plt.savefig(mask_path)
        #plt.show()


    def eval(self):
        """Ocena wytrenowanego modelu FCN z użyciem metryki 'średnia IoU'.
        """
        s_iou = 0
        s_pla = 0
        # oceń iou odnośnie obrazu testowego
        eps = np.finfo(float).eps
        for key in self.test_keys:
            # Załaduj obraz testowy
            image_path = os.path.join(self.args.data_path, key)
            image = skimage.img_as_float(imread(image_path))
            segmentation = self.segment_objects(image) 
            # Załaduj etykiety odniesienia dla obrazu testowego
            gt = self.test_dictionary[key]
            i_pla = 100 * (gt == segmentation).all(axis=(2)).mean()
            s_pla += i_pla
            
            i_iou = 0
            n_masks = 0
            # oblicz maskę dla każdego obiektu obrazu testowego, włącznie z tłem
            for i in range(self.n_classes):
                if np.sum(gt[..., i]) < eps: 
                    continue
                mask = segmentation[..., i]
                intersection = mask * gt[..., i]
                union = np.ceil((mask + gt[..., i]) / 2.0)
                intersection = np.sum(intersection) 
                union = np.sum(union) 
                if union > eps:
                    iou = intersection / union
                    i_iou += iou
                    n_masks += 1
            
            # średnia iou na obraz
            i_iou /= n_masks
            if not self.args.train:
                log = "%s: %d objs, miou=%0.4f ,pla=%0.2f%%"\
                      % (key, n_masks, i_iou, i_pla)
                print_log(log, self.args.verbose)

            # skumuluj iou dla wszystkich obrazów
            s_iou += i_iou
            if self.args.plot:
                self.evaluate(key, image)

        n_test = len(self.test_keys)
        m_iou = s_iou / n_test 
        self.miou_history.append(m_iou)
        np.save("miou_history.npy", self.miou_history)
        m_pla = s_pla / n_test
        self.mpla_history.append(m_pla)
        np.save("mpla_history.npy", self.mpla_history)
        if m_iou > self.miou and self.args.train:
            log = "\nNajlepsze stara mIoU=%0.4f, Najlepsza nowa mIoU=%0.4f, Dokładność na poziomie piksela=%0.2f%%"\
                    % (self.miou, m_iou, m_pla)
            print_log(log, self.args.verbose)
            self.miou = m_iou
            print_log("Zapisywanie wag... %s"\
                      % self.weights_path,\
                      self.args.verbose)
            self.fcn.save_weights(self.weights_path)
        else:
            log = "\nBieżąca mIoU=%0.4f, Dokładność na poziomie piksela=%0.2f%%"\
                    % (m_iou, m_pla)
            print_log(log, self.args.verbose)


    def print_summary(self):
        """Wydrukuj podsumowanie dla sieci (do debugowania)."""
        from tensorflow.keras.utils import plot_model
        if self.args.summary:
            self.backbone.summary()
            self.fcn.summary()
            plot_model(self.fcn,
                       to_file="fcn.png",
                       show_shapes=True)
            plot_model(self.backbone,
                       to_file="backbone.png",
                       show_shapes=True)


if __name__ == '__main__':
    parser = parser()
    args = parser.parse_args()
    fcn = FCN(args)

    if args.summary:
        fcn.print_summary()

    if args.restore_weights:
        fcn.restore_weights()

    if args.evaluate:
        if args.image_file is None:
            fcn.eval()
        else:
            fcn.evaluate()
            
    if args.train:
        fcn.train()
