# Standardowe skróty importowanych pakietów
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import patsy

import itertools as it
import collections as co
import functools as ft
import os.path as osp

import glob
import textwrap

import warnings
warnings.filterwarnings("ignore")
# Niektóre ostrzeżenia są skrajnie irytujące; nie 
# są one pożądane w tej książce
def warn(*args, **kwargs): pass
warnings.warn = warn

# Kod konfiguracji
np.set_printoptions(precision=4,
                    suppress=True)
pd.options.display.float_format = '{:20,.4f}'.format

# Istnieją powody do tego, by *NIE* stosować tej techniki w kodzie produkcyjnym;
# jednak w tym kontekście (pisanie książki z powtarzalnymi danymi wyjściowymi)
# *JEST* to odpowiednie rozwiązanie
np.random.seed(42)

# Wartości domyślne to [6.4, 4.8] (4:3)
mpl.rcParams['figure.figsize'] = [4.0, 3.0]

# Włączanie tabel zgodnych z systemem LaTeX
pd.set_option('display.latex.repr', True)
# Kod wyśrodkowujący ramki danych w komórkach Out[]
def _repr_latex_(self):
    return "{\centering\n%s\n\medskip}" % self.to_latex()
pd.DataFrame._repr_latex_ = _repr_latex_

# Używane tylko raz
markers = it.cycle(['+', '^', 'o', '_', '*', 'd', 'x', 's'])

# Wygodne narzędzia do wyświetlania danych
from IPython.display import Image

#
# Pakiety w bibliotece sklearn są tworzone w stylu typowym dla Javy :(
#
from sklearn import (cluster,
                     datasets,
                     decomposition,
                     discriminant_analysis,
                     dummy,
                     ensemble,
                     feature_selection as ftr_sel,
                     linear_model,
                     metrics,
                     model_selection as skms,
                     multiclass as skmulti,
                     naive_bayes,
                     neighbors,
                     pipeline,
                     preprocessing as skpre,
                     svm,
                     tree)

# Celem jest generowanie predykcji dla dużej siatki punktów danych
# http://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
def plot_boundary(ax, data, tgt, model, dims, grid_step = .01):
    # Pobieranie dwuwymiarowego obrazu danych i granic
    twoD = data[:, list(dims)]

    min_x1, min_x2 = np.min(twoD, axis=0) + 2 * grid_step
    max_x1, max_x2 = np.max(twoD, axis=0) - grid_step

    # Generowanie siatki punktów i predykcji
    xs, ys = np.mgrid[min_x1:max_x1:grid_step,
                      min_x2:max_x2:grid_step]
    grid_points = np.c_[xs.ravel(), ys.ravel()]
    # Ostrzeżenie — dopasowanie bez sprawdzianu krzyżowego
    preds = model.fit(twoD, tgt).predict(grid_points).reshape(xs.shape)

    # Wyświetlanie predykcji dla punktów
    ax.pcolormesh(xs,ys,preds,cmap=plt.cm.coolwarm)
    ax.set_xlim(min_x1, max_x1)#-grid_step)
    ax.set_ylim(min_x2, max_x2)#-grid_step)

def plot_separator(model, xs, ys, label='', ax=None):
    ''' xs i ys są jednowymiarowe, ponieważ wymagają tego
        wywołania contour i decision_function '''
    if ax is None:
        ax = plt.gca()

    xy = np_cartesian_product(xs, ys)
    z_shape = (xs.size, ys.size) # Użycie .size, ponieważ dane są jednowymiarowe
    zs = model.decision_function(xy).reshape(z_shape)

    contours = ax.contour(xs, ys, zs,
                          colors='k', levels=[0],
                          linestyles=['-'])
    fmt = {contours.levels[0] : label}
    labels = ax.clabel(contours, fmt=fmt, inline_spacing=10)
    [l.set_rotation(-90) for l in labels]

