-
Notifications
You must be signed in to change notification settings - Fork 0
/
mdp.py
143 lines (129 loc) · 4.81 KB
/
mdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gymnasium as gym
import matplotlib.pyplot as plt
LEFT = 0
DOWN = 1
RIGHT = 2
UP = 3
HOLE_REWARD = -1
GOAL_REWARD = 1000
DEFALT_REWARD = 1
LEARNING_RATE = 0.99
ENV_NAME = "FrozenLake-v1"
class MDP:
def __init__(self, grid):
self.n_rows = len(grid)
self.n_cols = len(grid[0])
self.grid = grid
self.m = "".join(grid)
self.value_function = [0] * self.n_rows * self.n_cols
self.actions = [LEFT, RIGHT, UP, DOWN]
def get_next_state(self, state, action):
# Check if we are at leftmost column.
if action == LEFT:
return (
state - 1
if state % self.n_cols != 0 and self.m[state - 1] != "H"
else state
)
# Check if we are at rightmost column.
elif action == RIGHT:
return (
state + 1
if state % self.n_cols != self.n_cols - 1 and self.m[state + 1] != "H"
else state
)
# Check if we are at the topmost row.
elif action == UP:
return state - self.n_cols if state // self.n_cols != 0 else state
# Check if we are at the bottom row.
elif action == DOWN:
return (
state + self.n_cols
if state // self.n_cols != self.n_rows - 1
and self.m[state + self.n_cols] != "H"
else state
)
return state
def get_reward(self, state):
if self.m[state] == "H":
return -HOLE_REWARD
elif self.m[state] == "G":
return GOAL_REWARD
else:
return DEFALT_REWARD
def run_iterative_policy_evaluation(self, steps=1000):
max_deltas = []
for step in range(steps):
new_value_function = [0.0] * self.n_rows * self.n_cols
max_delta = 0
for state in range(len(self.m)):
new_value = 0.0
for action in self.actions:
next_state = self.get_next_state(state, action)
reward = self.get_reward(next_state)
new_value += (
1
/ len(self.actions)
* (reward + LEARNING_RATE * self.value_function[next_state])
)
new_value_function[state] = new_value
delta = abs(new_value_function[state] - self.value_function[state])
max_delta = (
delta
if not max_delta
else max(
max_delta,
abs(new_value_function[state] - self.value_function[state]),
)
)
max_deltas.append((step, max_delta))
self.value_function = new_value_function
self.plot_value_function_deltas(max_deltas)
def run(self, episodes=1, video_folder=None):
if video_folder:
tmp_env = gym.make(
ENV_NAME, desc=self.grid, is_slippery=False, render_mode="rgb_array"
)
env = gym.wrappers.RecordVideo(
env=tmp_env,
video_folder=video_folder,
name_prefix="run",
episode_trigger=lambda e: True,
)
else:
env = gym.make(
ENV_NAME, desc=self.grid, is_slippery=False, render_mode="human"
)
for _ in range(episodes):
state = env.reset()[0]
while True:
max_value_action = None
max_action_value = None
for action in self.actions:
next_state = self.get_next_state(state, action)
if next_state == state:
continue
reward = self.get_reward(next_state)
value = (
1
/ len(self.actions)
* (reward + LEARNING_RATE * self.value_function[next_state])
)
if not max_action_value or value > max_action_value:
max_action_value = value
max_value_action = action
observation, _, terminated, truncated, _ = env.step(max_value_action)
state = observation
if terminated or truncated:
break
env.close()
def plot_value_function_deltas(self, deltas):
steps, max_deltas = zip(*deltas)
plt.plot(steps, max_deltas, linestyle="-", markersize=4)
plt.xlabel("Step Number")
plt.ylabel("Max Delta")
plt.title(
f"Max Delta by Step in Policy Evaluation ({self.n_rows} x {self.n_cols})"
)
plt.grid(True)
plt.savefig(f"./assets/max_delta_plot_{self.n_rows}_{self.n_cols}.png")