import os
import pandas as pd
import numpy as np
import torch
from pathlib import Path
import torch_geometric as pyg
from torch_geometric.data import Data
from sklearn.preprocessing import LabelEncoder
from typing import Tuple
import json
import urllib.request
import zipfile


class FacebookDataPreprocessor:
    """
    Preprocesor dla zbioru danych Facebook ze Stanford.

    Klasa obsługuje pobieranie, rozpakowywanie i przetwarzanie zbioru danych Facebook
    do zadań klasyfikacji wierzchołków.
    """

    def __init__(self, data_dir: str = "data"):
        """
        Inicjalizacja preprocesora.

        Parameters
        ----------
        data_dir : str
            Katalog, w którym będą przechowywane dane
        """
        self.data_dir = data_dir
        self.dataset_url = "https://snap.stanford.edu/data/facebook_large.zip"
        self.zip_path = os.path.join(data_dir, "facebook_large.zip")
        self.extracted_dir = os.path.join(data_dir, "facebook_large")

        # Utwórz katalog danych, jeśli nie istnieje
        os.makedirs(data_dir, exist_ok=True)

    def download_and_extract(self) -> None:
        """Pobierz i rozpakuj zbiór danych, jeśli jeszcze nie istnieje."""
        # Pobierz jeśli nie istnieje
        if not os.path.exists(self.zip_path):
            print("Pobieranie zbioru danych Facebook...")
            urllib.request.urlretrieve(self.dataset_url, self.zip_path)

        # Rozpakuj jeśli jeszcze nie rozpakowano
        if not os.path.exists(self.extracted_dir):
            print("Rozpakowywanie zbioru danych...")
            with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
                zip_ref.extractall(self.data_dir)

    def read_edges(self) -> pd.DataFrame:
        """
        Wczytaj i przetwórz plik krawędzi.

        Returns
        -------
        pd.DataFrame
            DataFrame zawierający krawędzie z przemianowanymi kolumnami
        """
        edges_path = os.path.join(self.extracted_dir, "musae_facebook_edges.csv")
        return pd.read_csv(edges_path)

    def read_features(self) -> dict:
        """Wczytaj plik cech."""
        features_path = os.path.join(self.extracted_dir, "musae_facebook_features.json")
        with open(features_path, "r") as f:
            features = json.load(f)
        return features

    def read_target(self) -> pd.DataFrame:
        """Wczytaj plik etykiet docelowych."""
        target_path = os.path.join(self.extracted_dir, "musae_facebook_target.csv")
        return pd.read_csv(target_path)

    def preprocess_features(
        self, features_dict: dict, target_df: pd.DataFrame, pad_to_max: bool = True
    ) -> torch.Tensor:
        """
        Przetwórz cechy wierzchołków z paddingiem.

        Parameters
        ----------
        features_dict : dict
            Słownik mapujący ID wierzchołków na listy cech
        target_df : pd.DataFrame
            DataFrame zawierający wierzchołki docelowe
        pad_to_max : bool, optional
            Jeśli True, uzupełnij do maksymalnej długości, w przeciwnym razie do mediany

        Returns
        -------
        torch.Tensor
            Przetworzona macierz cech z paddingiem
        """
        xs = []
        lengths = []

        # Przetwórz każdy wierzchołek w zbiorze docelowym
        for node_id in target_df["id"]:
            node_id = str(node_id)
            if node_id not in features_dict:
                # Jeśli wierzchołek nie ma cech, użyj [1.0] jako domyślnej
                features = [1.0]
            else:
                features = features_dict[
                    node_id
                ]  # Konwertuj na str, ponieważ klucze JSON są stringami
                lengths.append(len(features))

            xs.append(features)

        # Określ długość paddingu
        if lengths:
            pad_length = max(lengths) if pad_to_max else int(np.median(lengths))
        else:
            pad_length = 1

        # Uzupełnij sekwencje
        padded_features = []
        for features in xs:
            if len(features) > pad_length:
                # Przytnij jeśli dłuższe niż pad_length
                padded = features[:pad_length]
            else:
                # Uzupełnij zerami jeśli krótsze
                padded = features + [0.0] * (pad_length - len(features))
            padded_features.append(padded)
        # Najpierw konwertuj na tensor
        features_tensor = torch.tensor(padded_features, dtype=torch.float)

        # Normalizuj używając normalizacji torch_geometric
        normalized_features = torch.nn.functional.normalize(
            features_tensor, p=2, dim=-1
        )
        return normalized_features.to(torch.float)

    def preprocess_data(self) -> Tuple[Data, LabelEncoder]:
        """
        Przetwórz zbiór danych Facebook i przekonwertuj go na format PyG Data.

        Returns
        -------
        Tuple[Data, LabelEncoder]
            Obiekt PyG Data i koder etykiet użyty dla wartości docelowych
        """
        print("Wczytywanie plików danych...")
        edges_df = self.read_edges()
        features_dict = self.read_features()
        target_df = self.read_target()

        print("Przetwarzanie krawędzi...")
        # Konwertuj krawędzie na format COO
        edge_index = torch.tensor(
            [edges_df["id_1"].values, edges_df["id_2"].values],
            dtype=torch.long,
        )

        print("Przetwarzanie cech...")
        # Przetwórz cechy używając nowej metody
        x = self.preprocess_features(features_dict, target_df)

        print("Przetwarzanie etykiet docelowych...")
        # Koduj etykiety docelowe
        label_encoder = LabelEncoder()
        y = torch.tensor(
            label_encoder.fit_transform(target_df["page_type"]), dtype=torch.long
        )

        # Utwórz obiekt PyG Data
        data = Data(
            x=x,
            edge_index=edge_index,
            y=y)

        print("Tworzenie masek treningowych/walidacyjnych/testowych...")
        # Utwórz maski treningowe/walidacyjne/testowe (podział 80/10/10)
        # Użyj RandomNodeSplit do podziału danych na zbiory treningowy/walidacyjny/testowy
        splitter = pyg.transforms.RandomNodeSplit(
            num_train_per_class=0.7,
            num_val=0.15,
            num_test=0.15
        )
        data = splitter(data)

        return data, label_encoder

    def process(self) -> Tuple[Data, LabelEncoder]:
        """
        Główna funkcja przetwarzająca, która obsługuje pobieranie i przetwarzanie.

        Returns
        -------
        Tuple[Data, LabelEncoder]
            Obiekt PyG Data i koder etykiet użyty dla wartości docelowych
        """
        self.download_and_extract()
        return self.preprocess_data()


def load_facebook_data(data_dir: str = "data") -> Tuple[Data, LabelEncoder]:
    """
    Funkcja pomocnicza do wczytywania i przetwarzania zbioru danych Facebook.

    Parameters
    ----------
    data_dir : str
        Katalog, w którym będą przechowywane dane

    Returns
    -------
    Tuple[Data, LabelEncoder]
        Obiekt PyG Data i koder etykiet użyty dla wartości docelowych
    """
    preprocessor = FacebookDataPreprocessor(data_dir)
    return preprocessor.process()


if __name__ == "__main__":
    # Przykład użycia
    data, label_encoder = load_facebook_data(data_dir=Path("graph_nn_book/data"))
    print("\nZbiór danych wczytany pomyślnie!")
