-
Notifications
You must be signed in to change notification settings - Fork 3
/
agent.py
66 lines (42 loc) · 1.85 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import itertools
from stable_baselines3 import SAC, PPO, DDPG, DQN, HerReplayBuffer
import numpy as np
from torch.nn.functional import normalize
class Agent:
def __init__(self, env, total_timesteps, log_interval):
self.env = env
self.model = SAC('MlpPolicy', env, verbose=1)
#self.model = PPO('MlpPolicy', env, learning_rate=1e-4, n_steps=1024, gamma=0.9, gae_lambda=0.95, verbose=1)
self.model.learn(total_timesteps=total_timesteps, log_interval=log_interval)#5000 for di 20000 for lane
def generate_agent_traj(self, n_traj):
trajs = []
for i in range(n_traj):
obs = self.env.reset()
single_traj = []
done = False
t = 0
while not done:
single_traj.append(obs)
action, _states = self.model.predict(obs, deterministic=True)
obs, _, done, _ = self.env.step(action)
t += 1
single_traj.append(obs)
trajs.append(single_traj)
return np.array(trajs)
# def generate_test_traj(self, n_traj):
# trajs = []
# acts = []
# for i in range(n_traj):
# obs = self.env.reset()
# single_traj = []
# single_act = []
# done = False
# while not done:
# single_traj.append(self.env.state)
# action, _states = self.model.predict(obs, deterministic=True)
# single_act.append(action)
# obs, _, done, _ = self.env.step(action)
# single_traj.append(self.env.state)
# trajs.append(single_traj)
# acts.append(single_act)
# return np.array(trajs), np.array(acts)