import matplotlib.pyplot as plt
import numpy as np

# Wykres zanieczyszczenia Giniego dla przypadków binarnych
pos_fraction = np.linspace(0.00, 1.00, 1000)
gini = 1 - pos_fraction**2 - (1-pos_fraction)**2
plt.plot(pos_fraction, gini)
plt.xlabel('Udział klasy dodatniej')
plt.ylabel('Zanieczyszczenie Giniego')
plt.ylim(0, 1)
# plt.show()

# Funkcja wyliczająca zanieczyszczenie Giniego dla zadanych etykiet zbioru danych
def gini_impurity(labels):
    # Zbiór pusty jest również czysty
    if not labels:
        return 0
    # Zliczenie wystąpień każdej etykiety
    counts = np.unique(labels, return_counts=True)[1]
    fractions = counts / float(len(labels))
    return 1 - np.sum(fractions ** 2)

print(f'{gini_impurity([1, 1, 0, 1, 0]):.4f}')
print(f'{gini_impurity([1, 1, 0, 1, 0, 0]):.4f}')
print(f'{gini_impurity([1, 1, 1, 1]):.4f}')

# Entropia w zbiorze binarnym
pos_fraction = np.linspace(0.00, 1.00, 1000)
ent = - (pos_fraction * np.log2(pos_fraction) + (1 - pos_fraction) * np.log2(1 - pos_fraction))
plt.plot(pos_fraction, ent)
plt.xlabel('Udział klasy dodatniej')
plt.ylabel('Entropia')
plt.ylim(0, 1)
# plt.show()

# Funkcja wyliczająca entropię zbioru o zadanych etykietach
def entropy(labels):
    if not labels:
        return 0
    counts = np.unique(labels, return_counts=True)[1]
    fractions = counts / float(len(labels))
    return - np.sum(fractions * np.log2(fractions))

print(f'{entropy([1, 1, 0, 1, 0]):.4f}')
print(f'{entropy([1, 1, 0, 1, 0, 0]):.4f}')
print(f'{entropy([1, 1, 1, 1]):.4f}')

criterion_function = {'gini': gini_impurity, 'entropy': entropy}
def weighted_impurity(groups, criterion='gini'):
    """
    Funkcja wyliczająca ważone zanieczyszczenie zbioru po podziale
    @param groups: lista węzłów potomnych, z których każdy zawiera listę etykiet klas
    @param criterion: wskaźnik jakości podziału, 'gini' oznacza zanieczyszczenie Giniego, a 'entropy' przyrost informacji
    @return: float, weighted impurity
    """
    total = sum(len(group) for group in groups)
    weighted_sum = 0.0
    for group in groups:
        weighted_sum += len(group) / float(total) * criterion_function[criterion](group)
    return weighted_sum

children_1 = [[1, 0, 1], [0, 1]]
children_2 = [[1, 1], [0, 0, 1]]
print(f"Entropia 1: {weighted_impurity(children_1, 'entropy'):.4f}")
print(f"Entropia 2: {weighted_impurity(children_2, 'entropy'):.4f}")

def gini_impurity_np(labels):
    # Zbiór pusty jest również czysty
    if labels.size == 0:
        return 0
    # Zliczenie wystąpień każdej etykiety
    counts = np.unique(labels, return_counts=True)[1]
    fractions = counts / float(len(labels))
    return 1 - np.sum(fractions ** 2)

def entropy_np(labels):
    # Zbiór pusty jest również czysty
    if labels.size == 0:
        return 0
    counts = np.unique(labels, return_counts=True)[1]
    fractions = counts / float(len(labels))
    return - np.sum(fractions * np.log2(fractions))

criterion_function_np = {'gini': gini_impurity_np, 'entropy': entropy_np}
def weighted_impurity(groups, criterion='gini'):
    """
    Funkcja wyliczająca ważone zanieczyszczenie zbioru po podziale
    @param groups: lista węzłów potomnych, z których każdy zawiera listę etykiet klas
    @param criterion: wskaźnik jakości podziału, 'gini' oznacza zanieczyszczenie Giniego, a 'entropy' przyrost informacji
    @return: float, weighted impurity
    """
    total = sum(len(group) for group in groups)
    weighted_sum = 0.0
    for group in groups:
        weighted_sum += len(group) / float(total) * criterion_function_np[criterion](group)
    return weighted_sum

def split_node(X, y, index, value):
    """
    Funkcja dzieląca zbiór X na podstawie cechy i wartości
    @param X: numpy.ndarray, cechy zbioru
    @param y: numpy.ndarray, docelowy zbiór
    @param index: int, indeks cechy wykorzystywanej do dzielenia
    @param value: wartość cechy wykorzystywanej do dzielenia
    @return: dwie listy w formacie [X, y] reprezentujące węzły potomne lewy i prawy
    """
    x_index = X[:, index]
    # Jeżeli cecha jest liczbowa
    if X[0, index].dtype.kind in ['i', 'f']:
        mask = x_index >= value
    # Jeżeli cecha jest kategorialna
    else:
        mask = x_index == value
    # Podział na węzły potomne lewy i prawy
    left = [X[~mask, :], y[~mask]]
    right = [X[mask, :], y[mask]]
    return left, right

