-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSarsa_for_cartpole.py
85 lines (67 loc) · 3.26 KB
/
Sarsa_for_cartpole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gym
import numpy as np
import math
class CartPoleAgent():
def __init__(self, buckets=(1, 1, 6, 12), num_episodes=1000, min_lr=0.1, min_epsilon=0.1, discount=0.98, decay=25):
self.buckets = buckets
self.num_episodes = num_episodes
self.min_lr = min_lr
self.min_epsilon = min_epsilon
self.discount = discount
self.decay = decay
self.env = gym.make('CartPole-v0')
# [position, velocity, angle, angular velocity]
self.upper_bounds = [self.env.observation_space.high[0], 0.5, self.env.observation_space.high[2], math.radians(50) / 1.]
self.lower_bounds = [self.env.observation_space.low[0], -0.5, self.env.observation_space.low[2], -math.radians(50) / 1.]
self.sarsa_table = np.zeros(self.buckets + (self.env.action_space.n,))
def discretize_state(self, obs):
discretized = list()
for i in range(len(obs)):
scaling = (obs[i] + abs(self.lower_bounds[i])) / (self.upper_bounds[i] - self.lower_bounds[i])
new_obs = int(round((self.buckets[i] - 1) * scaling))
new_obs = min(self.buckets[i] - 1, max(0, new_obs))
discretized.append(new_obs)
return tuple(discretized)
def choose_action(self, state):
if (np.random.random() < self.epsilon):
return self.env.action_space.sample()
else:
return np.argmax(self.sarsa_table[state])
def update_sarsa(self, state, action, reward, new_state, new_action):
self.sarsa_table[state][action] += self.learning_rate * (reward + self.discount * (self.sarsa_table[new_state][new_action]) - self.sarsa_table[state][action])
def get_epsilon(self, t):
return max(self.min_epsilon, min(1., 1. - math.log10((t + 1) / self.decay)))
def get_learning_rate(self, t):
return max(self.min_lr, min(1., 1. - math.log10((t + 1) / self.decay)))
def train(self):
for e in range(self.num_episodes):
current_state = self.discretize_state(self.env.reset())
self.learning_rate = self.get_learning_rate(e)
self.epsilon = self.get_epsilon(e)
done = False
while not done:
action = self.choose_action(current_state)
obs, reward, done, _ = self.env.step(action)
new_state = self.discretize_state(obs)
new_action = self.choose_action(new_state)
self.update_sarsa(current_state, action, reward, new_state, new_action)
current_state = new_state
print('Finished training!')
def run(self):
self.env = gym.wrappers.Monitor(self.env,'cartpole')
t = 0
done = False
current_state = self.discretize_state(self.env.reset())
while not done:
self.env.render()
t = t+1
action = self.choose_action(current_state)
obs, reward, done, _ = self.env.step(action)
new_state = self.discretize_state(obs)
current_state = new_state
return t
if __name__ == "__main__":
agent = CartPoleAgent()
agent.train()
t = agent.run()
print("Time", t)