def high_school_style(ax):
    ' Funkcja pomocnicza do definiowania osi wyglądających jak na wykresach ze szkoły '
    ax.spines['left'].set_position(('data', 0.0))
    ax.spines['bottom'].set_position(('data', 0.0))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

def make_ticks(lims):
        lwr, upr = sorted(lims) # Zakres osi x i y można zmienić w bibliotece mpl
        lwr = np.round(lwr).astype('int') # Zwracanie obiektów z biblioteki np
        upr = np.round(upr).astype('int')

        if lwr * upr < 0:
            return list(range(lwr, 0)) + list(range(1,upr+1))
        else:
            return list(range(lwr, upr+1))

    import matplotlib.ticker as ticker
    xticks = make_ticks(ax.get_xlim())
    yticks = make_ticks(ax.get_ylim())

    ax.xaxis.set_major_locator(ticker.FixedLocator(xticks))
    ax.yaxis.set_major_locator(ticker.FixedLocator(yticks))
    ax.set_aspect('equal')

def get_model_name(model):
    ' Zwraca nazwę modelu (klasę) jako łańcuch znaków '
    return str(model.__class__).split('.')[-1][:-2]

def rdot(w,x):
    ' Wywołuje np.dot dla argumentów po przestawieniu ich kolejności '
    return np.dot(x,w)

from sklearn.base import BaseEstimator, ClassifierMixin
class DLDA(BaseEstimator, ClassifierMixin):
    def __init__(self):
        pass

    def fit(self, train_ftrs, train_tgts):
        self.uniq_tgts = np.unique(train_tgts)
        self.means, self.priors = {}, {}

        self.var = train_ftrs.var(axis=0) # Z obciążeniem
        for tgt in self.uniq_tgts:
            cases = train_ftrs[train_tgts==tgt]
            self.means[tgt] = cases.mean(axis=0)
            self.priors[tgt] = len(cases) / len(train_ftrs)
        return self

    def predict(self, test_ftrs):
        disc = np.empty((test_ftrs.shape[0],
                         self.uniq_tgts.shape[0]))
        for tgt in self.uniq_tgts:
            # Technicznie odległość Mahalanobisa to kwadrat następującej wartości:
            mahalanobis_dists = ((test_ftrs - self.means[tgt])**2 /
                                 self.var)
            disc[:,tgt] = (-np.sum(mahalanobis_dists, axis=1) +
                           2 * np.log(self.priors[tgt]))
        return np.argmax(disc,axis=1)

def plot_lines_and_projections(axes, lines, points, xs):
    data_xs, data_ys = points[:,0], points[:,1]
    mean = np.mean(points, axis=0, keepdims=True)
    centered_data = points - mean

    for (m,b), ax in zip(lines, axes):
        mb_line = m*xs + b
        v_line = np.array([[1, 1/m if m else 0]])

        ax.plot(data_xs, data_ys, 'r.') # Niewyśrodkowane
        ax.plot(xs, mb_line, 'y')       # Niewyśrodkowane
        ax.plot(*mean.T, 'ko')

        # Wycentrowanie danych ułatwia obliczenia!
        # To odległość od punktów czerwonych do niebieskich na żółtej linii,
        # czyli odległość od projekcji punktów do średniej
        y_lengths = centered_data.dot(v_line.T) / v_line.dot(v_line.T)
        projs = y_lengths.dot(v_line)

        # Powrót do pierwotnych współrzędnych
        final = projs + mean
        ax.plot(*final.T, 'b.')

        # Łączenie punktów z ich projekcjami
        from matplotlib import collections as mc
        proj_lines = mc.LineCollection(zip(points,final))
        ax.add_collection(proj_lines)

        hypots = zip(points, np.broadcast_to(mean, points.shape))
        mean_lines = mc.LineCollection(hypots, linestyles='dashed')
        ax.add_collection(mean_lines)

