diff --git a/src/dqn.rs b/src/dqn.rs index d668cf0..4a0b85d 100644 --- a/src/dqn.rs +++ b/src/dqn.rs @@ -133,9 +133,9 @@ where pub fn train_dqn( &mut self, - states: [[f32; STATE_SIZE]; BATCH], + states: Box<[[f32; STATE_SIZE]; BATCH]>, actions: [[f32; ACTION_SIZE]; BATCH], - next_states: [[f32; STATE_SIZE]; BATCH], + next_states: Box<[[f32; STATE_SIZE]; BATCH]>, rewards: [f32; BATCH], dones: [bool; BATCH], ) { @@ -148,7 +148,7 @@ where // Convert to tensors and normalize the states for better training let states: Tensor, f32, _> = - self.dev.tensor(states).normalize::>(0.001); + self.dev.tensor(*states).normalize::>(0.001); // Convert actions to tensors and get the max action for each batch let actions: Tensor, usize, _> = self.dev.tensor(actions.map(|a| { @@ -165,7 +165,7 @@ where // Convert to tensors and normalize the states for better training let next_states: Tensor, f32, _> = - self.dev.tensor(next_states).normalize::>(0.001); + self.dev.tensor(*next_states).normalize::>(0.001); // Compute the estimated Q-value for the action for _step in 0..20 { @@ -203,9 +203,21 @@ where ) { loop { // Initialize batch - let mut states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH]; + let mut states: Box<[[f32; STATE_SIZE]; BATCH]> = { + let b = vec![0.0; STATE_SIZE].into_boxed_slice(); + let big = unsafe { Box::from_raw(Box::into_raw(b) as *mut [f32; STATE_SIZE]) }; + + let b = vec![*big; BATCH].into_boxed_slice(); + unsafe { Box::from_raw(Box::into_raw(b) as *mut [[f32; STATE_SIZE]; BATCH]) } + }; let mut actions: [[f32; ACTION_SIZE]; BATCH] = [[0.0; ACTION_SIZE]; BATCH]; - let mut next_states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH]; + let mut next_states: Box<[[f32; STATE_SIZE]; BATCH]> = { + let b = vec![0.0; STATE_SIZE].into_boxed_slice(); + let big = unsafe { Box::from_raw(Box::into_raw(b) as *mut [f32; STATE_SIZE]) }; + + let b = vec![*big; BATCH].into_boxed_slice(); + unsafe { Box::from_raw(Box::into_raw(b) as *mut [[f32; STATE_SIZE]; BATCH]) } + }; let mut rewards: [f32; BATCH] = [0.0; BATCH]; let mut dones = [false; BATCH];