import sys
import numpy as np

import lib.dqn_extra

sys.path.append("./")

from lib import common

import matplotlib as mpl
mpl.use("Agg")
import matplotlib.pyplot as plt


Vmax = 10
Vmin = -10
N_ATOMS = 51
DELTA_Z = (Vmax - Vmin) / (N_ATOMS - 1)


def save_distr(src, proj, name):
    plt.clf()
    p = np.arange(Vmin, Vmax+DELTA_Z, DELTA_Z)
    plt.subplot(2, 1, 1)
    plt.bar(p, src, width=0.5)
    plt.title("rdo")
    plt.subplot(2, 1, 2)
    plt.bar(p, proj, width=0.5)
    plt.title("Rzutowanie")
    plt.savefig(name + ".png")


if __name__ == "__main__":
    np.random.seed(123)
    atoms = np.arange(Vmin, Vmax+DELTA_Z, DELTA_Z)

    # rozkad z jednym ekstremum
    src_hist = np.zeros(shape=(1, N_ATOMS), dtype=np.float32)
    src_hist[0, N_ATOMS//2+1] = 1.0
    proj_hist = lib.dqn_extra.distr_projection(src_hist, np.array([2], dtype=np.float32), np.array([False]),
                                               Vmin, Vmax, N_ATOMS, gamma=0.9)
    save_distr(src_hist[0], proj_hist[0], "peak-r=2")

    # rozkad normalny
    data = np.random.normal(size=1000, scale=3)
    hist = np.histogram(data, normed=True, bins=np.arange(Vmin - DELTA_Z/2, Vmax + DELTA_Z*3/2, DELTA_Z))

    src_hist = hist[0]
    proj_hist = lib.dqn_extra.distr_projection(np.array([src_hist]), np.array([2], dtype=np.float32), np.array([False]),
                                               Vmin, Vmax, N_ATOMS, gamma=0.9)
    save_distr(hist[0], proj_hist[0], "normal-r=2")

    # rozkad normalny, epizod zakoczony
    proj_hist = lib.dqn_extra.distr_projection(np.array([src_hist]), np.array([2], dtype=np.float32), np.array([True]),
                                               Vmin, Vmax, N_ATOMS, gamma=0.9)
    save_distr(hist[0], proj_hist[0], "normal-done-r=2")

    # obcinanie w przypadku rozkadu poza zakresem
    proj_dist = lib.dqn_extra.distr_projection(np.array([src_hist]), np.array([10], dtype=np.float32), np.array([False]),
                                               Vmin, Vmax, N_ATOMS, gamma=0.9)
    save_distr(hist[0], proj_dist[0], "normal-r=10")

    # testowanie przypadkw zakoczonych i niezakoczonych, bez obcinania
    proj_hist = lib.dqn_extra.distr_projection(np.array([src_hist, src_hist]), np.array([2, 2], dtype=np.float32),
                                               np.array([False, True]), Vmin, Vmax, N_ATOMS, gamma=0.9)
    save_distr(src_hist, proj_hist[0], "both_not_clip-01-incomplete")
    save_distr(src_hist, proj_hist[1], "both_not_clip-02-complete")

    # testowanie przypadkw zakoczonych i niezakoczonych, obcinanie po prawej
    proj_hist = lib.dqn_extra.distr_projection(np.array([src_hist, src_hist]), np.array([10, 10], dtype=np.float32),
                                               np.array([False, True]), Vmin, Vmax, N_ATOMS, gamma=0.9)
    save_distr(src_hist, proj_hist[0], "both_clip-right-01-incomplete")
    save_distr(src_hist, proj_hist[1], "both_clip-right-02-complete")

    # testowanie przypadkw zakoczonych i niezakoczonych, obcinanie po lewej
    proj_hist = lib.dqn_extra.distr_projection(np.array([src_hist, src_hist]), np.array([-10, -10], dtype=np.float32),
                                               np.array([False, True]), Vmin, Vmax, N_ATOMS, gamma=0.9)
    save_distr(src_hist, proj_hist[0], "both_clip-left-01-incomplete")
    save_distr(src_hist, proj_hist[1], "both_clip-left-02-complete")

    pass
