Skip to content

Commit

Permalink
Merge pull request #16 from chriamue/dqn
Browse files Browse the repository at this point in the history
Addition of DQN Trainer and Integration with DFDX Library
  • Loading branch information
milanboers authored Nov 30, 2023
2 parents 40d0fa7 + 675626a commit 2a84249
Show file tree
Hide file tree
Showing 4 changed files with 477 additions and 2 deletions.
13 changes: 11 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@ documentation = "https://docs.rs/rurel"
homepage = "https://github.com/milanboers/rurel"
repository = "https://github.com/milanboers/rurel"
readme = "README.md"
keywords = ["reinforcement", "q", "learning"]
keywords = ["reinforcement", "q", "learning", "dqn"]
categories = ["science", "algorithms"]
license = "MPL-2.0"
edition = "2021"

[badges]
travis-ci = { repository = "milanboers/rurel", branch = "master" }

[features]
default = []
dqn = ["dfdx"]

[dependencies]
rand = "0.8"
dfdx = { version = "0.11.2", optional = true }

[[example]]
name = "eucdist"
Expand All @@ -25,4 +30,8 @@ path = "src/examples/eucdist.rs"

[[example]]
name = "weightedcoin"
path = "src/examples/weightedcoin.rs"
path = "src/examples/weightedcoin.rs"

[[example]]
name = "eucdist_dqn"
path = "src/examples/eucdist_dqn.rs"
254 changes: 254 additions & 0 deletions src/dqn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
// source: https://raw.githubusercontent.com/coreylowman/dfdx/main/examples/rl-dqn.rs
use dfdx::{
nn,
optim::{Momentum, Sgd, SgdConfig},
prelude::*,
};

use crate::{
mdp::{Agent, State},
strategy::{explore::ExplorationStrategy, terminate::TerminationStrategy},
};

const BATCH: usize = 64;

type QNetwork<const STATE_SIZE: usize, const ACTION_SIZE: usize, const INNER_SIZE: usize> = (
(Linear<STATE_SIZE, INNER_SIZE>, ReLU),
(Linear<INNER_SIZE, INNER_SIZE>, ReLU),
Linear<INNER_SIZE, ACTION_SIZE>,
);

type QNetworkDevice<const STATE_SIZE: usize, const ACTION_SIZE: usize, const INNER_SIZE: usize> = (
(nn::modules::Linear<STATE_SIZE, INNER_SIZE, f32, Cpu>, ReLU),
(nn::modules::Linear<INNER_SIZE, INNER_SIZE, f32, Cpu>, ReLU),
nn::modules::Linear<INNER_SIZE, ACTION_SIZE, f32, Cpu>,
);

/// An `DQNAgentTrainer` can be trained for using a certain [Agent](mdp/trait.Agent.html). After
/// training, the `DQNAgentTrainer` contains learned knowledge about the process, and can be queried
/// for this. For example, you can ask the `DQNAgentTrainer` the expected values of all possible
/// actions in a given state.
///
/// The code is partially taken from https://github.com/coreylowman/dfdx/blob/main/examples/rl-dqn.rs.
///
pub struct DQNAgentTrainer<
S,
const STATE_SIZE: usize,
const ACTION_SIZE: usize,
const INNER_SIZE: usize,
> where
S: State + Into<[f32; STATE_SIZE]>,
S::A: Into<[f32; ACTION_SIZE]>,
S::A: From<[f32; ACTION_SIZE]>,
{
// values future rewards
gamma: f32,
q_network: QNetworkDevice<STATE_SIZE, ACTION_SIZE, INNER_SIZE>,
target_q_net: QNetworkDevice<STATE_SIZE, ACTION_SIZE, INNER_SIZE>,
sgd: Sgd<QNetworkDevice<STATE_SIZE, ACTION_SIZE, INNER_SIZE>, f32, Cpu>,
dev: Cpu,
phantom: std::marker::PhantomData<S>,
}

