From 211efcdaeba8770a5b6e4b3c7b3bbe7a18127068 Mon Sep 17 00:00:00 2001 From: Milan Boers Date: Sun, 15 Jan 2023 15:29:21 +0100 Subject: [PATCH] Fmt, clippy weightedcoin example --- src/examples/weightedcoin.rs | 105 +++++++++++++++----------- src/strategy/terminate/sink_states.rs | 1 + 2 files changed, 62 insertions(+), 44 deletions(-) diff --git a/src/examples/weightedcoin.rs b/src/examples/weightedcoin.rs index 5c0f7ba..ab82d2c 100644 --- a/src/examples/weightedcoin.rs +++ b/src/examples/weightedcoin.rs @@ -2,79 +2,96 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ - use rurel::mdp::{Agent, State}; - use rurel::strategy::explore::RandomExploration; - use rurel::strategy::learn::QLearning; - use rurel::strategy::terminate::SinkStates; - use rurel::AgentTrainer; +use rurel::mdp::{Agent, State}; +use rurel::strategy::explore::RandomExploration; +use rurel::strategy::learn::QLearning; +use rurel::strategy::terminate::SinkStates; +use rurel::AgentTrainer; -const TARGET : i32 = 100; -const WEIGHT : u8 = 100; //portion of 255 - - #[derive(PartialEq, Eq, Hash, Clone)] - struct CoinState{ - balance : i32 - } +const TARGET: i32 = 100; +const WEIGHT: u8 = 100; //portion of 255 + +#[derive(PartialEq, Eq, Hash, Clone)] +struct CoinState { + balance: i32, +} + +#[derive(PartialEq, Eq, Hash, Clone)] +struct CoinAction { + bet: i32, +} - #[derive(PartialEq, Eq, Hash, Clone)] - struct CoinAction{ - bet : i32 - } - impl State for CoinState { type A = CoinAction; fn reward(&self) -> f64 { - if self.balance>=TARGET {1.0} - else {0.0} + if self.balance >= TARGET { + 1.0 + } else { + 0.0 + } } fn actions(&self) -> Vec { let bet_range = { - if self.balance for CoinAgent{ +struct CoinAgent { + state: CoinState, +} + +impl Agent for CoinAgent { fn current_state(&self) -> &CoinState { - &self.state - } - fn take_action(&mut self, action: &CoinAction) -> () { + &self.state + } + fn take_action(&mut self, action: &CoinAction) { //Update the state to: - self.state = CoinState { balance : - if rand::random::() <= WEIGHT {self.state.balance+action.bet} + self.state = CoinState { + balance: if rand::random::() <= WEIGHT { + self.state.balance + action.bet + } //If the coin is heads, balance + bet - else {self.state.balance-action.bet} - //If the coin is tails, balance - bet + else { + self.state.balance - action.bet + }, //If the coin is tails, balance - bet } - } + } } fn main() { - const TRIALS:usize=1000000; - let mut trainer=AgentTrainer::new(); - for trial in 0..TRIALS{ - let mut agent = CoinAgent {state: CoinState{balance:((1+trial%98) as i32)}}; + const TRIALS: i32 = 100000; + let mut trainer = AgentTrainer::new(); + for trial in 0..TRIALS { + let mut agent = CoinAgent { + state: CoinState { + balance: 1 + trial % 98, + }, + }; trainer.train( &mut agent, - &QLearning::new(0.2,1.0,0.0), - &mut SinkStates{}, - &RandomExploration::new() + &QLearning::new(0.2, 1.0, 0.0), + &mut SinkStates {}, + &RandomExploration::new(), ); } println!("Balance\tBet\tQ-value"); - for balance in 1..TARGET{ - let state = CoinState{balance:balance}; + for balance in 1..TARGET { + let state = CoinState { balance }; let action = trainer.best_action(&state).unwrap(); - println!("{}\t{}\t{}", + println!( + "{}\t{}\t{}", balance, action.bet, - trainer.expected_value(&state,&action).unwrap() + trainer.expected_value(&state, &action).unwrap(), ); } -} \ No newline at end of file +} diff --git a/src/strategy/terminate/sink_states.rs b/src/strategy/terminate/sink_states.rs index 519446d..53434b9 100644 --- a/src/strategy/terminate/sink_states.rs +++ b/src/strategy/terminate/sink_states.rs @@ -9,6 +9,7 @@ use crate::strategy::terminate::TerminationStrategy; /// The termination strategy that ends if it's at a terminal state (no actions) pub struct SinkStates {} + impl TerminationStrategy for SinkStates { fn should_stop(&mut self, state: &S) -> bool { state.actions().is_empty()