From 675626acad6fce6842ac66fb09c21addacf67ec7 Mon Sep 17 00:00:00 2001 From: Christian M Date: Sun, 9 Jul 2023 07:46:56 +0200 Subject: [PATCH] Added DQN Trainer and dqn based example for eucdist --- Cargo.toml | 13 +- src/dqn.rs | 254 ++++++++++++++++++++++++++++++++++++ src/examples/eucdist_dqn.rs | 210 +++++++++++++++++++++++++++++ src/lib.rs | 2 + 4 files changed, 477 insertions(+), 2 deletions(-) create mode 100644 src/dqn.rs create mode 100644 src/examples/eucdist_dqn.rs diff --git a/Cargo.toml b/Cargo.toml index 5be7df4..a337d7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ documentation = "https://docs.rs/rurel" homepage = "https://github.com/milanboers/rurel" repository = "https://github.com/milanboers/rurel" readme = "README.md" -keywords = ["reinforcement", "q", "learning"] +keywords = ["reinforcement", "q", "learning", "dqn"] categories = ["science", "algorithms"] license = "MPL-2.0" edition = "2021" @@ -15,8 +15,13 @@ edition = "2021" [badges] travis-ci = { repository = "milanboers/rurel", branch = "master" } +[features] +default = [] +dqn = ["dfdx"] + [dependencies] rand = "0.8" +dfdx = { version = "0.11.2", optional = true } [[example]] name = "eucdist" @@ -25,4 +30,8 @@ path = "src/examples/eucdist.rs" [[example]] name = "weightedcoin" -path = "src/examples/weightedcoin.rs" \ No newline at end of file +path = "src/examples/weightedcoin.rs" + +[[example]] +name = "eucdist_dqn" +path = "src/examples/eucdist_dqn.rs" diff --git a/src/dqn.rs b/src/dqn.rs new file mode 100644 index 0000000..931a923 --- /dev/null +++ b/src/dqn.rs @@ -0,0 +1,254 @@ +// source: https://raw.githubusercontent.com/coreylowman/dfdx/main/examples/rl-dqn.rs +use dfdx::{ + nn, + optim::{Momentum, Sgd, SgdConfig}, + prelude::*, +}; + +use crate::{ + mdp::{Agent, State}, + strategy::{explore::ExplorationStrategy, terminate::TerminationStrategy}, +}; + +const BATCH: usize = 64; + +type QNetwork = ( + (Linear, ReLU), + (Linear, ReLU), + Linear, +); + +type QNetworkDevice = ( + (nn::modules::Linear, ReLU), + (nn::modules::Linear, ReLU), + nn::modules::Linear, +); + +/// An `DQNAgentTrainer` can be trained for using a certain [Agent](mdp/trait.Agent.html). After +/// training, the `DQNAgentTrainer` contains learned knowledge about the process, and can be queried +/// for this. For example, you can ask the `DQNAgentTrainer` the expected values of all possible +/// actions in a given state. +/// +/// The code is partially taken from https://github.com/coreylowman/dfdx/blob/main/examples/rl-dqn.rs. +/// +pub struct DQNAgentTrainer< + S, + const STATE_SIZE: usize, + const ACTION_SIZE: usize, + const INNER_SIZE: usize, +> where + S: State + Into<[f32; STATE_SIZE]>, + S::A: Into<[f32; ACTION_SIZE]>, + S::A: From<[f32; ACTION_SIZE]>, +{ + // values future rewards + gamma: f32, + q_network: QNetworkDevice, + target_q_net: QNetworkDevice, + sgd: Sgd, f32, Cpu>, + dev: Cpu, + phantom: std::marker::PhantomData, +} + +impl + DQNAgentTrainer +where + S: State + Into<[f32; STATE_SIZE]>, + S::A: Into<[f32; ACTION_SIZE]>, + S::A: From<[f32; ACTION_SIZE]>, +{ + /// Creates a new `DQNAgentTrainer` with the given parameters. + /// + /// # Arguments + /// + /// * `gamma` - The discount factor for future rewards. + /// * `learning_rate` - The learning rate for the optimizer. + /// + /// # Returns + /// + /// A new `DQNAgentTrainer` with the given parameters. + /// + pub fn new( + gamma: f32, + learning_rate: f32, + ) -> DQNAgentTrainer { + let dev = AutoDevice::default(); + + // initialize model + let q_net = dev.build_module::, f32>(); + let target_q_net = q_net.clone(); + + // initialize optimizer + let sgd = Sgd::new( + &q_net, + SgdConfig { + lr: learning_rate, + momentum: Some(Momentum::Nesterov(0.9)), + weight_decay: None, + }, + ); + + DQNAgentTrainer { + gamma, + q_network: q_net, + target_q_net, + sgd, + dev, + phantom: std::marker::PhantomData, + } + } + + /// Fetches the learned value for the given `Action` in the given `State`, or `None` if no + /// value was learned. + pub fn expected_value(&self, state: &S) -> [f32; ACTION_SIZE] { + let state_: [f32; STATE_SIZE] = (state.clone()).into(); + let states: Tensor, f32, _> = + self.dev.tensor(state_).normalize::>(0.001); + let actions = self.target_q_net.forward(states).nans_to(0f32); + actions.array() + } + + /// Returns a clone of the entire learned state to be saved or used elsewhere. + pub fn export_learned_values(&self) -> QNetworkDevice { + self.learned_values().clone() + } + + // Returns a reference to the learned state. + pub fn learned_values(&self) -> &QNetworkDevice { + &self.q_network + } + + /// Imports a model, completely replacing any learned progress + pub fn import_model(&mut self, model: QNetworkDevice) { + self.q_network.clone_from(&model); + self.target_q_net.clone_from(&self.q_network); + } + + /// Returns the best action for the given `State`, or `None` if no values were learned. + pub fn best_action(&self, state: &S) -> Option { + let target = self.expected_value(state); + + Some(target.into()) + } + + pub fn train_dqn( + &mut self, + states: [[f32; STATE_SIZE]; BATCH], + actions: [[f32; ACTION_SIZE]; BATCH], + next_states: [[f32; STATE_SIZE]; BATCH], + rewards: [f32; BATCH], + dones: [bool; BATCH], + ) { + self.target_q_net.clone_from(&self.q_network); + let mut grads = self.q_network.alloc_grads(); + + let dones: Tensor, f32, _> = + self.dev.tensor(dones.map(|d| if d { 1f32 } else { 0f32 })); + let rewards = self.dev.tensor(rewards); + + // Convert to tensors and normalize the states for better training + let states: Tensor, f32, _> = + self.dev.tensor(states).normalize::>(0.001); + + // Convert actions to tensors and get the max action for each batch + let actions: Tensor, usize, _> = self.dev.tensor(actions.map(|a| { + let mut max_idx = 0; + let mut max_val = 0f32; + for (i, v) in a.iter().enumerate() { + if *v > max_val { + max_val = *v; + max_idx = i; + } + } + max_idx + })); + + // Convert to tensors and normalize the states for better training + let next_states: Tensor, f32, _> = + self.dev.tensor(next_states).normalize::>(0.001); + + // Compute the estimated Q-value for the action + for _step in 0..20 { + let q_values = self.q_network.forward(states.trace(grads)); + + let action_qs = q_values.select(actions.clone()); + + // targ_q = R + discount * max(Q(S')) + // curr_q = Q(S)[A] + // loss = huber(curr_q, targ_q, 1) + let next_q_values = self.target_q_net.forward(next_states.clone()); + let max_next_q = next_q_values.max::, _>(); + let target_q = (max_next_q * (-dones.clone() + 1.0)) * self.gamma + rewards.clone(); + + let loss = huber_loss(action_qs, target_q, 1.0); + + grads = loss.backward(); + + // update weights with optimizer + self.sgd + .update(&mut self.q_network, &grads) + .expect("Unused params"); + self.q_network.zero_grads(&mut grads); + } + self.target_q_net.clone_from(&self.q_network); + } + + /// Trains this [DQNAgentTrainer] using the given [ExplorationStrategy] and + /// [Agent] until the [TerminationStrategy] decides to stop. + pub fn train( + &mut self, + agent: &mut dyn Agent, + termination_strategy: &mut dyn TerminationStrategy, + exploration_strategy: &dyn ExplorationStrategy, + ) { + loop { + // Initialize batch + let mut states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH]; + let mut actions: [[f32; ACTION_SIZE]; BATCH] = [[0.0; ACTION_SIZE]; BATCH]; + let mut next_states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH]; + let mut rewards: [f32; BATCH] = [0.0; BATCH]; + let mut dones = [false; BATCH]; + + let mut s_t_next = agent.current_state(); + + for i in 0..BATCH { + let s_t = agent.current_state().clone(); + let action = exploration_strategy.pick_action(agent); + + // current action value + s_t_next = agent.current_state(); + let r_t_next = s_t_next.reward(); + + states[i] = s_t.into(); + actions[i] = action.into(); + next_states[i] = (*s_t_next).clone().into(); + rewards[i] = r_t_next as f32; + + if termination_strategy.should_stop(s_t_next) { + dones[i] = true; + break; + } + } + + // train the network + self.train_dqn(states, actions, next_states, rewards, dones); + + // terminate if the agent is done + if termination_strategy.should_stop(s_t_next) { + break; + } + } + } +} + +impl Default + for DQNAgentTrainer +where + S: State + Into<[f32; STATE_SIZE]>, + S::A: Into<[f32; ACTION_SIZE]>, + S::A: From<[f32; ACTION_SIZE]>, +{ + fn default() -> Self { + Self::new(0.99, 1e-3) + } +} diff --git a/src/examples/eucdist_dqn.rs b/src/examples/eucdist_dqn.rs new file mode 100644 index 0000000..ec25211 --- /dev/null +++ b/src/examples/eucdist_dqn.rs @@ -0,0 +1,210 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#[cfg(feature = "dqn")] +use rurel::dqn::DQNAgentTrainer; +use rurel::mdp::{Agent, State}; + +/// A simple 2D grid world where the agent can move around. +/// The agent has to reach (10, 10). + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +struct MyState { + tx: i32, + ty: i32, + x: i32, + y: i32, + maxx: i32, + maxy: i32, +} + +// Into float array has to be implemented for the DQN state +impl Into<[f32; 6]> for MyState { + fn into(self) -> [f32; 6] { + [ + self.tx as f32, + self.ty as f32, + self.x as f32, + self.y as f32, + self.maxx as f32, + self.maxy as f32, + ] + } +} + +// From float array has to be implemented for the DQN state +impl From<[f32; 4]> for MyState { + fn from(v: [f32; 4]) -> Self { + MyState { + tx: v[0] as i32, + ty: v[1] as i32, + x: v[2] as i32, + y: v[3] as i32, + maxx: v[4] as i32, + maxy: v[5] as i32, + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +enum MyAction { + Move { dx: i32, dy: i32 }, +} + +// Into float array has to be implemented for the action, +// so that the DQN can use it. +impl Into<[f32; 4]> for MyAction { + fn into(self) -> [f32; 4] { + match self { + MyAction::Move { dx: -1, dy: 0 } => [1.0, 0.0, 0.0, 0.0], + MyAction::Move { dx: 1, dy: 0 } => [0.0, 1.0, 0.0, 0.0], + MyAction::Move { dx: 0, dy: -1 } => [0.0, 0.0, 1.0, 0.0], + MyAction::Move { dx: 0, dy: 1 } => [0.0, 0.0, 0.0, 1.0], + _ => panic!("Invalid action"), + } + } +} + +// From float array has to be implemented for the action, +// because output of the DQN is a float array like [0.1, 0.2, 0.1, 0.1] +impl From<[f32; 4]> for MyAction { + fn from(v: [f32; 4]) -> Self { + // Find the index of the maximum value + let max_index = v + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .unwrap() + .0; + + match max_index { + 0 => MyAction::Move { dx: -1, dy: 0 }, + 1 => MyAction::Move { dx: 1, dy: 0 }, + 2 => MyAction::Move { dx: 0, dy: -1 }, + 3 => MyAction::Move { dx: 0, dy: 1 }, + _ => panic!("Invalid action index"), + } + } +} + +impl State for MyState { + type A = MyAction; + + // The reward is the exponential of the negative distance to the target + fn reward(&self) -> f64 { + let (tx, ty) = (self.tx, self.ty); + let d = (((tx - self.x).pow(2) + (ty - self.y).pow(2)) as f64).sqrt(); + -d + } + + fn actions(&self) -> Vec { + vec![ + MyAction::Move { dx: -1, dy: 0 }, + MyAction::Move { dx: 1, dy: 0 }, + MyAction::Move { dx: 0, dy: -1 }, + MyAction::Move { dx: 0, dy: 1 }, + ] + } +} + +struct MyAgent { + state: MyState, +} + +impl Agent for MyAgent { + fn current_state(&self) -> &MyState { + &self.state + } + + fn take_action(&mut self, action: &MyAction) { + match action { + &MyAction::Move { dx, dy } => { + self.state = MyState { + x: (((self.state.x + dx) % self.state.maxx) + self.state.maxx) + % self.state.maxx, + y: (((self.state.y + dy) % self.state.maxy) + self.state.maxy) + % self.state.maxy, + ..self.state.clone() + }; + } + } + } +} + +#[cfg(feature = "dqn")] +fn main() { + use rurel::strategy::explore::RandomExploration; + use rurel::strategy::terminate::FixedIterations; + let (tx, ty) = (10, 10); + let (maxx, maxy) = (21, 21); + let initial_state = MyState { + tx, + ty, + x: 0, + y: 0, + maxx, + maxy, + }; + + let mut trainer = DQNAgentTrainer::::new(0.9, 1e-3); + let mut agent = MyAgent { + state: initial_state.clone(), + }; + trainer.train( + &mut agent, + &mut FixedIterations::new(10_000), + &RandomExploration::new(), + ); + for j in 0..maxy { + for i in 0..maxx { + let best_action = trainer + .best_action(&MyState { + tx, + ty, + x: i, + y: j, + maxx, + maxy, + }) + .unwrap(); + match best_action { + MyAction::Move { dx: -1, dy: 0 } => print!("<"), + MyAction::Move { dx: 1, dy: 0 } => print!(">"), + MyAction::Move { dx: 0, dy: -1 } => print!("^"), + MyAction::Move { dx: 0, dy: 1 } => print!("v"), + _ => print!("-"), + }; + } + println!(); + } + + /* + >>>>>vvvvvvvvvv<<<<<< + >>>>>vvvvvvvvvv<<<<<< + >>>>>vvvvvvvvv<<<<<<< + >>>>>>vvvvvvvv<<<<<<< + >>>>>>vvvvvvvv<<<<<<< + >>>>>>>vvvvvv<<<<<<<< + >>>>>>>vvvvvv<<<<<<<< + >>>>>>>>vvvv<<<<<<<<< + >>>>>>>>>vv<<<<<<<<<< + >>>>>>>>>v<<<<<<<<<<< + >>>>>>>>>^^<<<<<<<<<< + >>>>>>>>^^^^<<<<<<<<< + >>>>>>>^^^^^^<<<<<<<< + >>>>>^^^^^^^^^<<<<<<< + >>>^^^^^^^^^^^^<<<<<< + >^^^^^^^^^^^^^^<<<<<< + ^^^^^^^^^^^^^^^^<<<<< + ^^^^^^^^^^^^^^^^^<<<< + ^^^^^^^^^^^^^^^^^^^<< + ^^^^^^^^^^^^^^^^^^^^< + ^^^^^^^^^^^^^^^^^^^^^ + */ +} + +#[cfg(not(feature = "dqn"))] +fn main() { + panic!("Use the 'dqn' feature to run this example"); +} diff --git a/src/lib.rs b/src/lib.rs index 10ab266..11b876f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -89,6 +89,8 @@ use strategy::explore::ExplorationStrategy; use strategy::learn::LearningStrategy; use strategy::terminate::TerminationStrategy; +#[cfg(feature = "dqn")] +pub mod dqn; pub mod mdp; pub mod strategy;