import torch
import gym
env = gym.make('FrozenLake-v0')
gamma = 0.99
threshold = 0.0001

def policy_evaluation(env, policy, gamma, threshold):
    """
    Ocena polityki
    @param env: środowisko OpenAI Gym
    @param policy: macierz polityki zawierająca akcje i prawdopodobieństwa ich wykonania w każdym stanie
    @param gamma: współczynnik dyskontowy
    @param threshold: ocenianie zakończy się, gdy zmiana wartości dla wszystkich stanów będzie mniejsza od zadanego progu
    @return: wartości zadanej polityki
    """
    n_state = policy.shape[0]
    V = torch.zeros(n_state)
    while True:
        V_temp = torch.zeros(n_state)
        for state in range(n_state):
            action = policy[state].item()
            for trans_prob, new_state, reward, _ in env.env.P[state][action]:
                V_temp[state] += trans_prob * (reward + gamma * V[new_state])
        max_delta = torch.max(torch.abs(V - V_temp))
        V = V_temp.clone()
        if max_delta <= threshold:
            break
    return V

def policy_improvement(env, V, gamma):
    """
    Ulepszenie polityki na podstawie zadanych wartości
    @param env: środowisko OpenAI Gym
    @param V: wartości polityki
    @param gamma: współczynnik dyskontowy
    @return: polityka
    """
    n_state = env.observation_space.n
    n_action = env.action_space.n
    policy = torch.zeros(n_state)
    for state in range(n_state):
        v_actions = torch.zeros(n_action)
        for action in range(n_action):
            for trans_prob, new_state, reward, _ in env.env.P[state][action]:
                v_actions[action] += trans_prob * (reward + gamma * V[new_state])
        policy[state] = torch.argmax(v_actions)
    return policy

def policy_iteration(env, gamma, threshold):
    """
    Rozwiązanie problemu FrozenLake przy użyciu algorytmu iteracji polityki
    @param env: środowisko OpenAI Gym
    @param gamma: współczynnik dyskontowy
    @param threshold: algorytm zakończy się, gdy zmiana wartości dla wszystkich stanów będzie mniejsza od zadanego progu
    @return: optymalne wartości i optymalna polityka dla zadanego środowiska
    """
    n_state = env.observation_space.n
    n_action = env.action_space.n
    policy = torch.randint(high=n_action, size=(n_state,)).float()
    while True:
        V = policy_evaluation(env, policy, gamma, threshold)
        policy_improved = policy_improvement(env, V, gamma)
        if torch.equal(policy_improved, policy):
            return V, policy_improved
        policy = policy_improved

V_optimal, optimal_policy = policy_iteration(env, gamma, threshold)
print('Optymalne wartości:\n', V_optimal)
print('Optymalna polityka:\n', optimal_policy)