import numpy as np
import matplotlib.pyplot as plt
import pymc3 as pm
import pymc3.distributions.continuous as pmc
import pymc3.distributions.discrete as pmd
import pymc3.math as pmm

# Instrukcja instalacji biblioteki PyMC 3 (https://github.com/pymc-devs/pymc3)
# Pip: pip install pymc3
# Conda: conda install -c conda-forge pymc3
#
# W przypadku wystąpienia problemów z biblioteką h5py w dystrybucji Anaconda, zaktualizuj pakiet: conda install h5py

# Wyznacza ziarno losowe w celu odtworzenia rezultatów
np.random.seed(1000)


nb_samples = 500


if __name__ == '__main__':
    # Tworzy model PyMC3
    model = pm.Model()

    # Definiuje strukturę modelu
    with model:
        passenger_onboarding = pmc.Wald('Odprawa pasażerów', mu=0.5, lam=0.2)
        refueling = pmc.Wald('Tankowanie', mu=0.25, lam=0.5)
        departure_traffic_delay = pmc.Wald('Opóźnienie ruchu (wylot)', mu=0.1, lam=0.2)

        departure_time = pm.Deterministic('Czas wylotu',
                                          12.0 + departure_traffic_delay +
                                          pmm.switch(passenger_onboarding >= refueling,
                                                     passenger_onboarding,
                                                     refueling))

        rough_weather = pmd.Bernoulli('Warunki atmosferyczne', p=0.35)

        flight_time = pmc.Exponential('Czas trwania lotu', lam=0.5 - (0.1 * rough_weather))
        arrival_traffic_delay = pmc.Wald('Opóźnienie ruchu (przylot)', mu=0.1, lam=0.2)

        arrival_time = pm.Deterministic('Czas przylotu',
                                        departure_time +
                                        flight_time +
                                        arrival_traffic_delay)

    # Próbkuje z modelu
    # W systemie Windows z zainstalowaną platformą Anaconda 3.5 mogą występować problemy z parametrem joblib,
    # dlatego zalecam wyznaczenie jego wartości n_jobs=1
    with model:
        samples = pm.sample(draws=nb_samples, njobs=1, random_seed=1000)

    # Tworzy wykres podsumowania
    pm.summary(samples)

    # Wyświetla diagramy
    fig, ax = plt.subplots(8, 2, figsize=(14, 18))

    pm.traceplot(samples, ax=ax)

    for i in range(8):
        for j in range(2):
            ax[i, j].grid()

    ax[2, 0].set_xlim([0.05, 1.0])
    ax[3, 0].set_xlim([0.05, 0.4])
    ax[4, 0].set_xlim([12, 16])
    ax[5, 0].set_xlim([0, 10])
    ax[6, 0].set_xlim([0.05, 0.4])
    ax[7, 0].set_xlim([14, 20])

    plt.show()

