"""Moduł zawiera bazowe klasy, które będą używane w tym rozdziale i dołączonych do niego materiałach"""

import torch as th
import torch.nn as nn

from typing import Tuple, Callable


class NaiveGraphConv(nn.Module):
    """Bazowa klasa dla uproszczonych implementacji warstw splotu grafowego"""

    def __init__(
        self,
        dim_in: int,
        dim_out: int,
        aggr: Callable = th.sum,
        sigma: nn.Module = th.nn.Identity(),
        add_self_loops: bool = True,
    ):
        """Inicjalizacja warstwy splotu grafowego.

        Parameters
        ----------
        dim_in : int
            Wymiarowość reprezentacji wejściowej wierzchołków.
        dim_out : int
            Wymiarowość reprezentacji wyjściowej wierzchołków.
        aggr : Callable, optional
            Funkcja agregująca, która będzie używana do agregacji wiadomości od sąsiadów, domyślnie suma.
        sigma : th.Module, optional
            Funkcja aktywacji używana w warstwie splotu, domyślnie tożsamość - a zatem brak nieliniowości.
        add_self_loops: bool
            Czy dodać pętle do samego siebie.
        """
        super(NaiveGraphConv, self).__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.aggr = aggr
        self.sigma = sigma
        self.add_self_loops = add_self_loops

    def message(self, v_neighbors: th.Tensor, *args, **kwargs) -> th.Tensor:
        """Funkcja budująca wiadomość od sąsiadów wierzchołka v.

        Parameters
        ----------
        v_neighbors : th.Tensor
            Macierz reprezentacji sąsiadów wierzchołka v: N(v).
            Wymiary: N x dim_in, gdzie N to liczba sąsiadów wierzchołka v.

        Returns
        -------
        th.Tensor
            Macierz wiadomości od sąsiadów wierzchołka v. Wymiar: 1 x dim_out.
        """
        raise NotImplementedError(
            "Funkcja message musi zostać zaimplementowana w klasie dziedziczącej"
        )

    def aggregate(self, messages: th.Tensor, *args, **kwargs) -> th.Tensor:
        """Funkcja agregująca wiadomości od sąsiadów wierzchołka v.

        Parameters
        ----------
        messages : th.Tensor
            Macierz wiadomości od sąsiadów wierzchołka v: M(v).
            Wymiary: N x dim_out, gdzie N to liczba sąsiadów wierzchołka v.

        Returns
        -------
        th.Tensor
            Macierz wiadomości od sąsiadów wierzchołka v. Wymiar: 1 x dim_out.
        """
        return self.aggr(messages, dim=0)

    def update(self, v: th.Tensor, message: th.Tensor, *args, **kwargs) -> th.Tensor:
        """Funkcja aktualizująca reprezentację wierzchołka v na podstawie wiadomości od sąsiadów.

        Parameters
        ----------
        v : th.Tensor
            Macierz reprezentacji wierzchołka v: h_v.
            Wymiary: 1 x dim_in.
        message : th.Tensor
            Macierz wiadomości od sąsiadów wierzchołka v: M(v).
            Wymiary: 1 x dim_out.

        Returns
        -------
        th.Tensor
            Macierz reprezentacji wierzchołka v po aktualizacji. Wymiar: 1 x dim_out.
        """
        raise NotImplementedError(
            "Funkcja update musi zostać zaimplementowana w klasie dziedziczącej"
        )

    def forward(self, x: th.Tensor, edge_index: th.Tensor) -> th.Tensor:
        """Funkcja przepływu danych w warstwie splotu grafowego. Rozpoczyna cykl
        przekazywania wiadomości, dla każdego wierzchołka w grafie: v po kolei:
        1. Znajdując sąsiadów;
        2. Budując wiadomości od sąsiadów;
        3. Agregując wiadomości od sąsiadów;
        4. Aktualizując reprezentację wierzchołka.

        Parameters
        ----------
        x : th.Tensor
            Macierz cech wierzchołków.

        edge_index : th.Tensor
            Macierz sąsiedztwa grafu.

        Returns
        -------
        th.Tensor
            Zaktualizowane cechy wierzchołków.
        """
        # Alokacja nowej macierzy reprezentacji wierzchołków
        N = x.shape[0]
        h_new = []

        # Dla każdego wierzchołka v w grafie
        for v in range(x.shape[0]):
            # Znajdź indeksy w macierzy sąsiedztwa, gdzie odbiera on wiadomość
            v_index_as_receiver = th.where(edge_index[1] == v)[0]

            # Znajdź sąsiadów, którzy wysyłają tę wiadomość
            # |N(v)| x dim_in
            neighbors = edge_index[0, v_index_as_receiver]

            # Zbuduj wiadomości od sąsiadów
            # |N(v)| x dim_out
            neighbors_x = x[neighbors]

            if self.add_self_loops:
                neighbors_x = self.apply_self_loops(x, v, neighbors)
            m_nv = self.message(neighbors_x)

            # Agreguj wiadomości sąsiadów
            # 1 x dim_out
            agg = self.aggregate(m_nv)

            # Zaktualizuj wierzchołek v
            # 1 x dim_out
            new_v = self.update(x[v], agg)
            h_new.append(new_v)
        return th.cat(h_new)

    def apply_self_loops(
        self, x: th.Tensor, v: th.Tensor, neighbors: th.Tensor
    ) -> th.Tensor:
        """Funkcja odpowiadająca za dodanie pętli wierzchołka v do samego siebie.
        Krok ten jest co do zasady rekomendowany, przydatny też w przypadku braku sąsiadów - pozwala
        na zachowanie informacji o wierzchołku w przypadku braku jakichkolwiek krawędzi.

        Parameters
        ----------
        x : th.Tensor
            Macierz cech wierzchołków. Wymiar: N x dim_in.

        v : th.Tensor
            Indeks wierzchołka v w macierzy cech x.

        neighbors : th.Tensor
            Indeksy sąsiadów wierzchołka v w macierzy cech x.

        Returns
        -------
        th.Tensor
            Macierz cech wierzchołków z dodaną pętlą wierzchołka v do samego siebie.
        """
        # Dodanie pętli wierzchołka v do samego siebie.
        # Krok co do zasady rekomendowany, przydatny też w przypadku braku sąsiadów.
        v_self_x = x[v, :].unsqueeze(0)
        neighbors_x = th.cat((x[neighbors], v_self_x))
        return neighbors_x


