Skip to content

Commit

Permalink
update dqn to heap initialize state arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucky4Luuk authored Jan 17, 2024
1 parent 5bbaeb6 commit 3dc9ab5
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ where

pub fn train_dqn(
&mut self,
states: [[f32; STATE_SIZE]; BATCH],
states: Box<[[f32; STATE_SIZE]; BATCH]>,
actions: [[f32; ACTION_SIZE]; BATCH],
next_states: [[f32; STATE_SIZE]; BATCH],
next_states: Box<[[f32; STATE_SIZE]; BATCH]>,
rewards: [f32; BATCH],
dones: [bool; BATCH],
) {
Expand All @@ -148,7 +148,7 @@ where

// Convert to tensors and normalize the states for better training
let states: Tensor<Rank2<BATCH, STATE_SIZE>, f32, _> =
self.dev.tensor(states).normalize::<Axis<1>>(0.001);
self.dev.tensor(*states).normalize::<Axis<1>>(0.001);

// Convert actions to tensors and get the max action for each batch
let actions: Tensor<Rank1<BATCH>, usize, _> = self.dev.tensor(actions.map(|a| {
Expand All @@ -165,7 +165,7 @@ where

// Convert to tensors and normalize the states for better training
let next_states: Tensor<Rank2<BATCH, STATE_SIZE>, f32, _> =
self.dev.tensor(next_states).normalize::<Axis<1>>(0.001);
self.dev.tensor(*next_states).normalize::<Axis<1>>(0.001);

// Compute the estimated Q-value for the action
for _step in 0..20 {
Expand Down Expand Up @@ -203,9 +203,21 @@ where
) {
loop {
// Initialize batch
let mut states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH];
let mut states: Box<[[f32; STATE_SIZE]; BATCH]> = {
let b = vec![0.0; STATE_SIZE].into_boxed_slice();
let big = unsafe { Box::from_raw(Box::into_raw(b) as *mut [f32; STATE_SIZE]) };

let b = vec![*big; BATCH].into_boxed_slice();
unsafe { Box::from_raw(Box::into_raw(b) as *mut [[f32; STATE_SIZE]; BATCH]) }
};
let mut actions: [[f32; ACTION_SIZE]; BATCH] = [[0.0; ACTION_SIZE]; BATCH];
let mut next_states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH];
let mut next_states: Box<[[f32; STATE_SIZE]; BATCH]> = {
let b = vec![0.0; STATE_SIZE].into_boxed_slice();
let big = unsafe { Box::from_raw(Box::into_raw(b) as *mut [f32; STATE_SIZE]) };

let b = vec![*big; BATCH].into_boxed_slice();
unsafe { Box::from_raw(Box::into_raw(b) as *mut [[f32; STATE_SIZE]; BATCH]) }
};
let mut rewards: [f32; BATCH] = [0.0; BATCH];
let mut dones = [false; BATCH];

Expand Down

0 comments on commit 3dc9ab5

Please sign in to comment.