import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import load_digits
from sklearn.neighbors import NearestNeighbors


# Wyznacza ziarno losowe w celu odtworzenia rezultatów
np.random.seed(1000)


if __name__ == '__main__':
    # Wczytuje zestaw danych
    digits = load_digits()
    X_train = digits['data'] / np.max(digits['data'])

    # Realizuje algorytm kNN
    knn = NearestNeighbors(n_neighbors=50, algorithm='ball_tree')
    knn.fit(X_train)

    # Sprawdza model
    distances, neighbors = knn.kneighbors(X_train[100].reshape(1, -1), return_distance=True)

    print('Odległości: {}'.format(distances[0]))

    # Tworzy wykres sąsiadów
    fig, ax = plt.subplots(5, 10, figsize=(8, 8))

    for y in range(5):
        for x in range(10):
            idx = neighbors[0][(x + (y * 10))]
            ax[y, x].matshow(digits['images'][idx], cmap='gray')
            ax[y, x].set_xticks([])
            ax[y, x].set_yticks([])

    plt.show()