class NodesClassifierModel(nn.Module):
    """Model klasyfikacji, który może przyjąć dowolną ilość warstw splotu grafowego (w implementacji naiwnej lub PyG)
    i będzie w stanie dokonać klasyfikacji wierzchołków."""

    def __init__(self, gnn_convs: Tuple[nn.Module, ...], lin_sizes: Tuple[int, ...]):
        super(NodesClassifierModel, self).__init__()
        # Zbudowanie sekwencji warstw warstw splotu grafowego
        self.gnn_convs = nn.ModuleList(gnn_convs)

        # Zbudowanie sekwencji warstw MLP
        lin_layers = []
        for i in range(1, len(lin_sizes)):
            lin_layers.append(
                nn.Linear(
                    lin_sizes[i - 1],
                    lin_sizes[i],
                )
            )
            lin_layers.append(nn.ReLU())
        self.lin_layers = nn.Sequential(*lin_layers)

    def forward(self, x: th.Tensor, edge_index: th.Tensor) -> th.Tensor:
        last_x = x
        # Dla każdej warstwy splotu grafowego: przekazujemy ostatnią
        # reprezentację wierzchołków i macierz sąsiedztwa (niezmienioną)
        for conv in self.gnn_convs:
            last_x = conv(last_x, edge_index)
        # Wynik ostatniej warstwy splotu jest przekazywany do MLP i służy do predykcji klas
        lin_pred = self.lin_layers(last_x)
        return lin_pred
