Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Deep Q-Learning for Atari Breakout example for Keras v3 #1803

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 56 additions & 55 deletions examples/rl/deep_q_network_breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Title: Deep Q-Learning for Atari Breakout
Author: [Jacob Chapman](https://twitter.com/jacoblchapman) and [Mathias Lechner](https://twitter.com/MLech20)
Date created: 2020/05/23
Last modified: 2020/06/17
Last modified: 2024/03/17
Description: Play Atari Breakout with a Deep Q-Network.
Accelerator: NONE
Accelerator: None
"""

"""
Expand All @@ -13,19 +13,6 @@
This script shows an implementation of Deep Q-Learning on the
`BreakoutNoFrameskip-v4` environment.

This example requires the following dependencies: `baselines`, `atari-py`, `rows`.
They can be installed via:

```
git clone https://github.com/openai/baselines.git
cd baselines
pip install -e .
git clone https://github.com/openai/atari-py
wget http://www.atarimania.com/roms/Roms.rar
unrar x Roms.rar .
python -m atari_py.import_roms .
```

### Deep Q-Learning

As an agent takes actions and moves through an environment, it learns to map
Expand All @@ -51,20 +38,29 @@
game experience in total)". However this script will give good results at around 10
million frames which are processed in less than 24 hours on a modern machine.

You can control the number of episodes by setting the `max_episodes` variable
to a value greater than 0.

### References

- [Q-Learning](https://link.springer.com/content/pdf/10.1007/BF00992698.pdf)
- [Deep Q-Learning](https://deepmind.com/research/publications/human-level-control-through-deep-reinforcement-learning)
- [Deep Q-Learning](https://www.semanticscholar.org/paper/Human-level-control-through-deep-reinforcement-Mnih-Kavukcuoglu/340f48901f72278f6bf78a04ee5b01df208cc508)
"""
"""
## Setup
"""

from baselines.common.atari_wrappers import make_atari, wrap_deepmind
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers

import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStack
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Configuration paramaters for the whole setup
seed = 42
Expand All @@ -77,13 +73,16 @@
) # Rate at which to reduce chance of random action being taken
batch_size = 32 # Size of batch taken from replay buffer
max_steps_per_episode = 10000

# Use the Baseline Atari environment because of Deepmind helper functions
env = make_atari("BreakoutNoFrameskip-v4")
# Warp the frames, grey scale, stake four frame and scale to smaller ratio
env = wrap_deepmind(env, frame_stack=True, scale=True)
max_episodes = 10 # Limit training episodes, will run until solved if smaller than 1

# Use the Atari environment
# Specify the `render_mode` parameter to show the attempts of the agent in a pop up window.
env = gym.make("BreakoutNoFrameskip-v4") # , render_mode="human")
# Environment preprocessing
env = AtariPreprocessing(env)
# Stack four frames
env = FrameStack(env, 4)
env.seed(seed)

"""
## Implement the Deep Q-Network

Expand All @@ -99,26 +98,23 @@

def create_q_model():
# Network defined by the Deepmind paper
inputs = layers.Input(
shape=(
84,
84,
4,
)
return keras.Sequential(
[
layers.Lambda(
lambda tensor: keras.ops.transpose(tensor, [0, 2, 3, 1]),
output_shape=(84, 84, 4),
input_shape=(4, 84, 84),
),
# Convolutions on the frames on the screen
layers.Conv2D(32, 8, strides=4, activation="relu", input_shape=(4, 84, 84)),
layers.Conv2D(64, 4, strides=2, activation="relu"),
layers.Conv2D(64, 3, strides=1, activation="relu"),
layers.Flatten(),
layers.Dense(512, activation="relu"),
layers.Dense(num_actions, activation="linear"),
]
)

# Convolutions on the frames on the screen
layer1 = layers.Conv2D(32, 8, strides=4, activation="relu")(inputs)
layer2 = layers.Conv2D(64, 4, strides=2, activation="relu")(layer1)
layer3 = layers.Conv2D(64, 3, strides=1, activation="relu")(layer2)

layer4 = layers.Flatten()(layer3)

layer5 = layers.Dense(512, activation="relu")(layer4)
action = layers.Dense(num_actions, activation="linear")(layer5)

return keras.Model(inputs=inputs, outputs=action)


# The first model makes the predictions for Q-values which are used to
# make a action.
Expand Down Expand Up @@ -160,13 +156,12 @@ def create_q_model():
# Using huber loss for stability
loss_function = keras.losses.Huber()

while True: # Run until solved
state = np.array(env.reset())
while True:
observation, _ = env.reset()
state = np.array(observation)
episode_reward = 0

for timestep in range(1, max_steps_per_episode):
# env.render(); Adding this line would show the attempts
# of the agent in a pop up window.
frame_count += 1

# Use epsilon-greedy for exploration
Expand All @@ -176,18 +171,18 @@ def create_q_model():
else:
# Predict action Q-values
# From environment state
state_tensor = tf.convert_to_tensor(state)
state_tensor = tf.expand_dims(state_tensor, 0)
state_tensor = keras.ops.convert_to_tensor(state)
state_tensor = keras.ops.expand_dims(state_tensor, 0)
action_probs = model(state_tensor, training=False)
# Take best action
action = tf.argmax(action_probs[0]).numpy()
action = keras.ops.argmax(action_probs[0]).numpy()

# Decay probability of taking random action
epsilon -= epsilon_interval / epsilon_greedy_frames
epsilon = max(epsilon, epsilon_min)

# Apply the sampled action in our environment
state_next, reward, done, _ = env.step(action)
state_next, reward, done, _, _ = env.step(action)
state_next = np.array(state_next)

episode_reward += reward
Expand All @@ -210,30 +205,30 @@ def create_q_model():
state_next_sample = np.array([state_next_history[i] for i in indices])
rewards_sample = [rewards_history[i] for i in indices]
action_sample = [action_history[i] for i in indices]
done_sample = tf.convert_to_tensor(
done_sample = keras.ops.convert_to_tensor(
[float(done_history[i]) for i in indices]
)

# Build the updated Q-values for the sampled future states
# Use the target model for stability
future_rewards = model_target.predict(state_next_sample)
# Q value = reward + discount factor * expected future reward
updated_q_values = rewards_sample + gamma * tf.reduce_max(
updated_q_values = rewards_sample + gamma * keras.ops.amax(
future_rewards, axis=1
)

# If final frame set the last value to -1
updated_q_values = updated_q_values * (1 - done_sample) - done_sample

# Create a mask so we only calculate loss on the updated Q-values
masks = tf.one_hot(action_sample, num_actions)
masks = keras.ops.one_hot(action_sample, num_actions)

with tf.GradientTape() as tape:
# Train the model on the states and updated Q-values
q_values = model(state_sample)

# Apply the masks to the Q-values to get the Q-value for action taken
q_action = tf.reduce_sum(tf.multiply(q_values, masks), axis=1)
q_action = keras.ops.sum(keras.ops.multiply(q_values, masks), axis=1)
# Calculate loss between new Q-value and old Q-value
loss = loss_function(updated_q_values, q_action)

Expand Down Expand Up @@ -271,6 +266,12 @@ def create_q_model():
print("Solved at episode {}!".format(episode_count))
break

if (
max_episodes > 0 and episode_count >= max_episodes
): # Maximum number of episodes reached
print("Stopped at episode {}!".format(episode_count))
break

"""
## Visualizations
Before any training:
Expand Down
Loading
Loading