"""Przykład Q-learning do rozwiązania problemu w prostym modelu świata

Prosty deterministyczny MDP jest złożony z 6 pól (stanów)
---------------------------------
|         |          |          |
|  Start  |          |  Cel     |
|         |          |          |
---------------------------------
|         |          |          |
|         |          |  Dziura  |
|         |          |          |
---------------------------------

"""

from collections import deque
import numpy as np
import argparse
import os
import time
from termcolor import colored


class QWorld:
    def __init__(self):
        """symulowanie deterministycznego świata z 6 stanami.
        Q-uczenie wg równania Bellmana. 
        """
        # 4 akcje
        # 0 — lewo, 1 — dół, 2 — prawo, 3 — góra

        self.col = 4

        # 6 stanów
        self.row = 6

        # ustawienia otoczenia
        self.q_table = np.zeros([self.row, self.col])
        self.init_transition_table()
        self.init_reward_table()

        # współczynnik dyskontowy
        self.gamma = 0.9

        # 90% eksploracja, 10% eksploatacja
        self.epsilon = 0.9
        # eksploracja zmniejsza się o ten czynnik w każdym epizodzie
        self.epsilon_decay = 0.9
        # długofalowo, 10% eksploracja, 90% eksploatacja
        self.epsilon_min = 0.1

        # reset otoczenia
        self.reset()
        self.is_explore = True


    def reset(self):
        """początek epizodu"""
        self.state = 0
        return self.state

    def is_in_win_state(self):
        """agent wygrywa, kiedy osiągnie cel"""
        return self.state == 2


    def init_reward_table(self):
        """
        # 0 — lewo, 1 — dół, 2 — prawo, 3 — góra
        ----------------
        | 0 | 0 | 100  |
        ----------------
        | 0 | 0 | -100 |
        ----------------
        """

        self.reward_table = np.zeros([self.row, self.col])
        self.reward_table[1, 2] = 100.
        self.reward_table[4, 2] = -100.


    def init_transition_table(self):
        """
        # 0 — lewo, 1 — dół, 2 — prawo, 3 — góra
        -------------
        | 0 | 1 | 2 |
        -------------
        | 3 | 4 | 5 |
        -------------
        """

        self.transition_table = np.zeros([self.row, self.col],
                                         dtype=int)

        self.transition_table[0, 0] = 0
        self.transition_table[0, 1] = 3
        self.transition_table[0, 2] = 1
        self.transition_table[0, 3] = 0

        self.transition_table[1, 0] = 0
        self.transition_table[1, 1] = 4
        self.transition_table[1, 2] = 2
        self.transition_table[1, 3] = 1

        # graniczny stan docelowy
        self.transition_table[2, 0] = 2
        self.transition_table[2, 1] = 2
        self.transition_table[2, 2] = 2
        self.transition_table[2, 3] = 2

        self.transition_table[3, 0] = 3
        self.transition_table[3, 1] = 3
        self.transition_table[3, 2] = 4
        self.transition_table[3, 3] = 0

        self.transition_table[4, 0] = 3
        self.transition_table[4, 1] = 4
        self.transition_table[4, 2] = 5
        self.transition_table[4, 3] = 1

        # graniczny stan „dziura”
        self.transition_table[5, 0] = 5
        self.transition_table[5, 1] = 5
        self.transition_table[5, 2] = 5
        self.transition_table[5, 3] = 5
        
    
    def step(self, action):
        """wykonuje akcję na otoczeniu
        Argument:
            action (tensor): akcja w przestrzeni akcji
        Zwraca:
            next_state (tensor): następny stan otoczenia
            reward (float): nagroda dla agenta
            done (Bool): czy osiągnięto stan graniczny
        """
        # określenie następnego stanu (next_state), mając dany stan i akcję 

        next_state = self.transition_table[self.state, action]
        # done ma wartość True, jeśli next_state to cel lub dziura
        done = next_state == 2 or next_state == 5
        # nagroda dla stanu i akcji
        reward = self.reward_table[self.state, action]
        # otoczenie jest teraz w nowym stanie
        self.state = next_state
        return next_state, reward, done

    
    def act(self):
        """określenie następnej akcji: albo z tabeli Q (eksploatacja), albo wybór losowy
        Zwraca:
            action (tensor): akcja, którą musi wykonać agent
        """
        # 0 — lewo, 1 — dół, 2 — prawo, 3 — góra
        # akcja w ramach eksploracji 

        if np.random.rand() <= self.epsilon:
            # eksploruj — wykonaj losową akcję
            self.is_explore = True
            return np.random.choice(4,1)[0]

        # lub akcja jest eksploatacją — wybierz akcję 
        # z maksymalną wartością Q
        self.is_explore = False
        action = np.argmax(self.q_table[self.state])
        return action


    def update_q_table(self, state, action, reward, next_state):
        """Q-uczenie — aktualizacja tabeli Q przy użyciu Q(s, a)
        Argumenty:
            state (tensor): stan agenta
            action (tensor): akcja wykonana przez agenta
            reward (float): nagroda po wykonaniu akcji dla danego stanu
            next_state (tensor): następny stan po wykonaniu akcji dla danego stanu
        """
        # Q(s, a) = nagroda+gamma * max_a' Q(s', a')

        q_value = self.gamma * np.amax(self.q_table[next_state])
        q_value += reward
        self.q_table[state, action] = q_value


    def print_q_table(self):
        """UI to dump Q Table contents"""
        print("Q-Table (Epsilon: %0.2f)" % self.epsilon)
        print(self.q_table)


    def update_epsilon(self):
        """aktualizacja mieszanki eksploracja-eksploatacja"""
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay


    def print_cell(self, row=0):
        """Interfejs użytkownika do wyświetlania agenta poruszającego się po siatce"""
        print("")
        for i in range(13):
            j = i - 2
            if j in [0, 4, 8]: 
                if j == 8:
                    if self.state == 2 and row == 0:
                        marker = "\033[4mG\033[0m"
                    elif self.state == 5 and row == 1:
                        marker = "\033[4mH\033[0m"
                    else:
                        marker = 'G' if row == 0 else 'H'
                    color = self.state == 2 and row == 0
                    color = color or (self.state == 5 and row == 1)
                    color = 'red' if color else 'blue'
                    print(colored(marker, color), end='')
                elif self.state in [0, 1, 3, 4]:
                    cell = [(0, 0, 0), (1, 0, 4), (3, 1, 0), (4, 1, 4)]
                    marker = '_' if (self.state, row, j) in cell else ' '
                    print(colored(marker, 'red'), end='')
                else:
                    print(' ', end='')
            elif i % 4 == 0:
                    print('|', end='')
            else:
                print(' ', end='')
        print("")


    def print_world(self, action, step):
        """Interfejs użytkownika do wyświetlania trybu i akcji agenta"""
        actions = { 0: "(Lewo)", 1: "(Dol)", 2: "(Prawo)", 3: "(Góra)" }
        explore = "Eksploraca" if self.is_explore else "Eksploatacja"
        print("Krok", step, ":", explore, actions[action])
        for _ in range(13):
            print('-', end='')
        self.print_cell()
        for _ in range(13):
            print('-', end='')
        self.print_cell(row=1)
        for _ in range(13):
            print('-', end='')
        print("")


