#!/usr/bin/env python3
import gym
import time
import argparse
import numpy as np

import torch

from lib import wrappers
from lib import dqn_model

import collections

DEFAULT_ENV_NAME = "PongNoFrameskip-v4"
FPS = 25


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", required=True,
                        help="Plik modelu")
    parser.add_argument("-e", "--env", default=DEFAULT_ENV_NAME,
                        help="Nazwa rodowiska. Warto domylna=" +
                             DEFAULT_ENV_NAME)
    parser.add_argument("-r", "--record", help="Nazwa katalogu do przechowywania pliku wideo")
    parser.add_argument("--no-vis", default=True, dest='vis',
                        help="Wyczenie wizualizacji",
                        action='store_false')
    args = parser.parse_args()

    env = wrappers.make_env(args.env)
    if args.record:
        env = gym.wrappers.Monitor(env, args.record)
    net = dqn_model.DQN(env.observation_space.shape,
                        env.action_space.n)
    state = torch.load(args.model, map_location=lambda stg, _: stg)
    net.load_state_dict(state)

    state = env.reset()
    total_reward = 0.0
    c = collections.Counter()

    while True:
        start_ts = time.time()
        if args.vis:
            env.render()
        state_v = torch.tensor(np.array([state], copy=False))
        q_vals = net(state_v).data.numpy()[0]
        action = np.argmax(q_vals)
        c[action] += 1
        state, reward, done, _ = env.step(action)
        total_reward += reward
        if done:
            break
        if args.vis:
            delta = 1/FPS - (time.time() - start_ts)
            if delta > 0:
                time.sleep(delta)
    print("Nagroda sumaryczna: %.2f" % total_reward)
    print("Liczba akcji:", c)
    if args.record:
        env.env.close()

