Skip to content

Commit

Permalink
Updated Deep Deterministic Policy Gradient (DDPG) example to Keras v3 (
Browse files Browse the repository at this point in the history
…#1812)

* Updated Deep Deterministic Policy Gradient (DDPG) example to Keras v3

* add generated files
  • Loading branch information
lpizzinidev authored Mar 25, 2024
1 parent ee887ff commit a66fff8
Show file tree
Hide file tree
Showing 5 changed files with 400 additions and 261 deletions.
104 changes: 57 additions & 47 deletions examples/rl/ddpg_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Title: Deep Deterministic Policy Gradient (DDPG)
Author: [amifunny](https://github.com/amifunny)
Date created: 2020/06/04
Last modified: 2020/09/21
Last modified: 2024/03/23
Description: Implementing DDPG algorithm on the Inverted Pendulum Problem.
Accelerator: NONE
Accelerator: None
"""

"""
Expand All @@ -15,11 +15,10 @@
It combines ideas from DPG (Deterministic Policy Gradient) and DQN (Deep Q-Network).
It uses Experience Replay and slow-learning target networks from DQN, and it is based on
DPG,
which can operate over continuous action spaces.
DPG, which can operate over continuous action spaces.
This tutorial closely follow this paper -
[Continuous control with deep reinforcement learning](https://arxiv.org/pdf/1509.02971.pdf)
[Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971)
## Problem
Expand Down Expand Up @@ -61,19 +60,25 @@
Now, let's see how is it implemented.
"""
import gym
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers

import tensorflow as tf
from tensorflow.keras import layers
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

"""
We use [OpenAIGym](http://gym.openai.com/docs) to create the environment.
We use [Gymnasium](https://gymnasium.farama.org/) to create the environment.
We will use the `upper_bound` parameter to scale our actions later.
"""

problem = "Pendulum-v1"
env = gym.make(problem)
# Specify the `render_mode` parameter to show the attempts of the agent in a pop up window.
env = gym.make("Pendulum-v1") # , render_mode="human")

num_states = env.observation_space.shape[0]
print("Size of State Space -> {}".format(num_states))
Expand Down Expand Up @@ -104,7 +109,7 @@ def __init__(self, mean, std_deviation, theta=0.15, dt=1e-2, x_initial=None):
self.reset()

def __call__(self):
# Formula taken from https://www.wikipedia.org/wiki/Ornstein-Uhlenbeck_process.
# Formula taken from https://www.wikipedia.org/wiki/Ornstein-Uhlenbeck_process
x = (
self.x_prev
+ self.theta * (self.mean - self.x_prev) * self.dt
Expand Down Expand Up @@ -193,7 +198,7 @@ def update(
[next_state_batch, target_actions], training=True
)
critic_value = critic_model([state_batch, action_batch], training=True)
critic_loss = tf.math.reduce_mean(tf.math.square(y - critic_value))
critic_loss = keras.ops.mean(keras.ops.square(y - critic_value))

critic_grad = tape.gradient(critic_loss, critic_model.trainable_variables)
critic_optimizer.apply_gradients(
Expand All @@ -205,7 +210,7 @@ def update(
critic_value = critic_model([state_batch, actions], training=True)
# Used `-value` as we want to maximize the value given
# by the critic for our actions
actor_loss = -tf.math.reduce_mean(critic_value)
actor_loss = -keras.ops.mean(critic_value)

actor_grad = tape.gradient(actor_loss, actor_model.trainable_variables)
actor_optimizer.apply_gradients(
Expand All @@ -220,21 +225,27 @@ def learn(self):
batch_indices = np.random.choice(record_range, self.batch_size)

# Convert to tensors
state_batch = tf.convert_to_tensor(self.state_buffer[batch_indices])
action_batch = tf.convert_to_tensor(self.action_buffer[batch_indices])
reward_batch = tf.convert_to_tensor(self.reward_buffer[batch_indices])
reward_batch = tf.cast(reward_batch, dtype=tf.float32)
next_state_batch = tf.convert_to_tensor(self.next_state_buffer[batch_indices])
state_batch = keras.ops.convert_to_tensor(self.state_buffer[batch_indices])
action_batch = keras.ops.convert_to_tensor(self.action_buffer[batch_indices])
reward_batch = keras.ops.convert_to_tensor(self.reward_buffer[batch_indices])
reward_batch = keras.ops.cast(reward_batch, dtype="float32")
next_state_batch = keras.ops.convert_to_tensor(
self.next_state_buffer[batch_indices]
)

self.update(state_batch, action_batch, reward_batch, next_state_batch)


# This update target parameters slowly
# Based on rate `tau`, which is much less than one.
@tf.function
def update_target(target_weights, weights, tau):
for a, b in zip(target_weights, weights):
a.assign(b * tau + a * (1 - tau))
def update_target(target, original, tau):
target_weights = target.get_weights()
original_weights = original.get_weights()

for i in range(len(target_weights)):
target_weights[i] = original_weights[i] * tau + target_weights[i] * (1 - tau)

target.set_weights(target_weights)


"""
Expand All @@ -250,7 +261,7 @@ def update_target(target_weights, weights, tau):

def get_actor():
# Initialize weights between -3e-3 and 3-e3
last_init = tf.random_uniform_initializer(minval=-0.003, maxval=0.003)
last_init = keras.initializers.RandomUniform(minval=-0.003, maxval=0.003)

inputs = layers.Input(shape=(num_states,))
out = layers.Dense(256, activation="relu")(inputs)
Expand All @@ -259,18 +270,18 @@ def get_actor():

# Our upper bound is 2.0 for Pendulum.
outputs = outputs * upper_bound
model = tf.keras.Model(inputs, outputs)
model = keras.Model(inputs, outputs)
return model


def get_critic():
# State as input
state_input = layers.Input(shape=(num_states))
state_input = layers.Input(shape=(num_states,))
state_out = layers.Dense(16, activation="relu")(state_input)
state_out = layers.Dense(32, activation="relu")(state_out)

# Action as input
action_input = layers.Input(shape=(num_actions))
action_input = layers.Input(shape=(num_actions,))
action_out = layers.Dense(32, activation="relu")(action_input)

# Both are passed through seperate layer before concatenating
Expand All @@ -281,7 +292,7 @@ def get_critic():
outputs = layers.Dense(1)(out)

# Outputs single value for give state-action
model = tf.keras.Model([state_input, action_input], outputs)
model = keras.Model([state_input, action_input], outputs)

return model

Expand All @@ -293,7 +304,7 @@ def get_critic():


def policy(state, noise_object):
sampled_actions = tf.squeeze(actor_model(state))
sampled_actions = keras.ops.squeeze(actor_model(state))
noise = noise_object()
# Adding noise to action
sampled_actions = sampled_actions.numpy() + noise
Expand Down Expand Up @@ -325,8 +336,8 @@ def policy(state, noise_object):
critic_lr = 0.002
actor_lr = 0.001

critic_optimizer = tf.keras.optimizers.Adam(critic_lr)
actor_optimizer = tf.keras.optimizers.Adam(actor_lr)
critic_optimizer = keras.optimizers.Adam(critic_lr)
actor_optimizer = keras.optimizers.Adam(actor_lr)

total_episodes = 100
# Discount factor for future rewards
Expand All @@ -349,29 +360,28 @@ def policy(state, noise_object):

# Takes about 4 min to train
for ep in range(total_episodes):
prev_state = env.reset()
prev_state, _ = env.reset()
episodic_reward = 0

while True:
# Uncomment this to see the Actor in action
# But not in a python notebook.
# env.render()

tf_prev_state = tf.expand_dims(tf.convert_to_tensor(prev_state), 0)
tf_prev_state = keras.ops.expand_dims(
keras.ops.convert_to_tensor(prev_state), 0
)

action = policy(tf_prev_state, ou_noise)
# Recieve state and reward from environment.
state, reward, done, info = env.step(action)
state, reward, done, truncated, _ = env.step(action)

buffer.record((prev_state, action, reward, state))
episodic_reward += reward

buffer.learn()
update_target(target_actor.variables, actor_model.variables, tau)
update_target(target_critic.variables, critic_model.variables, tau)

# End this episode when `done` is True
if done:
update_target(target_actor, actor_model, tau)
update_target(target_critic, critic_model, tau)

# End this episode when `done` or `truncated` is True
if done or truncated:
break

prev_state = state
Expand All @@ -387,7 +397,7 @@ def policy(state, noise_object):
# Episodes versus Avg. Rewards
plt.plot(avg_reward_list)
plt.xlabel("Episode")
plt.ylabel("Avg. Epsiodic Reward")
plt.ylabel("Avg. Episodic Reward")
plt.show()

"""
Expand All @@ -399,16 +409,16 @@ def policy(state, noise_object):
The Inverted Pendulum problem has low complexity, but DDPG work great on many other
problems.
Another great environment to try this on is `LunarLandingContinuous-v2`, but it will take
Another great environment to try this on is `LunarLander-v2` continuous, but it will take
more episodes to obtain good results.
"""

# Save the weights
actor_model.save_weights("pendulum_actor.h5")
critic_model.save_weights("pendulum_critic.h5")
actor_model.save_weights("pendulum_actor.weights.h5")
critic_model.save_weights("pendulum_critic.weights.h5")

target_actor.save_weights("pendulum_target_actor.h5")
target_critic.save_weights("pendulum_target_critic.h5")
target_actor.save_weights("pendulum_target_actor.weights.h5")
target_critic.save_weights("pendulum_target_critic.weights.h5")

"""
Before Training:
Expand Down
Binary file modified examples/rl/img/ddpg_pendulum/ddpg_pendulum_16_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit a66fff8

Please sign in to comment.