-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbittle_watch.py
75 lines (60 loc) · 1.94 KB
/
bittle_watch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import gym
from stable_baselines3 import PPO, SAC, DDPG, A2C
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from bittle_env import BittleEnv
import time
import matplotlib.pyplot as plt
import numpy as np
def loadEnv(model_name, vec_norm=False):
'''Load environment based on model name'''
# Create OpenCatGym environment from class
env = BittleEnv(GUI=True)
#if trained with VecNormalize Wrapper
if vec_norm:
#Get Saved Stats
stats_path = os.path.join('Training','VecNormalize_Stats',model_name+'.pkl')
env = make_vec_env(lambda: env, n_envs=1)
env = VecNormalize.load(stats_path, env)
# do not update stats at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False
#Get path to model
model_path = os.path.join('Training','Saved_Models',model_name)
#Reload the saved model
model = PPO.load(model_path,env=env)
return model, env
def watch(model, env, plot=False):
obs = env.reset()
#For plot
score = 0
count = 0
x = []
y = []
for _ in range(20000):
action, _state = model.predict(obs, deterministic=False)
obs, reward, done, info = env.step(action)
#env.render()
print(reward)
x.append(count)
count+=1
if type(reward) == np.ndarray:
score+=reward[0]
else:
score+=reward
y.append(score)
if done and plot:
plt.scatter(x, y)
plt.show()
break
elif done:
obs = env.reset()
#close environment
env.close()
if __name__ == '__main__':
model, env = loadEnv('PPO_Model_Bittle89',vec_norm=True)
watch(model, env, plot=False)