def get_best_split(X, y, criterion):
    """
    Funkcja wyszukująca najlepszy punkt podziału zbioru X, y i zwracająca węzły potomne
    @param X: numpy.ndarray, cechy zbioru
    @param y: numpy.ndarray, docelowy zbiór
    @param criterion: kryterium gini lub entropy
    @return: dict {index: indeks cechy, value: wartość cechy, children: węzły potomne lewy i prawy}
    """
    best_index, best_value, best_score, children = None, None, 1, None
    for index in range(len(X[0])):
        for value in np.sort(np.unique(X[:, index])):
            groups = split_node(X, y, index, value)
            impurity = weighted_impurity([groups[0][1], groups[1][1]], criterion)
            if impurity < best_score:
                best_index, best_value, best_score, children = index, value, impurity, groups
    return {'index': best_index, 'value': best_value, 'children': children}

def get_leaf(labels):
    # Zwrócenie liścia z główną etykietą
    return np.bincount(labels).argmax()

def split(node, max_depth, min_size, depth, criterion):
    """
    Funkcja dzieląca węzeł lub przypisująca mu końcową etykietę
    @param node: słownik z informacjami o węźle
    @param max_depth: int, maksymalna głębokość drzewa
    @param min_size: int, minimalna liczba próbek wymagana do podziału węzła
    @param depth: int, głębokość aktualnego węzła
    @param criterion: kryterium gini lub entropy
    """
    left, right = node['children']
    del (node['children'])
    if left[1].size == 0:
        node['right'] = get_leaf(right[1])
        return
    if right[1].size == 0:
        node['left'] = get_leaf(left[1])
        return
    # Sprawdzenie, czy aktualna głębokość nie przekracza maksymalnej
    if depth >= max_depth:
        node['left'], node['right'] = get_leaf(left[1]), get_leaf(right[1])
        return
    # Sprawdzenie, czy lewy węzeł potomny zawiera wystarczającą liczbę próbek
    if left[1].size <= min_size:
        node['left'] = get_leaf(left[1])
    else:
        # Jeżeli tak, dzielimy go dalej
        result = get_best_split(left[0], left[1], criterion)
        result_left, result_right = result['children']
        if result_left[1].size == 0:
            node['left'] = get_leaf(result_right[1])
        elif result_right[1].size == 0:
            node['left'] = get_leaf(result_left[1])
        else:
            node['left'] = result
            split(node['left'], max_depth, min_size, depth + 1, criterion)
    # Sprawdzenie, czy prawy węzeł potomny zawiera wystarczającą liczbę próbek
    if right[1].size <= min_size:
        node['right'] = get_leaf(right[1])
    else:
        # Jeżeli tak, dzielimy go dalej
        result = get_best_split(right[0], right[1], criterion)
        result_left, result_right = result['children']
        if result_left[1].size == 0:
            node['right'] = get_leaf(result_right[1])
        elif result_right[1].size == 0:
            node['right'] = get_leaf(result_left[1])
        else:
            node['right'] = result
            split(node['right'], max_depth, min_size, depth + 1, criterion)

def train_tree(X_train, y_train, max_depth, min_size, criterion='gini'):
    """
    Funkcja inicjująca budowanie drzewa
    @param X_train: lista próbek treningowych (cechy)
    @param y_train: lista próbek treningowych (cel)
    @param max_depth: int, maksymalna głębokość drzewa
    @param min_size: int, minimalna liczba próbek wymagana do podziału węzła
    @param criterion: kryterium gini lub entropy
    """
    X = np.array(X_train)
    y = np.array(y_train)
    root = get_best_split(X, y, criterion)
    split(root, max_depth, min_size, 1, criterion)
    return root

CONDITION = {'numerical': {'yes': '>=', 'no': '<'},
             'categorical': {'yes': 'to', 'no': 'to nie'}}
def visualize_tree(node, depth=0):
    if isinstance(node, dict):
        if node['value'].dtype.kind in ['i', 'f']:
            condition = CONDITION['numerical']
        else:
            condition = CONDITION['categorical']
        print('{}|- X{} {} {}'.format(depth * '  ', node['index'] + 1, condition['no'], node['value']))
        if 'left' in node:
            visualize_tree(node['left'], depth + 1)
        print('{}|- X{} {} {}'.format(depth * '  ', node['index'] + 1, condition['yes'], node['value']))
        if 'right' in node:
            visualize_tree(node['right'], depth + 1)
    else:
        print(f"{depth * '  '}[{node}]")

X_train = [['technika', 'specjalista'],
           ['moda', 'student'],
           ['moda', 'specjalista'],
           ['sport', 'student'],
           ['technika', 'student'],
           ['technika', 'emeryt'],
           ['sport', 'specjalista']]

y_train = [1,
           0,
           0,
           0,
           1,
           0,
           1]

tree = train_tree(X_train, y_train, 2, 2)
visualize_tree(tree)

X_train_n = [[6, 7],
           [2, 4],
           [7, 2],
           [3, 6],
           [4, 7],
           [5, 2],
           [1, 6],
           [2, 0],
           [6, 3],
           [4, 1]]

y_train_n = [0,
           0,
           0,
           0,
           0,
           1,
           1,
           1,
           1,
           1]

tree = train_tree(X_train_n, y_train_n, 2, 2)
visualize_tree(tree)

from sklearn.tree import DecisionTreeClassifier
tree_sk = DecisionTreeClassifier(criterion='gini', max_depth=2, min_samples_split=2)
tree_sk.fit(X_train_n, y_train_n)

from sklearn.tree import export_graphviz
export_graphviz(tree_sk, out_file='tree.dot', feature_names=['X1', 'X2'], impurity=False, filled=True, class_names=['0', '1'])