forked from kittykg/neural-dnf-mt-policy-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtabular_common.py
150 lines (121 loc) · 4.54 KB
/
tabular_common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import math
import random
from typing import Any
from gymnasium import Env
import numpy as np
import numpy.typing as npt
class TabularQAgent:
q_table: npt.NDArray[np.float64] | dict[Any, npt.NDArray[np.float64]]
gamma: float
alpha: float
steps_done: int
eps_end: float
eps_start: float
eps_decay: float
use_sarsa: bool
td_errors: list[float]
def __init__(
self,
gamma: float,
alpha: float,
eps_end: float,
eps_start: float,
eps_decay: float,
use_sarsa: bool = False,
) -> None:
self.gamma = gamma
self.alpha = alpha
self.steps_done = 0
self.eps_end = eps_end
self.eps_start = eps_start
self.eps_decay = eps_decay
self.use_sarsa = use_sarsa
self.td_errors = []
def _obs_processing(self, obs: Any) -> Any:
return obs
def best_value_and_action(self, obs: Any) -> tuple[float, int]:
action_values = self.q_table[obs]
best_value = np.max(action_values)
best_action = np.argmax(action_values)
return best_value, best_action # type: ignore
def select_epsilon_greedy_action(
self, env: Env, obs: Any
) -> tuple[int, float]:
sample = random.random()
eps_threshold = self.eps_end + (
self.eps_start - self.eps_end
) * math.exp(-1.0 * self.steps_done / self.eps_decay)
self.steps_done += 1
action = None
if sample > eps_threshold:
action = np.argmax(self.q_table[obs])
else:
action = env.action_space.sample()
return int(action), eps_threshold
def simulate_one_episode(self, env: Env) -> tuple[float, int, float]:
if self.use_sarsa:
return self._simulate_one_episode_sarsa(env)
else:
return self._simulate_one_episode_q_learning(env)
def _simulate_one_episode_sarsa(self, env: Env) -> tuple[float, int, float]:
total_reward = 0
obs, _ = env.reset()
obs = self._obs_processing(obs)
action, _ = self.select_epsilon_greedy_action(env, obs)
terminated = False
truncated = False
episode_duration = 0
eps_threshold = 0
while not terminated and not truncated:
next_obs, reward, terminated, truncated, _ = env.step(action)
next_obs = self._obs_processing(next_obs)
total_reward += reward # type: ignore
# Update Q table
# SARSA: Q(s,a) = Q(s,a) + alpha * (r + gamma * Q(s',a') - Q(s,a))
new_action, eps_threshold = self.select_epsilon_greedy_action(
env, next_obs
)
new_value = reward + self.gamma * self.q_table[next_obs][new_action]
old_value = self.q_table[obs][action]
self.q_table[obs][action] = old_value + self.alpha * (
new_value - old_value
)
self.td_errors.append(new_value - old_value)
obs = next_obs
action = new_action
episode_duration += 1
return total_reward, episode_duration, eps_threshold
def _simulate_one_episode_q_learning(
self, env: Env
) -> tuple[float, int, float]:
total_reward = 0
obs, _ = env.reset()
obs = self._obs_processing(obs)
terminated = False
truncated = False
episode_duration = 0
eps_threshold = 0
while not terminated and not truncated:
action, eps_threshold = self.select_epsilon_greedy_action(env, obs)
next_obs, reward, terminated, truncated, _ = env.step(action)
next_obs = self._obs_processing(next_obs)
total_reward += reward # type: ignore
# Update Q table
# Q learning: Q(s,a) = Q(s,a) + alpha * (r + gamma * max_a' Q(s',a') - Q(s,a))
best_value, _ = self.best_value_and_action(next_obs)
new_value = reward + self.gamma * best_value # type: ignore
old_value = self.q_table[obs][action]
self.q_table[obs][action] = old_value + self.alpha * (
new_value - old_value
)
self.td_errors.append(new_value - old_value)
obs = next_obs
episode_duration += 1
return total_reward, episode_duration, eps_threshold
def get_moving_average_for_plot(
data: list[float], rolling_length: int = 500
) -> npt.NDArray[np.float64]:
t = np.array(data, dtype=np.float64)
return (
np.convolve(t, np.ones(rolling_length), mode="valid") / rolling_length
)