forked from rasbt/machine-learning-book
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,828 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
|
||
## Chapter 19: Reinforcement Learning for Decision Making in Complex Environments | ||
|
||
|
||
### Chapter Outline | ||
|
||
- Introduction: learning from experience | ||
- Understanding reinforcement learning | ||
- Defining the agent-environment interface of a reinforcement learning system | ||
- The theoretical foundations of RL | ||
- Markov decision processes | ||
- The mathematical formulation of Markov decision processes | ||
- Visualization of a Markov process | ||
- Episodic versus continuing tasks | ||
- RL terminology: return, policy, and value function | ||
- The return | ||
- Policy | ||
- Value function | ||
- Dynamic programming using the Bellman equation | ||
- Reinforcement learning algorithms | ||
- Dynamic programming | ||
- Policy evaluation – predicting the value function with dynamic programmin | ||
- Improving the policy using the estimated value function | ||
- Policy iteration | ||
- Value iteration | ||
- Reinforcement learning with Monte Carlo | ||
- State-value function estimation using MC | ||
- Action-value function estimation using MC | ||
- Finding an optimal policy using MC control | ||
- Policy improvement – computing the greedy policy from the action-value function | ||
- Temporal difference learning | ||
- TD prediction | ||
- On-policy TD control (SARSA) | ||
- Off-policy TD control (Q-learning) | ||
- Implementing our first RL algorithm | ||
- Introducing the OpenAI Gym toolkit | ||
- Working with the existing environments in OpenAI Gym | ||
- A grid world example | ||
- Implementing the grid world environment in OpenAI Gym | ||
- Solving the grid world problem with Q-learning | ||
- Implementing the Q-learning algorithm | ||
- A glance at deep Q-learning | ||
- Training a DQN model according to the Q-learning algorithm | ||
- Replay memory | ||
- Determining the target values for computing the loss | ||
- Implementing a deep Q-learning algorithm | ||
- Chapter and book summary | ||
|
||
**Please refer to the [README.md](../ch01/README.md) file in [`../ch01`](../ch01) for more information about running the code examples.** | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# coding: utf-8 | ||
|
||
# Python Machine Learning, PyTorch Edition by Sebastian Raschka (https://sebastianraschka.com), Yuxi (Hayden) Liu | ||
# (https://www.mlexample.com/) & Vahid Mirjalili (http://vahidmirjalili.com), Packt Publishing Ltd. 2021 | ||
# | ||
# Code Repository: https://github.com | ||
# | ||
# Code License: MIT License (https://github.com/ /LICENSE.txt) | ||
|
||
################################################################################# | ||
# Chapter 19 - Reinforcement Learning for Decision Making in Complex Environments | ||
################################################################################# | ||
|
||
# Script: carpole/main.py | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import random | ||
import matplotlib.pyplot as plt | ||
from collections import namedtuple | ||
from collections import deque | ||
|
||
np.random.seed(1) | ||
torch.manual_seed(1) | ||
|
||
Transition = namedtuple( | ||
'Transition', ('state', 'action', 'reward', | ||
'next_state', 'done')) | ||
|
||
|
||
class DQNAgent: | ||
def __init__( | ||
self, env, discount_factor=0.95, | ||
epsilon_greedy=1.0, epsilon_min=0.01, | ||
epsilon_decay=0.995, learning_rate=1e-3, | ||
max_memory_size=2000): | ||
self.env = env | ||
self.state_size = env.observation_space.shape[0] | ||
self.action_size = env.action_space.n | ||
|
||
self.memory = deque(maxlen=max_memory_size) | ||
|
||
self.gamma = discount_factor | ||
self.epsilon = epsilon_greedy | ||
self.epsilon_min = epsilon_min | ||
self.epsilon_decay = epsilon_decay | ||
self.lr = learning_rate | ||
self._build_nn_model() | ||
|
||
def _build_nn_model(self): | ||
self.model = nn.Sequential(nn.Linear(self.state_size, 256), | ||
nn.ReLU(), | ||
nn.Linear(256, 128), | ||
nn.ReLU(), | ||
nn.Linear(128, 64), | ||
nn.ReLU(), | ||
nn.Linear(64, self.action_size)) | ||
|
||
self.loss_fn = nn.MSELoss() | ||
self.optimizer = torch.optim.Adam( | ||
self.model.parameters(), self.lr) | ||
|
||
def remember(self, transition): | ||
self.memory.append(transition) | ||
|
||
def choose_action(self, state): | ||
if np.random.rand() <= self.epsilon: | ||
return np.random.choice(self.action_size) | ||
with torch.no_grad(): | ||
q_values = self.model(torch.tensor(state, dtype=torch.float32))[0] | ||
return torch.argmax(q_values).item() # returns action | ||
|
||
def _learn(self, batch_samples): | ||
batch_states, batch_targets = [], [] | ||
for transition in batch_samples: | ||
s, a, r, next_s, done = transition | ||
|
||
with torch.no_grad(): | ||
if done: | ||
target = r | ||
else: | ||
pred = self.model(torch.tensor(next_s, dtype=torch.float32))[0] | ||
target = r + self.gamma * pred.max() | ||
|
||
target_all = self.model(torch.tensor(s, dtype=torch.float32))[0] | ||
target_all[a] = target | ||
|
||
batch_states.append(s.flatten()) | ||
batch_targets.append(target_all) | ||
self._adjust_epsilon() | ||
|
||
self.optimizer.zero_grad() | ||
pred = self.model(torch.tensor(batch_states, dtype=torch.float32)) | ||
|
||
loss = self.loss_fn(pred, torch.stack(batch_targets)) | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
return loss.item() | ||
|
||
def _adjust_epsilon(self): | ||
if self.epsilon > self.epsilon_min: | ||
self.epsilon *= self.epsilon_decay | ||
|
||
def replay(self, batch_size): | ||
samples = random.sample(self.memory, batch_size) | ||
return self._learn(samples) | ||
|
||
|
||
def plot_learning_history(history): | ||
fig = plt.figure(1, figsize=(14, 5)) | ||
ax = fig.add_subplot(1, 1, 1) | ||
episodes = np.arange(len(history)) + 1 | ||
plt.plot(episodes, history, lw=4, | ||
marker='o', markersize=10) | ||
ax.tick_params(axis='both', which='major', labelsize=15) | ||
plt.xlabel('Episodes', size=20) | ||
plt.ylabel('Total rewards', size=20) | ||
plt.show() | ||
|
||
|
||
# General settings | ||
EPISODES = 200 | ||
batch_size = 32 | ||
init_replay_memory_size = 500 | ||
|
||
if __name__ == '__main__': | ||
env = gym.make('CartPole-v1') | ||
agent = DQNAgent(env) | ||
state = env.reset() | ||
state = np.reshape(state, [1, agent.state_size]) | ||
|
||
# Filling up the replay-memory | ||
for i in range(init_replay_memory_size): | ||
action = agent.choose_action(state) | ||
next_state, reward, done, _ = env.step(action) | ||
next_state = np.reshape(next_state, [1, agent.state_size]) | ||
agent.remember(Transition(state, action, reward, | ||
next_state, done)) | ||
if done: | ||
state = env.reset() | ||
state = np.reshape(state, [1, agent.state_size]) | ||
else: | ||
state = next_state | ||
|
||
total_rewards, losses = [], [] | ||
for e in range(EPISODES): | ||
state = env.reset() | ||
if e % 10 == 0: | ||
env.render() | ||
state = np.reshape(state, [1, agent.state_size]) | ||
for i in range(500): | ||
action = agent.choose_action(state) | ||
next_state, reward, done, _ = env.step(action) | ||
next_state = np.reshape(next_state, [1, agent.state_size]) | ||
agent.remember(Transition(state, action, reward, | ||
next_state, done)) | ||
state = next_state | ||
if e % 10 == 0: | ||
env.render() | ||
if done: | ||
total_rewards.append(i) | ||
print(f'Episode: {e}/{EPISODES}, Total reward: {i}') | ||
break | ||
loss = agent.replay(batch_size) | ||
losses.append(loss) | ||
plot_learning_history(total_rewards) |
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# coding: utf-8 | ||
|
||
# Python Machine Learning, PyTorch Edition by Sebastian Raschka (https://sebastianraschka.com), Yuxi (Hayden) Liu | ||
# (https://www.mlexample.com/) & Vahid Mirjalili (http://vahidmirjalili.com), Packt Publishing Ltd. 2021 | ||
# | ||
# Code Repository: | ||
# | ||
# Code License: MIT License (https://github.com/ /LICENSE.txt) | ||
|
||
################################################################################# | ||
# Chapter 19 - Reinforcement Learning for Decision Making in Complex Environments | ||
################################################################################# | ||
|
||
# Script: agent.py | ||
|
||
from collections import defaultdict | ||
import numpy as np | ||
|
||
|
||
class Agent(object): | ||
def __init__( | ||
self, env, | ||
learning_rate=0.01, | ||
discount_factor=0.9, | ||
epsilon_greedy=0.9, | ||
epsilon_min=0.1, | ||
epsilon_decay=0.95): | ||
self.env = env | ||
self.lr = learning_rate | ||
self.gamma = discount_factor | ||
self.epsilon = epsilon_greedy | ||
self.epsilon_min = epsilon_min | ||
self.epsilon_decay = epsilon_decay | ||
|
||
# Define the q_table | ||
self.q_table = defaultdict(lambda: np.zeros(self.env.nA)) | ||
|
||
def choose_action(self, state): | ||
if np.random.uniform() < self.epsilon: | ||
action = np.random.choice(self.env.nA) | ||
else: | ||
q_vals = self.q_table[state] | ||
perm_actions = np.random.permutation(self.env.nA) | ||
q_vals = [q_vals[a] for a in perm_actions] | ||
perm_q_argmax = np.argmax(q_vals) | ||
action = perm_actions[perm_q_argmax] | ||
return action | ||
|
||
def _learn(self, transition): | ||
s, a, r, next_s, done = transition | ||
q_val = self.q_table[s][a] | ||
if done: | ||
q_target = r | ||
else: | ||
q_target = r + self.gamma*np.max(self.q_table[next_s]) | ||
|
||
# Update the q_table | ||
self.q_table[s][a] += self.lr * (q_target - q_val) | ||
|
||
# Adjust the epsilon | ||
self._adjust_epsilon() | ||
|
||
def _adjust_epsilon(self): | ||
if self.epsilon > self.epsilon_min: | ||
self.epsilon *= self.epsilon_decay |
Oops, something went wrong.