# Dobrze byłoby dodać orientację
def sane_quiver(vs, ax=None, colors=None, origin=(0,0)):
    ''' Wyświetlanie surowych wektorów od środka układu '''
    vs = np.asarray(vs)
    assert vs.ndim == 2 and vs.shape[1] == 2 # Upewnianie się, że używane są wektory kolumnowe

    n = vs.shape[0]
    if not ax: ax = plt.gca()

    # zs = np.zeros(n)
    # zs = np.broadcast_to(origin, vs.shape)
    orig_x, orig_y = origin

        xs = vs.T[0] # Przekształcanie kolumn w wiersze, wiersz[0] to xs
        ys = vs.T[1]

        props = {"angles":'xy', 'scale':1, 'scale_units':'xy'}
        ax.quiver(orig_x, orig_y, xs, ys, color=colors, **props)

        ax.set_aspect('equal')
        # ax.set_axis_off()
        _min, _max = min(vs.min(), 0) -1, max(0, vs.max())+1
        ax.set_xlim(_min, _max)
        ax.set_ylim(_min, _max)

    def reweight(examples, weights):
        ''' Przekształcanie wag na liczby przykładów z użyciem dwóch
            znaczących cyfr z wag

            Istnieje pewnie z setka powodów, aby nie stosować takiego rozwiązania.
            Oto dwa najważniejsze z nich:
              1. boosting może wymagać bardziej precyzyjnych wartości 
                (lub randomizacji), aby uniknąć obciążenia
              2. to podejście *znacznie* zwiększa zbiór danych
                (co oznacza marnowanie zasobów)
        '''
        from math import gcd
        from functools import reduce

        # Która z wag jest najmniejsza?
        min_wgt = min(weights)
        min_replicate = 1 / min_wgt # Na przykład 0,25 -> 4

        # Proste duplikowanie z dokładnością do dwóch miejsc po przecinku
        counts = (min_replicate * weights * 100).astype(np.int64)

        # Przycinanie wartości, jeśli jest to możliwe
        our_gcd = reduce(gcd, counts)
        counts = counts // our_gcd

        # Funkcja repeat wymaga użycia odpowiedniego typu danych
        return np.repeat(examples, counts, axis=0)

    # examples = np.array([1, 10, 20])
    # weights = np.array([.25, .33, 1-(.25+.33)])
    # print(pd.Series(reweight(examples, weights)))

    def enumerate_outer(outer_seq):
        ''' Powiela indeks zewnętrznej kolekcji na podstawie długości wewnętrznej sekwencji '''
        return np.repeat(*zip(*enumerate(map(len, outer_seq))))

    def np_array_fromiter(itr, shape, dtype=np.float64):
        ''' Funkcja pomocnicza, ponieważ funkcja np.fromiter 
            działa tylko dla danych jednowymiarowych '''
        arr = np.empty(shape, dtype=dtype)
        for idx, itm in enumerate(itr):
            arr[idx] = itm
        return arr

    # Jak zrozumieć skomplikowany kod?
    # Zaczynaj od środka, używaj prostych danych wejściowych, zwracaj uwagę na typy danych.
    # Spróbuj wykonać błędne i poprawne wywołania z prostszymi danymi wejściowymi.
    # Czytaj dokumentację *i jednocześnie* wykonuj eksperymenty
    # [sama dokumentacja rzadko jest dla mnie zrozumiała, jeśli nie 
    # wykonuję eksperymentów w trakcie jej czytania]

    # Różnica względem „surowego” wywołania np.meshgrid polega na tym, że tu tworzone
    # są dwie kolumny wyników (tworzona jest tabela dla par tablic)
    def np_cartesian_product(*arrays):
        ''' Sztuczki z biblioteką numpy, aby wygenerować
            wszystkie możliwe kombinacje tablic wejściowych '''
        ndim = len(arrays)
        return np.stack(np.meshgrid(*arrays), axis=-1).reshape(-1, ndim)
