#!/usr/bin/env python
# coding: utf-8

# Kody źródłowe do książki: Python. Uczenie maszynowe w przykładach
#  
# Rozdział 15.: Podejmowanie decyzji w skomplikowanych warunkach z wykorzystaniem uczenia przez wzmacnianie
#  
# Autor: Yuxi (Hayden) Liu (yuxi.liu.ece@gmail.com)

# # Przygotowanie środowiska do uczenia przez wzmacnianie

# ## Instalowanie Gymnasium

import gymnasium as gym
print(gym.envs.registry.keys())


# # Problem FrozenLake i programowanie dynamiczne

# ## Utworzenie środowiska FrozenLake

env = gym.make("FrozenLake-v1", render_mode="rgb_array")
 
n_state = env.observation_space.n
print(n_state)
n_action = env.action_space.n
print(n_action)


env.reset(seed=0)


import matplotlib.pyplot as plt
plt.imshow(env.render())  


new_state, reward, terminated, truncated, info = env.step(2)
is_done = terminated or truncated
    
env.render()
print(new_state)
print(reward)
print(is_done)
print(info)


plt.imshow(env.render())


def run_episode(env, policy):
    state, _ = env.reset()
    total_reward = 0
    is_done = False
    while not is_done:
        action = policy[state].item()
        state, reward, terminated, truncated, info = env.step(action)
        is_done = terminated or truncated
        total_reward += reward
        if is_done:
            break
    return total_reward


import torch

n_episode = 1000

total_rewards = []
for episode in range(n_episode):
    random_policy = torch.randint(high=n_action, size=(n_state,))
    total_reward = run_episode(env, random_policy)
    total_rewards.append(total_reward)

print(f'Średnia sumaryczna nagroda w losowej polityce: {sum(total_rewards)/n_episode}')


print(env.env.P[6])


# ## Rozwiązanie problemu przy użyciu algorytmu iteracji wartości

def value_iteration(env, gamma, threshold):
    """
    Rozwiązanie problemu FrozenLake przy użyciu algorytmu iteracji wartości
    @param env: środowisko Gymnasium
    @param gamma: współczynnik dyskontowy
    @param threshold: algorytm zakończy się, gdy zmiana wartości dla wszystkich stanów będzie mniejsza od zadanego progu
    @return: wartości optymalnej polityki
    """
    n_state = env.observation_space.n
    n_action = env.action_space.n
    V = torch.zeros(n_state)
    while True:
        V_temp = torch.empty(n_state)
        for state in range(n_state):
            v_actions = torch.zeros(n_action)
            for action in range(n_action):
                for trans_prob, new_state, reward, _ in env.env.P[state][action]:
                    v_actions[action] += trans_prob * (reward + gamma * V[new_state])
            V_temp[state] = torch.max(v_actions)
        max_delta = torch.max(torch.abs(V - V_temp))
        V = V_temp.clone()
        if max_delta <= threshold:
            break
    return V


gamma = 0.99
threshold = 0.0001

V_optimal = value_iteration(env, gamma, threshold)
print('Optymalne wartości:\n', V_optimal)



def extract_optimal_policy(env, V_optimal, gamma):
    """
    Implementacja optymalnej polityki na podstawie optymalnych wartości
    @param env: środowisko Gymnasium
    @param V_optimal: optymalne wartości
    @param gamma: współczynnik dyskontowy
    @return: optymalna polityka
    """
    n_state = env.observation_space.n
    n_action = env.action_space.n
    optimal_policy = torch.zeros(n_state)
    for state in range(n_state):
        v_actions = torch.zeros(n_action)
        for action in range(n_action):
            for trans_prob, new_state, reward, _ in env.env.P[state][action]:
                v_actions[action] += trans_prob * (reward + gamma * V_optimal[new_state])
        optimal_policy[state] = torch.argmax(v_actions)
    return optimal_policy


optimal_policy = extract_optimal_policy(env, V_optimal, gamma)
print('Optymalna polityka:\n', optimal_policy)


def run_episode(env, policy):
    state, _ = env.reset()
    total_reward = 0
    is_done = False
    while not is_done:
        action = policy[state].item()
        state, reward, terminated, truncated, info = env.step(action)
        is_done = terminated or truncated
        total_reward += reward
        if is_done:
            break
    return total_reward


n_episode = 1000
total_rewards = []
for episode in range(n_episode):
    total_reward = run_episode(env, optimal_policy)
    total_rewards.append(total_reward)

print('Średnia sumaryczna nagroda po zastosowaniu optymalnej polityki:', sum(total_rewards) / n_episode)



# ## Rozwiązanie problemu przy użyciu  algorytmu iteracji polityki

def policy_evaluation(env, policy, gamma, threshold):
    """
    Ocena polityki
    @param env: środowisko Gymnasium
    @param policy: macierz polityki zawierająca akcje i prawdopodobieństwa ich wykonania w każdym stanie
    @param gamma: współczynnik dyskontowy
    @param threshold: ocenianie zakończy się, gdy zmiana wartości dla wszystkich stanów będzie mniejsza od zadanego progu
    @return: wartości zadanej polityki
    """
    n_state = policy.shape[0]
    V = torch.zeros(n_state)
    while True:
        V_temp = torch.zeros(n_state)
        for state in range(n_state):
            action = policy[state].item()
            for trans_prob, new_state, reward, _ in env.env.P[state][action]:
                V_temp[state] += trans_prob * (reward + gamma * V[new_state])
        max_delta = torch.max(torch.abs(V - V_temp))
        V = V_temp.clone()
        if max_delta <= threshold:
            break
    return V


def policy_improvement(env, V, gamma):
    """
    Ulepszenie polityki na podstawie zadanych wartości
    @param env: środowisko Gymnasium
    @param V: wartości polityki
    @param gamma: współczynnik dyskontowy
    @return: polityka
    """
    n_state = env.observation_space.n
    n_action = env.action_space.n
    policy = torch.zeros(n_state)
    for state in range(n_state):
        v_actions = torch.zeros(n_action)
        for action in range(n_action):
            for trans_prob, new_state, reward, _ in env.env.P[state][action]:
                v_actions[action] += trans_prob * (reward + gamma * V[new_state])
        policy[state] = torch.argmax(v_actions)
    return policy


def policy_iteration(env, gamma, threshold):
    """
    Rozwiązanie problemu FrozenLake przy użyciu algorytmu iteracji polityki
    @param env: środowisko Gymnasium
    @param gamma: współczynnik dyskontowy
    @param threshold: algorytm zakończy się, gdy zmiana wartości dla wszystkich stanów będzie mniejsza od zadanego progu
    @return: optymalne wartości i optymalna polityka dla zadanego środowiska
    """    
    n_state = env.observation_space.n
    n_action = env.action_space.n
    policy = torch.randint(high=n_action, size=(n_state,)).float()
    while True:
        V = policy_evaluation(env, policy, gamma, threshold)
        policy_improved = policy_improvement(env, V, gamma)
        if torch.equal(policy_improved, policy):
            return V, policy_improved
        policy = policy_improved


gamma = 0.99
threshold = 0.0001


V_optimal, optimal_policy = policy_iteration(env, gamma, threshold)
print('Optymalne wartości:\n', V_optimal)
print('Optymalna polityka:\n', optimal_policy)


# ---

# Czytelnicy mogą pominąć następną komórkę.