def print_episode(episode, delay=1):
    """Interfejs użytkownika do wyświetlania zliczenia epizodów
    Argumenty:
        episode (int): numer epizodu
        delay (int): opóźnienie [sek]

    """
    os.system('clear')
    for _ in range(13):
        print('=', end='')
    print("")
    print("Epizod ", episode)
    for _ in range(13):
        print('=', end='')
    print("")
    time.sleep(delay)


def print_status(q_world, done, step, delay=1):
    """Interfejs użytkownika do wyświetlenia Świata, 
        jednosekundowe opóźnienie by było łatwiej zrozumieć
    """
    os.system('clear')
    q_world.print_world(action, step)
    q_world.print_q_table()
    if done:
        print("------- EPIZOD UKONCZONY --------")
        delay *= 2
    time.sleep(delay)


# stan, akcja, nagroda, iteracja w następnym stanie
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Uczenie i pokazanie końcowej tabeli Q"
    parser.add_argument("-t",
                        "--train",
                        help=help_,
                        action='store_true')
    args = parser.parse_args()

    if args.train:
        maxwins = 2000
        delay = 0
    else:
        maxwins = 10
        delay = 1

    wins = 0
    episode_count = 10 * maxwins
    # punktacja (maksymalna liczba kroków przed osiągnięciem celu) - dobry wskaźnik postepow uczenia
    scores = deque(maxlen=maxwins)
    q_world = QWorld()
    step = 1

    # iteracja stan, akcja nagroda, następny stan 
    for episode in range(episode_count):
        state = q_world.reset()
        done = False
        print_episode(episode, delay=delay)
        while not done:
            action = q_world.act()
            next_state, reward, done = q_world.step(action)
            q_world.update_q_table(state, action, reward, next_state)
            print_status(q_world, done, step, delay=delay)
            state = next_state
            # jeśli epizod zakończony, zacznij od nowa
            if done:
                if q_world.is_in_win_state():
                    wins += 1
                    scores.append(step)
                    if wins > maxwins:
                        print(scores)
                        exit(0)
                # Eksploracja/eksploatacja jest aktualizowana po każdym epizodzie
                q_world.update_epsilon()
                step = 1
            else:
                step += 1

    print(scores)
    q_world.print_q_table()
