"""Klasa SSD do konstruowania, uczenia i oceny sieci SSD

1)  Szkielet ResNet50 (v2).
    Trening z 6 warstwami map cech.
    Proszę dostosuj rozmiar partii w zależności od pamięci swojej GPU.
    Dla 1060 z 6GB, -b=1. Dla V100 z 32GB, -b=4

python3 ssd-11.6.1.py -t -b=4

2)  Szkielet ResNet50 (v2).
    Uczenie z wcześniej zapisanego modelu:

python3 ssd-11.6.1.py --restore-weights=saved_models/ResNet56v2_4-layer_weights-200.h5 -t -b=4

3)  Szkielet ResNet50 (v2).
    Ocena:

python3 ssd-11.6.1.py -e --restore-weights=saved_models/ResNet56v2_4-layer_weights-200.h5 \
        --image-file=dataset/drinks/0010000.jpg

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.losses import Huber

import layer_utils
import label_utils
import config

import os
import skimage
import numpy as np
import argparse

from skimage.io import imread
from data_generator import DataGenerator
from label_utils import build_label_dictionary
from boxes import show_boxes
from model import build_ssd
from loss import focal_loss_categorical, smooth_l1_loss, l1_loss
from model_utils import lr_scheduler, ssd_parser
from common_utils import print_log


class SSD:
    """Tworzenie modelu sieci SSD i generator zbioru danych.
    SSD definiuje funkcję do uczenia i walidacji 
    modelu sieci SSD

    Argumenty:
        args: konfiguracja użytkownika

    Atrybuty:
        ssd (model): model sieci SSD 
        train_generator: wielowątkowy generator danych uczących
    """
    def __init__(self, args):
        """kopiowanie konfiguracji zdefiniowanej przez użytkownika
        Tworzenie szkieletu i modelu sieci SSD
        """
        self.args = args
        self.ssd = None
        self.train_generator = None
        self.build_model()


    def build_model(self):
        """Konstruowanie szkieletu i modelu SSD"""
        # lista plików z obrazami i etykiety są przechowywane w słowniku
        self.build_dictionary()
        
        # wejściowe domyślne rozmiary to (480, 640, 3)
        self.input_shape = (self.args.height, 
                            self.args.width,
                            self.args.channels)

        # konstruowanie sieci szkieletowej (np. ResNet50)
        # liczba warstw cech jest równa n_layers
        # warstwy cech są wejściami dla czoła sieci SSD
        # dla predykcji klas i przesunięć
        self.backbone = self.args.backbone(self.input_shape,
                                           n_layers=self.args.layers)

        # użycie szkieletu, konstruowanie sieci SSD
        # wyjściem SSD jest predykcja klasy i przesunięcia
        anchors, features, ssd = build_ssd(self.input_shape,
                                           self.backbone,
                                           n_layers=self.args.layers,
                                           n_classes=self.n_classes)
        # n_anchors = liczba zakotwiczeń na punkt cechy (tj 4)
        self.n_anchors = anchors
        # feature_shapes jest listą kształtów mapy cech
        # na warstwę wyjściową — do obliczenia rozmiarów pól zakotwiczenia
        self.feature_shapes = features
        # model sieci SSD
        self.ssd = ssd


    def build_dictionary(self):
        """odczyt nazw plików wejściowych i etykiet wykrywania obiektów
        z pliku csv i przechowywanie w słowniku
        """
        # ścieżka do zbioru uczącego
        path = os.path.join(self.args.data_path,
                            self.args.train_labels)

        # tworzenie słownika 
        # klucz = nazwa pliku obrazu, wartość = współrzędne pola + etykieta klasy
        # self.classes to lista etykiet klas
        self.dictionary, self.classes = build_label_dictionary(path)
        self.n_classes = len(self.classes)
        self.keys = np.array(list(self.dictionary.keys()))


    def build_generator(self):
        """Konstruowanie wielowątkowego generatora danych uczących"""

        self.train_generator = \
                DataGenerator(args=self.args,
                              dictionary=self.dictionary,
                              n_classes=self.n_classes,
                              feature_shapes=self.feature_shapes,
                              n_anchors=self.n_anchors,
                              shuffle=True)


    def train(self):
        """Uczenie sieci SSD"""
        # konstruowanie generatora danych uczących
        if self.train_generator is None:
            self.build_generator()

        optimizer = Adam(lr=1e-3)
        # wybór funkcji straty przez argumenty
        if self.args.improved_loss:
            print_log("Focal loss and smooth L1", self.args.verbose)
            loss = [focal_loss_categorical, smooth_l1_loss]
        elif self.args.smooth_l1:
            print_log("Smooth L1", self.args.verbose)
            loss = ['categorical_crossentropy', smooth_l1_loss]
        else:
            print_log("Cross-entropy and L1", self.args.verbose)
            loss = ['categorical_crossentropy', l1_loss]

        self.ssd.compile(optimizer=optimizer, loss=loss)

        # wagi modelu sa zachowywane do przyszłej walidacji
        # przygotowanie katalogu do zapisywania modelu
        save_dir = os.path.join(os.getcwd(), self.args.save_dir)
        model_name = self.backbone.name
        model_name += '-' + str(self.args.layers) + "layer"
        if self.args.normalize:
            model_name += "-norm"
        if self.args.improved_loss:
            model_name += "-improved_loss"
        elif self.args.smooth_l1:
            model_name += "-smooth_l1"

        if self.args.threshold < 1.0:
            model_name += "-extra_anchors" 

        model_name += "-" 
        model_name += self.args.dataset
        model_name += '-{epoch:03d}.h5'

        log = "# of classes %d" % self.n_classes
        print_log(log, self.args.verbose)
        log = "Batch size: %d" % self.args.batch_size
        print_log(log, self.args.verbose)
        log = "Weights filename: %s" % model_name
        print_log(log, self.args.verbose)
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        filepath = os.path.join(save_dir, model_name)

        # przygotowanie wywołań zwrotnych do zapisania wag modelu
        # i harmonogramu współczynnika uczenia
        # współczynnik uczenia zmniejsza się o 50% co 20 epok, począwszy od 60. epoki
        checkpoint = ModelCheckpoint(filepath=filepath,
                                     verbose=1,
                                     save_weights_only=True)
        scheduler = LearningRateScheduler(lr_scheduler)

        callbacks = [checkpoint, scheduler]
        # uczenie sieci SSD
        self.ssd.fit(self.train_generator,
                     use_multiprocessing=False,
                     callbacks=callbacks,
                     epochs=self.args.epochs)


    def restore_weights(self):
        """Ładowanie wag wcześniej wytrenowanego modelu"""
        if self.args.restore_weights:
            save_dir = os.path.join(os.getcwd(), self.args.save_dir)
            filename = os.path.join(save_dir, self.args.restore_weights)
            log = "Loading weights: %s" % filename
            print(log, self.args.verbose)
            self.ssd.load_weights(filename)


    def detect_objects(self, image):
        image = np.expand_dims(image, axis=0)
        classes, offsets = self.ssd.predict(image)
        image = np.squeeze(image, axis=0)
        classes = np.squeeze(classes)
        offsets = np.squeeze(offsets)
        return image, classes, offsets


    def evaluate(self, image_file=None, image=None):
        """Ocena obrazu na podstawie obrazu (tensor numpy) lub nazwy pliku"""
        show = False
        if image is None:
            image = skimage.img_as_float(imread(image_file))
            show = True

        image, classes, offsets = self.detect_objects(image)
        class_names, rects, _, _ = show_boxes(args,
                                              image,
                                              classes,
                                              offsets,
                                              self.feature_shapes,
                                              show=show)
        return class_names, rects


    def evaluate_test(self):
        # ścieżka etykiet testowych w csv
        path = os.path.join(self.args.data_path,
                            self.args.test_labels)
        # słownik testowy
        dictionary, _ = build_label_dictionary(path)
        keys = np.array(list(dictionary.keys()))
        # suma precyzji
        s_precision = 0
        # suma czułości
        s_recall = 0
        # suma IoU
        s_iou = 0
        # ocena (na obraz)
        for key in keys:
            # etykiety odniesienia
            labels = np.array(dictionary[key])
            # 4 współrzędnie pola są pierwszymi czterema elementami etykiet
            gt_boxes = labels[:, 0:-1]
            # ostatnia jest klasą
            gt_class_ids = labels[:, -1]
            # załaduj identyfikator obrazy jako klucz
            image_file = os.path.join(self.args.data_path, key)
            image = skimage.img_as_float(imread(image_file))
            image, classes, offsets = self.detect_objects(image)
            # wykonaj nms
            _, _, class_ids, boxes = show_boxes(args,
                                                image,
                                                classes,
                                                offsets,
                                                self.feature_shapes,
                                                show=False)

            boxes = np.reshape(np.array(boxes), (-1,4))
            # oblicz IoU
            iou = layer_utils.iou(gt_boxes, boxes)
            # opuść puste IoU
            if iou.size ==0:
                continue
            # klasa przewidywanego pola (z maksymalnym IoU)
            maxiou_class = np.argmax(iou, axis=1)

            # prawdziwie dodatni
            tp = 0
            # fałszywie dodatni
            fp = 0
            # suma IoU obiektów na obraz
            s_image_iou = []
            for n in range(iou.shape[0]):
                # referencyjne pole zakotwiczenia ma etykietę
                if iou[n, maxiou_class[n]] > 0:
                    s_image_iou.append(iou[n, maxiou_class[n]])
                    # prawdziwie dodatni ma tę sama klasę oraz pole/wartość odniesienia
                    if gt_class_ids[n] == class_ids[maxiou_class[n]]:
                        tp += 1
                    else:
                        fp += 1

            # obiekty pominięte (fałszywie negatywne)
            fn = abs(len(gt_class_ids) - tp - fp)
            s_iou += (np.sum(s_image_iou) / iou.shape[0])
            s_precision += (tp/(tp + fp))
            s_recall += (tp/(tp + fn))


        n_test = len(keys)
        print_log("mIoU: %f" % (s_iou/n_test),
                  self.args.verbose)
        print_log("Precision: %f" % (s_precision/n_test),
                  self.args.verbose)
        print_log("Recall : %f" % (s_recall/n_test),
                  self.args.verbose)


    def print_summary(self):
        """Wydrukuj podsumowanie sieci do debugowania."""
        from tensorflow.keras.utils import plot_model
        if self.args.summary:
            self.backbone.summary()
            self.ssd.summary()
            plot_model(self.backbone,
                       to_file="backbone.png",
                       show_shapes=True)


if __name__ == '__main__':
    parser = ssd_parser()
    args = parser.parse_args()
    ssd = SSD(args)

    if args.summary:
        ssd.print_summary()

    if args.restore_weights:
        ssd.restore_weights()
        if args.evaluate:
            if args.image_file is None:
                ssd.evaluate_test()
            else:
                ssd.evaluate(image_file=args.image_file)
            
    if args.train:
        ssd.train()
