Skip to content

Commit

Permalink
Fix some compilation, clippy, fmt issues
Browse files Browse the repository at this point in the history
  • Loading branch information
milanboers committed Nov 30, 2023
1 parent 2a84249 commit 931c01d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ type QNetworkDevice<const STATE_SIZE: usize, const ACTION_SIZE: usize, const INN
/// training, the `DQNAgentTrainer` contains learned knowledge about the process, and can be queried
/// for this. For example, you can ask the `DQNAgentTrainer` the expected values of all possible
/// actions in a given state.
///
///
/// The code is partially taken from https://github.com/coreylowman/dfdx/blob/main/examples/rl-dqn.rs.
///
///
pub struct DQNAgentTrainer<
S,
const STATE_SIZE: usize,
Expand Down
26 changes: 13 additions & 13 deletions src/examples/eucdist_dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ struct MyState {
}

// Into float array has to be implemented for the DQN state
impl Into<[f32; 6]> for MyState {
fn into(self) -> [f32; 6] {
impl From<MyState> for [f32; 6] {
fn from(val: MyState) -> Self {
[
self.tx as f32,
self.ty as f32,
self.x as f32,
self.y as f32,
self.maxx as f32,
self.maxy as f32,
val.tx as f32,
val.ty as f32,
val.x as f32,
val.y as f32,
val.maxx as f32,
val.maxy as f32,
]
}
}

// From float array has to be implemented for the DQN state
impl From<[f32; 4]> for MyState {
fn from(v: [f32; 4]) -> Self {
impl From<[f32; 6]> for MyState {
fn from(v: [f32; 6]) -> Self {
MyState {
tx: v[0] as i32,
ty: v[1] as i32,
Expand All @@ -54,9 +54,9 @@ enum MyAction {

// Into float array has to be implemented for the action,
// so that the DQN can use it.
impl Into<[f32; 4]> for MyAction {
fn into(self) -> [f32; 4] {
match self {
impl From<MyAction> for [f32; 4] {
fn from(val: MyAction) -> Self {
match val {
MyAction::Move { dx: -1, dy: 0 } => [1.0, 0.0, 0.0, 0.0],
MyAction::Move { dx: 1, dy: 0 } => [0.0, 1.0, 0.0, 0.0],
MyAction::Move { dx: 0, dy: -1 } => [0.0, 0.0, 1.0, 0.0],
Expand Down
5 changes: 1 addition & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,7 @@ where
learning_strategy.value(&self.q.get(s_t_next), &old_value, r_t_next)
};

self.q
.entry(s_t)
.or_insert_with(HashMap::new)
.insert(action, v);
self.q.entry(s_t).or_default().insert(action, v);

if termination_strategy.should_stop(s_t_next) {
break;
Expand Down

0 comments on commit 931c01d

Please sign in to comment.