"""Narzędzia warstw

Funkcja narzędziowe do obliczania IOU, pól zakotwiczenia, masek,
oraz przesunięć pól obwiedni

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np
import config
import math
from tensorflow.keras import backend as K

def anchor_sizes(n_layers=4):
    """Generuje liniowy rozkład rozmiarów zależnych od liczby najwyższych warstw ssd

    Argument:
        n_layers (int): Liczba czołowych warstw ssd

    Zwraca:
        sizes (list): Listę rozmiarów kotwic
    """
    s = np.linspace(0.2, 0.9, n_layers + 1)
    sizes = []
    for i in range(len(s) - 1):
        # size = [s[i], (s[i] * 0.5)]
        size = [s[i], math.sqrt(s[i] * s[i + 1])]
        sizes.append(size)

    return sizes


def anchor_boxes(feature_shape,
                 image_shape,
                 index=0,
                 n_layers=4,
                 aspect_ratios=(1, 2, 0.5)):
    """Obliczenie pól zakotwiczenia dla zadanej mapy cech
    Pola zakotwiczenia są w formacie minmax.

    Argumenty:
        feature_shape (list): kształt mapy cech
        image_shape (list): postać rozmiaru obrazu
        index (int): wskazuje, do którego czoła ssd się odwołujemy
        n_layers (int): liczba bloków czoła ssd

    Zwraca:
        boxes (tensor): pola zakotwiczenia dla mapy cech
    """

    # rozmiar pola zakotwiczenia przy zadanym indeksie warstwy w czole ssd

    sizes = anchor_sizes(n_layers)[index]
    # liczba pól zakotwiczenia na punkt mapy cech 
    n_boxes = len(aspect_ratios) + 1
    # zignoruj liczbę kanałów (ostatni)
    image_height, image_width, _ = image_shape
    # zignoruj liczbę map cech (ostatnia) 
    feature_height, feature_width, _ = feature_shape

    # znormalizowana szerokość i wysokość
    # sizes[0] jest rozmiarem skali, sizes[1] to sqrt(scale*(scale+1))
    norm_height = image_height * sizes[0]
    norm_width = image_width * sizes[0]

    # lista pól zakotwiczenia (szerokość, wysokość)
    width_height = []
    # pole zakotwiczenia do proporcji na wymiarach obrazu po zmianie rozmiaru
    # równanie 11.3
    for ar in aspect_ratios:
        box_width = norm_width * np.sqrt(ar)
        box_height = norm_height / np.sqrt(ar)
        width_height.append((box_width, box_height))
    # mnożenie wymiaru pola zakotwiczenia przez size[1] dla aspect_ratio = 1
    # równanie 11.4
    box_width = image_width * sizes[1]
    box_height = image_height * sizes[1]
    width_height.append((box_width, box_height))

    # tablica (szerokość, wysokość)
    width_height = np.array(width_height)

    # wymiary każdego pola receptywnego w pikselach
    grid_width = image_width / feature_width
    grid_height = image_height / feature_height

    # obliczenie środka pola receptywnego na punkt cechy
    # format (cx, cy)
    # zaczynamy w środku pierwszego pola receptywnego
    start = grid_width * 0.5 
    # kończymy w środku ostatniego pola receptywnego
    end = (feature_width - 0.5) * grid_width
    cx = np.linspace(start, end, feature_width)

    start = grid_height * 0.5
    end = (feature_height - 0.5) * grid_height
    cy = np.linspace(start, end, feature_height)

    # siatka środków pól
    cx_grid, cy_grid = np.meshgrid(cx, cy)

    # dla np.tile()
    cx_grid = np.expand_dims(cx_grid, -1) 
    cy_grid = np.expand_dims(cy_grid, -1)

    # tensor = (feature_map_height, feature_map_width, n_boxes, 4)
    # wyrównanie z obrazem tensora (wysokość, szerokość, kanały)
    # ostatni wymiar = (cx, cy, w, h)
    boxes = np.zeros((feature_height, feature_width, n_boxes, 4))
    
    # (cx, cy)
    boxes[..., 0] = np.tile(cx_grid, (1, 1, n_boxes))
    boxes[..., 1] = np.tile(cy_grid, (1, 1, n_boxes))

    # (w, h)
    boxes[..., 2] = width_height[:, 0]
    boxes[..., 3] = width_height[:, 1]

    # zamiana (cx, cy, w, h) na (xmin, xmax, ymin, ymax)
    # dodanie jednego wymiaru do pola, by uwzględnić rozmiar partii równy 1 
    boxes = centroid2minmax(boxes)
    boxes = np.expand_dims(boxes, axis=0)
    return boxes


def centroid2minmax(boxes):
    """Zamian formatu środek_ciężkości na format minmax 
    (cx, cy, w, h) na (xmin, xmax, ymin, ymax)

    Argumenty:
        boxes (tensor): partia pól w formacie środek_ciężkości

    Zwraca:
        minmax (tensor): partia pól w formacie minmax
    """
    minmax= np.copy(boxes).astype(np.float)
    minmax[..., 0] = boxes[..., 0] - (0.5 * boxes[..., 2])
    minmax[..., 1] = boxes[..., 0] + (0.5 * boxes[..., 2])
    minmax[..., 2] = boxes[..., 1] - (0.5 * boxes[..., 3])
    minmax[..., 3] = boxes[..., 1] + (0.5 * boxes[..., 3])
    return minmax


def minmax2centroid(boxes):
    """Zamian formatu środek_ciężkości na format minmax 
    (cx, cy, w, h) na (xmin, xmax, ymin, ymax)

    Argumenty:
        boxes (tensor): partia pól w formacie środek_ciężkości

    Zwraca:
        minmax (tensor): partia pól w formacie minmax
    """

    centroid = np.copy(boxes).astype(np.float)
    centroid[..., 0] = 0.5 * (boxes[..., 1] - boxes[..., 0])
    centroid[..., 0] += boxes[..., 0] 
    centroid[..., 1] = 0.5 * (boxes[..., 3] - boxes[..., 2])
    centroid[..., 1] += boxes[..., 2] 
    centroid[..., 2] = boxes[..., 1] - boxes[..., 0]
    centroid[..., 3] = boxes[..., 3] - boxes[..., 2]
    return centroid



def intersection(boxes1, boxes2):
    """Obliczenie części wspólnej partii boxes1 oraz boxes2
    
    Argumenty:
        boxes1 (tensor): Współrzędne pól w pikselach
        boxes2 (tensor): Współrzędne pól w pikselach

    Zwraca:
        intersection_areas (tensor): część wspólna obszarów boxes1 and boxes2
    """
    m = boxes1.shape[0] # The number of boxes in `boxes1`
    n = boxes2.shape[0] # The number of boxes in `boxes2`

    xmin = 0
    xmax = 1
    ymin = 2
    ymax = 3

    boxes1_min = np.expand_dims(boxes1[:, [xmin, ymin]], axis=1)
    boxes1_min = np.tile(boxes1_min, reps=(1, n, 1))
    boxes2_min = np.expand_dims(boxes2[:, [xmin, ymin]], axis=0)
    boxes2_min = np.tile(boxes2_min, reps=(m, 1, 1))
    min_xy = np.maximum(boxes1_min, boxes2_min)

    boxes1_max = np.expand_dims(boxes1[:, [xmax, ymax]], axis=1)
    boxes1_max = np.tile(boxes1_max, reps=(1, n, 1))
    boxes2_max = np.expand_dims(boxes2[:, [xmax, ymax]], axis=0)
    boxes2_max = np.tile(boxes2_max, reps=(m, 1, 1))
    max_xy = np.minimum(boxes1_max, boxes2_max)

    side_lengths = np.maximum(0, max_xy - min_xy)

    intersection_areas = side_lengths[:, :, 0] * side_lengths[:, :, 1]
    return intersection_areas


def union(boxes1, boxes2, intersection_areas):
    """Oblicza złączenie partii boxes1 and boxes2

    Argumenty:
        boxes1 (tensor): Współrzędne pól w pikselach
        boxes2 (tensor): Współrzędne pól w pikselach

    Zwraca:
        union_areas (tensor): złączenie obszarów boxes1 and boxes2
    """
    m = boxes1.shape[0] # number of boxes in boxes1
    n = boxes2.shape[0] # number of boxes in boxes2

    xmin = 0
    xmax = 1
    ymin = 2
    ymax = 3

    width = (boxes1[:, xmax] - boxes1[:, xmin])
    height = (boxes1[:, ymax] - boxes1[:, ymin])
    areas = width * height
    boxes1_areas = np.tile(np.expand_dims(areas, axis=1), reps=(1,n))
    width = (boxes2[:,xmax] - boxes2[:,xmin])
    height = (boxes2[:,ymax] - boxes2[:,ymin])
    areas = width * height
    boxes2_areas = np.tile(np.expand_dims(areas, axis=0), reps=(m,1))

    union_areas = boxes1_areas + boxes2_areas - intersection_areas
    return union_areas


def iou(boxes1, boxes2):
    """Oblicza IoU dla partii boxes1 oraz boxes2

    Arguments:
        boxes1 (tensor): Współrzędne pól w pikselach
        boxes2 (tensor): Współrzędne pól w pikselach

    Zwraca:
        iou (tensor): część wspólna względem sumy dla obszarów boxes1 and boxes2
    """
    intersection_areas = intersection(boxes1, boxes2)
    union_areas = union(boxes1, boxes2, intersection_areas)
    return intersection_areas / union_areas


def get_gt_data(iou,
                n_classes=4,
                anchors=None,
                labels=None,
                normalize=False,
                threshold=0.6):
    """Pobranie klas referencyjnych, przesunięć obwiedni i masek

    Argumenty:
        iou (tensor): IoU każdej obwiedni z każdym polem zakotwiczenia
        n_classes (int): liczba klas obiektów
        anchors (tensor): pola zakotwiczenia na warstwę cechy
        labels (list): etykiety referencyjne
        normalize (bool): czy ma być wykonana normalizacja
        threshold (float): jeśli mniejsze niż 1.0, to pole zakotwiczenia > wartość progowa
            również jest pozytywnym polem zakotwiczenia

    Zwraca:
        gt_class, gt_offset, gt_mask (tensor): klasy referencyjne, przesunięcia i maski
    """
    # każde maxiou_per_get jest indeksem kotwicy z max iou dla danej obwiedni
    maxiou_per_gt = np.argmax(iou, axis=0)
    
    # dodatkowe pola zakotwiczenia w oparciu o wartość IoU
    if threshold < 1.0:
        iou_gt_thresh = np.argwhere(iou>threshold)
        if iou_gt_thresh.size > 0:
            extra_anchors = iou_gt_thresh[:,0]
            extra_classes = iou_gt_thresh[:,1]
            #extra_labels = labels[:,:][extra_classes]
            extra_labels = labels[extra_classes]
            indexes = [maxiou_per_gt, extra_anchors]
            maxiou_per_gt = np.concatenate(indexes,
                                           axis=0)
            labels = np.concatenate([labels, extra_labels],
                                    axis=0)

    # generowanie masek
    gt_mask = np.zeros((iou.shape[0], 4))
    # tylko indeksy maxiou_per_gt są prawidłowymi obwiedniami
    gt_mask[maxiou_per_gt] = 1.0

    # generowanie klas
    gt_class = np.zeros((iou.shape[0], n_classes))
    # domyślnie wszystko jest tłem (index 0)
    gt_class[:, 0] = 1
    # ale to, co jest w maxiou_per_gt — nie jest
    gt_class[maxiou_per_gt, 0] = 0
    # musimy znaleźć indeksy (klasy) tych kolumn
    maxiou_col = np.reshape(maxiou_per_gt,
                            (maxiou_per_gt.shape[0], 1))
    label_col = np.reshape(labels[:,4],
                           (labels.shape[0], 1)).astype(int)
    row_col = np.append(maxiou_col, label_col, axis=1)
    # etykieta obiektu w maxio_per_gt
    gt_class[row_col[:,0], row_col[:,1]]  = 1.0
    
    # generowanie przesunięć
    gt_offset = np.zeros((iou.shape[0], 4))

    # format (cx, cy, w, h)
    if normalize:
        anchors = minmax2centroid(anchors)
        labels = minmax2centroid(labels)
        # bbox = obwiednia, p_z = pole zakotwiczenia
        # ((bbox xcenter — p_z xcenter)/p_z width)/.1
        # ((bbox ycenter — p_z ycenter)/p_z height)/.1
        # równanie 11.7
        offsets1 = labels[:, 0:2] - anchors[maxiou_per_gt, 0:2]
        offsets1 /= anchors[maxiou_per_gt, 2:4]
        offsets1 /= 0.1

        # log(bbox width / anchor box width) / 0.2
        # log(bbox height / anchor box height) / 0.2
        # równanie 11.7
        offsets2 = np.log(labels[:, 2:4]/anchors[maxiou_per_gt, 2:4])
        offsets2 /= 0.2  

        offsets = np.concatenate([offsets1, offsets2], axis=-1)

    # format (xmin, xmax, ymin, ymax)
    else:
        offsets = labels[:, 0:4] - anchors[maxiou_per_gt]

    gt_offset[maxiou_per_gt] = offsets

    return gt_class, gt_offset, gt_mask