impl<S, const STATE_SIZE: usize, const ACTION_SIZE: usize, const INNER_SIZE: usize>
DQNAgentTrainer<S, STATE_SIZE, ACTION_SIZE, INNER_SIZE>
where
S: State + Into<[f32; STATE_SIZE]>,
S::A: Into<[f32; ACTION_SIZE]>,
S::A: From<[f32; ACTION_SIZE]>,
{
/// Creates a new `DQNAgentTrainer` with the given parameters.
///
/// # Arguments
///
/// * `gamma` - The discount factor for future rewards.
/// * `learning_rate` - The learning rate for the optimizer.
///
/// # Returns
///
/// A new `DQNAgentTrainer` with the given parameters.
///
pub fn new(
gamma: f32,
learning_rate: f32,
) -> DQNAgentTrainer<S, STATE_SIZE, ACTION_SIZE, INNER_SIZE> {
let dev = AutoDevice::default();

// initialize model
let q_net = dev.build_module::<QNetwork<STATE_SIZE, ACTION_SIZE, INNER_SIZE>, f32>();
let target_q_net = q_net.clone();

// initialize optimizer
let sgd = Sgd::new(
&q_net,
SgdConfig {
lr: learning_rate,
momentum: Some(Momentum::Nesterov(0.9)),
weight_decay: None,
},
);

DQNAgentTrainer {
gamma,
q_network: q_net,
target_q_net,
sgd,
dev,
phantom: std::marker::PhantomData,
}
}

/// Fetches the learned value for the given `Action` in the given `State`, or `None` if no
/// value was learned.
pub fn expected_value(&self, state: &S) -> [f32; ACTION_SIZE] {
let state_: [f32; STATE_SIZE] = (state.clone()).into();
let states: Tensor<Rank1<STATE_SIZE>, f32, _> =
self.dev.tensor(state_).normalize::<Axis<0>>(0.001);
let actions = self.target_q_net.forward(states).nans_to(0f32);
actions.array()
}

/// Returns a clone of the entire learned state to be saved or used elsewhere.
pub fn export_learned_values(&self) -> QNetworkDevice<STATE_SIZE, ACTION_SIZE, INNER_SIZE> {
self.learned_values().clone()
}

// Returns a reference to the learned state.
pub fn learned_values(&self) -> &QNetworkDevice<STATE_SIZE, ACTION_SIZE, INNER_SIZE> {
&self.q_network
}

/// Imports a model, completely replacing any learned progress
pub fn import_model(&mut self, model: QNetworkDevice<STATE_SIZE, ACTION_SIZE, INNER_SIZE>) {
self.q_network.clone_from(&model);
self.target_q_net.clone_from(&self.q_network);
}

/// Returns the best action for the given `State`, or `None` if no values were learned.
pub fn best_action(&self, state: &S) -> Option<S::A> {
let target = self.expected_value(state);

Some(target.into())
}

pub fn train_dqn(
&mut self,
states: [[f32; STATE_SIZE]; BATCH],
actions: [[f32; ACTION_SIZE]; BATCH],
next_states: [[f32; STATE_SIZE]; BATCH],
rewards: [f32; BATCH],
dones: [bool; BATCH],
) {
self.target_q_net.clone_from(&self.q_network);
let mut grads = self.q_network.alloc_grads();

let dones: Tensor<Rank1<BATCH>, f32, _> =
self.dev.tensor(dones.map(|d| if d { 1f32 } else { 0f32 }));
let rewards = self.dev.tensor(rewards);

// Convert to tensors and normalize the states for better training
let states: Tensor<Rank2<BATCH, STATE_SIZE>, f32, _> =
self.dev.tensor(states).normalize::<Axis<1>>(0.001);

// Convert actions to tensors and get the max action for each batch
let actions: Tensor<Rank1<BATCH>, usize, _> = self.dev.tensor(actions.map(|a| {
let mut max_idx = 0;
let mut max_val = 0f32;
for (i, v) in a.iter().enumerate() {
if *v > max_val {
max_val = *v;
max_idx = i;
}
}
max_idx
}));

// Convert to tensors and normalize the states for better training
let next_states: Tensor<Rank2<BATCH, STATE_SIZE>, f32, _> =
self.dev.tensor(next_states).normalize::<Axis<1>>(0.001);

// Compute the estimated Q-value for the action
for _step in 0..20 {
let q_values = self.q_network.forward(states.trace(grads));

let action_qs = q_values.select(actions.clone());

// targ_q = R + discount * max(Q(S'))
// curr_q = Q(S)[A]
// loss = huber(curr_q, targ_q, 1)
let next_q_values = self.target_q_net.forward(next_states.clone());
let max_next_q = next_q_values.max::<Rank1<BATCH>, _>();
let target_q = (max_next_q * (-dones.clone() + 1.0)) * self.gamma + rewards.clone();

let loss = huber_loss(action_qs, target_q, 1.0);

grads = loss.backward();

// update weights with optimizer
self.sgd
.update(&mut self.q_network, &grads)
.expect("Unused params");
self.q_network.zero_grads(&mut grads);
}
self.target_q_net.clone_from(&self.q_network);
}

/// Trains this [DQNAgentTrainer] using the given [ExplorationStrategy] and
/// [Agent] until the [TerminationStrategy] decides to stop.
pub fn train(
&mut self,
agent: &mut dyn Agent<S>,
termination_strategy: &mut dyn TerminationStrategy<S>,
exploration_strategy: &dyn ExplorationStrategy<S>,
) {
loop {
// Initialize batch
let mut states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH];
let mut actions: [[f32; ACTION_SIZE]; BATCH] = [[0.0; ACTION_SIZE]; BATCH];
let mut next_states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH];
let mut rewards: [f32; BATCH] = [0.0; BATCH];
let mut dones = [false; BATCH];

let mut s_t_next = agent.current_state();

for i in 0..BATCH {
let s_t = agent.current_state().clone();
let action = exploration_strategy.pick_action(agent);

// current action value
s_t_next = agent.current_state();
let r_t_next = s_t_next.reward();

states[i] = s_t.into();
actions[i] = action.into();
next_states[i] = (*s_t_next).clone().into();
rewards[i] = r_t_next as f32;

if termination_strategy.should_stop(s_t_next) {
dones[i] = true;
break;
}
}

// train the network
self.train_dqn(states, actions, next_states, rewards, dones);

// terminate if the agent is done
if termination_strategy.should_stop(s_t_next) {
break;
}
}
}
}

impl<S, const STATE_SIZE: usize, const ACTION_SIZE: usize, const INNER_SIZE: usize> Default
for DQNAgentTrainer<S, STATE_SIZE, ACTION_SIZE, INNER_SIZE>
where
S: State + Into<[f32; STATE_SIZE]>,
S::A: Into<[f32; ACTION_SIZE]>,
S::A: From<[f32; ACTION_SIZE]>,
{
fn default() -> Self {
Self::new(0.99, 1e-3)
}
}
Loading

0 comments on commit 2a84249

Please sign in to comment.