diff --git a/Cargo.lock b/Cargo.lock index 366385b2..f1be37f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,6 +358,25 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "little-sorry" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c956d5f2f5d4fd63a3f5e741fc3997c9e3538a0201af5c0aed935198396654d" +dependencies = [ + "ndarray", + "once_cell", + "rand", + "rand_distr", + "thiserror", +] + [[package]] name = "log" version = "0.4.21" @@ -373,12 +392,35 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -389,6 +431,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-complex" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.18" @@ -396,6 +456,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -504,6 +565,22 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -576,6 +653,8 @@ dependencies = [ "arbitrary", "criterion", "env_logger", + "little-sorry", + "ndarray", "rand", "serde", "test-log", @@ -602,18 +681,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.198" +version = "1.0.199" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +checksum = "0c9f6e76df036c77cd94996771fb40db98187f096dd0b9af39c6c6e452ba966a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.198" +version = "1.0.199" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +checksum = "11bd257a6541e141e42ca6d24ae26f7714887b47e89aa739099104c7e4d3b7fc" dependencies = [ "proc-macro2", "quote", @@ -907,11 +986,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" dependencies = [ - "winapi", + "windows-sys", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c6418fc7..8ae95897 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,8 @@ arbitrary = { version = "1.3.2", optional = true, features = ["derive"] } tracing = { version = "0.1.40", optional = true} approx = { version = "0.5.1", optional = true} uuid = {version = "1.8.0", optional = true, features = ["v7"]} +little-sorry = { version = "0.4.0", optional = true} +ndarray = { version = "0.15.6", optional = true} [dev-dependencies] criterion = "0.5.1" @@ -32,7 +34,7 @@ approx = { version = "0.5.1"} default = ["arena", "serde"] uuid = ["dep:uuid"] serde = ["dep:serde", "uuid?/serde"] -arena = ["dep:tracing", "uuid"] +arena = ["dep:tracing", "dep:little-sorry", "dep:ndarray", "uuid"] arena-test-util = ["arena", "dep:approx"] [[bench]] diff --git a/src/arena/cfr/agent.rs b/src/arena/cfr/agent.rs new file mode 100644 index 00000000..99f1a2c4 --- /dev/null +++ b/src/arena/cfr/agent.rs @@ -0,0 +1,267 @@ +use std::{cell::RefCell, rc::Rc}; + +use ndarray::aview1; +use rand::{thread_rng, Rng}; +use uuid::Uuid; + +use crate::arena::{ + action::AgentAction, cfr::ArenaCFRHistorian, game_state::Round, Agent, GameState, Historian, + HoldemSimulationBuilder, +}; + +use super::{ + node::{NodeData, PlayerData}, + EnsureNodeType, PlayerCFRState, EXPERTS, LOWER_MULT, MAX_RAISE_EXPERTS, + MAX_RAISE_PREFLOP_EXPERTS, PREFLOP_EXPERTS, UPPER_MULT, +}; + +pub struct ArenaCFRAgent { + pub id: Uuid, + pub cfr_states: Vec>>, + pub player_idx: usize, + forced_action: Option, +} + +impl ArenaCFRAgent { + pub fn new(cfr_states: Vec>>, player_idx: usize) -> Self { + let id = Uuid::now_v7(); + + Self { + id, + cfr_states, + player_idx, + forced_action: None, + } + } + + pub fn new_with_forced_action( + cfr_states: Vec>>, + player_idx: usize, + forced_action: AgentAction, + ) -> Self { + let id = Uuid::now_v7(); + + Self { + id, + cfr_states, + player_idx, + forced_action: Some(forced_action), + } + } + + fn state(&self) -> &Rc> { + &self.cfr_states[self.player_idx] + } + + fn mut_state(&mut self) -> &mut Rc> { + &mut self.cfr_states[self.player_idx] + } + + fn expert_action(&self, game_state: &GameState, expert: usize) -> AgentAction { + match game_state.round { + Round::Preflop => self.preflop_action(game_state, expert), + _ => self.postflop_action(game_state, expert), + } + } + + fn preflop_action(&self, game_state: &GameState, expert: usize) -> AgentAction { + let current_round_bet = game_state.current_round_bet(); + match expert { + 0 => AgentAction::Fold, + 1 => AgentAction::Bet(current_round_bet), + 2 => AgentAction::Bet(current_round_bet + game_state.current_round_min_raise()), + 3 => AgentAction::Bet(self.pot_bet(game_state, 0.5)), + 4 => AgentAction::Bet(self.pot_bet(game_state, 1.0)), + 5 => AgentAction::Bet(current_round_bet + game_state.stacks[self.player_idx]), + _ => panic!("Un-expected expert"), + } + } + + fn postflop_action(&self, game_state: &GameState, expert: usize) -> AgentAction { + let current_round_bet = game_state.current_round_bet(); + match expert { + 0 => AgentAction::Fold, + 1 => AgentAction::Bet(current_round_bet), + 2 => AgentAction::Bet(current_round_bet + game_state.current_round_min_raise()), + 3 => AgentAction::Bet(self.pot_bet(game_state, 0.5)), + 4 => AgentAction::Bet(self.pot_bet(game_state, 0.6666)), + 5 => AgentAction::Bet(self.pot_bet(game_state, 1.0)), + 6 => AgentAction::Bet(self.prepare_shove_bet(game_state)), + 7 => AgentAction::Bet(current_round_bet + game_state.stacks[self.player_idx]), + _ => panic!("Un-expected expert"), + } + } + + fn pot_bet(&self, game_state: &GameState, ratio: f32) -> f32 { + let pot = game_state.total_pot; + let current_round_bet = game_state.current_round_bet(); + let lower_raise = LOWER_MULT * ratio * pot; + let upper_raise = UPPER_MULT * ratio * pot; + if lower_raise >= upper_raise { + current_round_bet + } else { + let mut rng = thread_rng(); + current_round_bet + rng.gen_range(lower_raise..upper_raise) + } + } + + fn prepare_shove_bet(&self, game_state: &GameState) -> f32 { + let diff = game_state.stacks[self.player_idx] - game_state.total_pot; + + if diff <= 0.0 { + 0.0 + } else { + let current_round_bet = game_state.current_round_bet(); + let per_player_more = diff / game_state.num_active_players() as f32; + let lower_raise = LOWER_MULT * per_player_more; + let upper_raise = UPPER_MULT * per_player_more; + + if lower_raise >= upper_raise { + current_round_bet + } else { + let mut rng = thread_rng(); + current_round_bet + rng.gen_range(lower_raise..upper_raise) + } + } + } + + fn play_round(&mut self, game_state: &GameState) -> AgentAction { + // First explore all the experts suggestions + self.try_all_experts(game_state); + // Then after that which updates the regrets + // we can get the best action + let state = self.state().borrow(); + let data = &state.get_current_node().unwrap().data; + match &data { + NodeData::Player(player_data) => self.final_action(player_data, game_state), + _ => panic!("Expected a player node"), + } + } + + fn final_action(&self, player_data: &PlayerData, game_state: &GameState) -> AgentAction { + // Get the expert action + let expert = player_data.regrets.next_action(); + self.expert_action(game_state, expert) + } + + fn try_all_experts(&mut self, game_state: &GameState) { + let save_points = self + .cfr_states + .iter() + .map(|state| state.borrow().save_point()) + .collect::>(); + + let mut rewards: Vec = vec![0.0; self.num_possible_experts(game_state)]; + // For every expert try to see what the reward would be + // then update the regret matcher + for &expert in self.possible_experts(game_state) { + let inner_agents: Vec> = self + .cfr_states + .iter() + .enumerate() + .map(|(idx, _state)| { + // pass in all the states and the index of the current agent + // This allows each agent to run simulations with how the other + // agents would play + if self.player_idx == idx { + let action = self.expert_action(game_state, expert); + Box::new(ArenaCFRAgent::new_with_forced_action( + self.cfr_states.clone(), + idx, + action, + )) as Box + } else { + Box::new(ArenaCFRAgent::new(self.cfr_states.clone(), idx)) as Box + } + }) + .collect(); + + // The historians to watch + let inner_historians = self + .cfr_states + .iter() + .enumerate() + .map(|(idx, state)| { + Box::new(ArenaCFRHistorian::new(state.clone(), idx)) as Box + }) + .collect(); + + let mut sim = HoldemSimulationBuilder::default() + .agents(inner_agents) + .historians(inner_historians) + .game_state(game_state.clone()) + .panic_on_historian_error(true) + .build() + .unwrap(); + + sim.run(); + + rewards[expert] = sim.game_state.player_winnings[self.player_idx]; + + // Reset the trees + for (state, save_point) in self.cfr_states.iter().zip(save_points.iter()) { + state.borrow_mut().restore_save_point(*save_point); + } + } + self.update_regrets(rewards) + } + + fn update_regrets(&mut self, rewards: Vec) { + let mut state = self.mut_state().borrow_mut(); + let current_node = state.get_mut_current_node().unwrap(); + + match &mut current_node.data { + NodeData::Player(player_data) => { + player_data.regrets.update_regret(aview1(&rewards)).unwrap() + } + _ => panic!("Expected a player node"), + }; + } + + fn possible_experts(&self, game_state: &GameState) -> impl Iterator { + match game_state.round { + Round::Preflop => { + if game_state.round_data.total_raise_count > 4 { + MAX_RAISE_PREFLOP_EXPERTS.iter() + } else { + PREFLOP_EXPERTS.iter() + } + } + _ => { + if game_state.round_data.total_raise_count > 4 { + MAX_RAISE_EXPERTS.iter() + } else { + EXPERTS.iter() + } + } + } + } + + fn num_possible_experts(&self, game_state: &GameState) -> usize { + match game_state.round { + Round::Preflop => PREFLOP_EXPERTS.len(), + _ => EXPERTS.len(), + } + } + + fn next_action(&mut self, game_state: &GameState) -> AgentAction { + self.state() + .borrow_mut() + .ensure_current_node(EnsureNodeType::Player(self.player_idx), game_state.round) + .unwrap(); + + // If the agent has been told to explore a path then do that + // and clear the forced action + if let Some(forced_action) = self.forced_action.take() { + forced_action + } else { + self.play_round(game_state) + } + } +} + +impl Agent for ArenaCFRAgent { + fn act(&mut self, _id: &uuid::Uuid, game_state: &GameState) -> AgentAction { + self.next_action(game_state) + } +} diff --git a/src/arena/cfr/historian.rs b/src/arena/cfr/historian.rs new file mode 100644 index 00000000..b4b9cecb --- /dev/null +++ b/src/arena/cfr/historian.rs @@ -0,0 +1,245 @@ +use std::{cell::RefCell, rc::Rc}; + +use crate::arena::{ + action::{Action, AgentAction, DealStartingHandPayload, PlayedActionPayload}, + game_state::Round, + historian::{Historian, HistorianError}, + GameState, +}; + +use super::{ + node::NodeData, + state::{EnsureNodeType, PlayerCFRState}, + LOWER_MULT, UPPER_MULT, +}; + +pub struct ArenaCFRHistorian { + pub state: Rc>, + pub player_idx: usize, +} + +impl ArenaCFRHistorian { + pub fn new(state: Rc>, player_idx: usize) -> ArenaCFRHistorian { + ArenaCFRHistorian { state, player_idx } + } + + fn played_action_to_idx( + &self, + game_state: &GameState, + played_action: &PlayedActionPayload, + ) -> usize { + match game_state.round { + Round::Preflop => self.preflop_action_to_idx(played_action), + _ => self.postflop_action_to_idx(played_action), + } + } + + fn preflop_action_to_idx(&self, played_action: &PlayedActionPayload) -> usize { + // Fold is 0 + // The rest will have to be figured out + match played_action.action { + AgentAction::Fold => 0, + AgentAction::Bet(_) => self.preflop_bet_action_to_idx(played_action), + } + } + + fn postflop_action_to_idx(&self, played_action: &PlayedActionPayload) -> usize { + // Fold is 0 + // The rest will have to be figured out + match played_action.action { + AgentAction::Fold => 0, + AgentAction::Bet(_) => self.postflot_bet_action_to_idx(played_action), + } + } + + // Guess which expert created this action + // that is a Bet of some kind. + // + // 1 -> Check + // 2 -> Min Raise + // 3 -> 1/2 pot + // 4 -> 2/3 pot + // 5 -> pot size + // 6 -> pot == player stack + // 7 -> All In + // 8 -> random ? huh? + fn postflot_bet_action_to_idx(&self, played_action: &PlayedActionPayload) -> usize { + let raise_amount = played_action.raise_amount(); + + // For 1/2 pot + let min_one_half = played_action.starting_pot * 0.5 * LOWER_MULT; + let max_one_half = played_action.starting_pot * 0.5 * UPPER_MULT; + + // for 2/3 pot + let min_two_thirds = played_action.starting_pot * 0.66666 * LOWER_MULT; + let max_two_thirds = played_action.starting_pot * 0.66666 * UPPER_MULT; + + // pot + let min_pot = played_action.starting_pot * LOWER_MULT; + let max_pot = played_action.starting_pot * UPPER_MULT; + + // maybe geo? + // + // What if the player is setting up for a shove next round + let min_per = raise_amount * LOWER_MULT; + // If we're setting up for a pot sized shove this is the expected pot + // The starting pot plus everyone still left calling + let expected_geo = + played_action.starting_pot + min_per * played_action.players_active.count() as f32; + + let min_geo_expected = expected_geo * LOWER_MULT; + let max_geo_expected = expected_geo * UPPER_MULT; + + if played_action.starting_bet == played_action.final_bet { + // Check + 1 + } else if raise_amount <= played_action.starting_min_raise * UPPER_MULT { + // min raise + 2 + } else if raise_amount <= min_one_half && raise_amount >= max_one_half { + // about 50% of pot raise + 3 + } else if raise_amount >= min_two_thirds && raise_amount <= max_two_thirds { + 4 + } else if raise_amount >= min_pot && raise_amount <= max_pot { + // About pot sized raise + 5 + } else if played_action.player_stack * LOWER_MULT >= min_geo_expected + && played_action.player_stack * UPPER_MULT <= max_geo_expected + { + // it look like the player is setting for a shove + 6 + } else if played_action.player_stack == 0.0 { + // All In + 7 + } else { + // Dunno random bet ? + 8 + } + } + fn preflop_bet_action_to_idx(&self, played_action: &PlayedActionPayload) -> usize { + let raise_amount = played_action.raise_amount(); + + // for 2/3 pot + let min_two_thirds = played_action.starting_pot * 0.66666 * LOWER_MULT; + let max_two_thirds = played_action.starting_pot * 0.66666 * UPPER_MULT; + + // pot + let min_pot = played_action.starting_pot * LOWER_MULT; + let max_pot = played_action.starting_pot * UPPER_MULT; + + if played_action.starting_bet == played_action.final_bet { + // Check + 1 + } else if raise_amount <= played_action.starting_min_raise * UPPER_MULT { + // min raise + 2 + } else if raise_amount >= min_two_thirds && raise_amount <= max_two_thirds { + 3 + } else if raise_amount >= min_pot && raise_amount <= max_pot { + // About pot sized raise + 4 + } else { + // random I guess + 5 + } + } + + fn handle_terminal_node(&mut self, game_state: &GameState) { + // Compute the utility for every player + let utility: Vec = game_state + .player_winnings + .iter() + .zip(game_state.player_bet.iter()) + .map(|(winnings, bet)| winnings - bet) + .collect(); + + let mut state = self.state.borrow_mut(); + // Well store that in the terminal node + if let Some(terminal_node) = state.get_mut_current_node() { + if let NodeData::Terminal(terminal_data) = &mut terminal_node.data { + terminal_data.utility = utility; + } + } + } + + fn handle_played_action_payload( + &mut self, + game_state: &GameState, + played_action: PlayedActionPayload, + ) -> Result<(), HistorianError> { + if played_action.idx == self.player_idx { + self.state + .try_borrow_mut()? + .ensure_current_node(EnsureNodeType::Player(played_action.idx), game_state.round)?; + } else { + self.state + .try_borrow_mut()? + .ensure_current_node(EnsureNodeType::Action(played_action.idx), game_state.round)?; + } + + // Use the current game state and the played action to get the index of the + // action in the current node's children. We will take that path + // next. + let action_idx = self.played_action_to_idx(game_state, &played_action); + self.state.try_borrow_mut()?.set_next_node(action_idx) + } +} + +impl Historian for ArenaCFRHistorian { + fn record_action( + &mut self, + _id: &uuid::Uuid, + game_state: &GameState, + action: Action, + ) -> Result<(), HistorianError> { + // Record the action in the game tree + match action { + Action::PlayedAction(played_action) => { + self.handle_played_action_payload(game_state, played_action) + } + Action::FailedAction(failed_action) => { + // An agent failed to take an appropriate action + // handle the result + self.handle_played_action_payload(game_state, failed_action.result) + } + Action::DealStartingHand(DealStartingHandPayload { card, idx, .. }) => { + // Only record the action if it's the player's action + // So we can't get information leakage + if idx == self.player_idx { + // This is a chance node + self.state + .try_borrow_mut()? + .ensure_current_node(EnsureNodeType::Chance, game_state.round)?; + self.state + .try_borrow_mut()? + .set_next_node(u8::from(card) as usize) + } else { + Ok(()) + } + } + Action::DealCommunity(card) => { + // This is chance node + self.state + .try_borrow_mut()? + .ensure_current_node(EnsureNodeType::Chance, game_state.round)?; + self.state + .try_borrow_mut()? + .set_next_node(u8::from(card) as usize)?; + Ok(()) + } + Action::RoundAdvance(Round::Complete) => { + // This is a terminal node + self.state + .try_borrow_mut()? + .ensure_current_node(EnsureNodeType::Terminal, Round::Complete)?; + + self.handle_terminal_node(game_state); + Ok(()) + } + // The rest of the actions are ignored (Partial payouts, or sitting down). + Action::GameStart(_) | Action::ForcedBet(_) | Action::PlayerSit(_) => Ok(()), + Action::RoundAdvance(_) | Action::Award(_) => Ok(()), + } + } +} diff --git a/src/arena/cfr/mod.rs b/src/arena/cfr/mod.rs new file mode 100644 index 00000000..1f79e2da --- /dev/null +++ b/src/arena/cfr/mod.rs @@ -0,0 +1,80 @@ +mod agent; +mod historian; +mod node; +mod state; + +pub const PREFLOP_EXPERTS: [usize; 6] = [0, 1, 2, 3, 4, 5]; +pub const EXPERTS: [usize; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; + +pub const MAX_RAISE_PREFLOP_EXPERTS: [usize; 2] = [0, 5]; +pub const MAX_RAISE_EXPERTS: [usize; 2] = [0, 7]; + +// The ranges that we consider for random bet sizes +pub const LOWER_MULT: f32 = 0.9; +pub const UPPER_MULT: f32 = 1.1; + +pub use agent::*; +pub use historian::*; +pub use state::*; + +#[cfg(test)] +mod tests { + use std::{cell::RefCell, rc::Rc}; + + use crate::arena::{historian::Historian, Agent, GameState, HoldemSimulationBuilder}; + + use super::*; + + #[test] + fn test_crf() { + let num_agents = 2; + let game_state = GameState::new(vec![100.0; num_agents], 10.0, 5.0, 0.0, 0); + // CFR states for each seat + let cfr_states: Vec<_> = (0..num_agents) + .map(|_| Rc::new(RefCell::new(PlayerCFRState::new(game_state.clone())))) + .collect(); + + let save_points: Vec = cfr_states + .iter() + .map(|state| state.borrow().save_point()) + .collect(); + + // Test a lot of simulations to show that we can add on state that overlaps + // however we don't need this to be a fuzz test + for _ in 0..10000 { + let agents: Vec> = cfr_states + .iter() + .enumerate() + .map(|(idx, _state)| { + // pass in all the states and the index of the current agent + // This allows each agent to run simulations with how the other + // agents would play + Box::new(ArenaCFRAgent::new(cfr_states.clone(), idx)) as Box + }) + .collect(); + + // The historians to watch + let historians = cfr_states + .iter() + .enumerate() + .map(|(idx, state)| { + Box::new(ArenaCFRHistorian::new(state.clone(), idx)) as Box + }) + .collect(); + + // Build the simulation + let mut sim = HoldemSimulationBuilder::default() + .agents(agents) + .historians(historians) + .game_state(game_state.clone()) + .build() + .unwrap(); + + sim.run(); + + for (idx, state) in cfr_states.iter().enumerate() { + state.borrow_mut().restore_save_point(save_points[idx]); + } + } + } +} diff --git a/src/arena/cfr/node.rs b/src/arena/cfr/node.rs new file mode 100644 index 00000000..fc785a24 --- /dev/null +++ b/src/arena/cfr/node.rs @@ -0,0 +1,76 @@ +use crate::arena::GameState; + +#[derive(Debug)] +pub struct RootData { + pub game_state: GameState, +} + +#[derive(Debug)] +pub struct ActionData { + /// The index of the play that this action is for + pub idx: usize, +} + +#[derive(Debug)] +pub struct PlayerData { + pub idx: usize, + pub regrets: little_sorry::RegretMatcher, +} + +#[derive(Debug)] +pub struct TerminalData { + pub utility: Vec, +} + +// The base node type for Poker CFR +#[derive(Debug)] +pub enum NodeData { + Root(RootData), + Chance, + Action(ActionData), + Player(PlayerData), + Terminal(TerminalData), +} + +impl NodeData { + pub fn is_terminal(&self) -> bool { + matches!(self, NodeData::Terminal(_)) + } + + pub fn is_chance(&self) -> bool { + matches!(self, NodeData::Chance) + } + + pub fn is_player(&self) -> bool { + matches!(self, NodeData::Player(_)) + } + + pub fn is_action(&self) -> bool { + matches!(self, NodeData::Action(_)) + } + + pub fn is_root(&self) -> bool { + matches!(self, NodeData::Root(_)) + } +} + +impl std::fmt::Display for NodeData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NodeData::Root(_) => write!(f, "Root"), + NodeData::Chance => write!(f, "Chance"), + NodeData::Action(_) => write!(f, "Action"), + NodeData::Player(_) => write!(f, "Player"), + NodeData::Terminal(_) => write!(f, "Terminal"), + } + } +} + +#[derive(Debug)] +pub struct Node { + pub idx: usize, + pub data: NodeData, + pub parent: Option, + pub children: Vec>, + pub count: Vec, +} diff --git a/src/arena/cfr/state.rs b/src/arena/cfr/state.rs new file mode 100644 index 00000000..e2358d57 --- /dev/null +++ b/src/arena/cfr/state.rs @@ -0,0 +1,249 @@ +use uuid::Uuid; + +use crate::arena::{game_state::Round, historian::HistorianError, GameState}; + +use super::{ + node::{ActionData, Node, NodeData, PlayerData, RootData, TerminalData}, + EXPERTS, PREFLOP_EXPERTS, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CFRSavePoint { + current_node: Option, + previous_node: Option, + child_idx: Option, +} + +pub struct PlayerCFRState { + pub id: Uuid, + current_node: Option, + previous_node: Option, + child_idx: Option, + arena: Vec, +} + +// These are the possible node types that we +// encounter on the way to the next node. +#[derive(Debug, PartialEq, Eq)] +pub enum EnsureNodeType { + // Ensure that it's a player node for this index + Player(usize), + // Endure that it's an action node for this index + Action(usize), + Chance, + Terminal, +} + +impl PlayerCFRState { + pub fn new(game_state: GameState) -> PlayerCFRState { + let arena: Vec = vec![Node { + idx: 0, + data: NodeData::Root(RootData { + game_state: game_state.clone(), + }), + parent: None, + children: vec![None], + + // At some point we probably want to initialize this + // and inline it + count: vec![0], + }]; + + let id = uuid::Uuid::now_v7(); + + PlayerCFRState { + id, + // We don't know what the next node will be yet + current_node: None, + // Root node is the previous node + // And by convention, there's only one child at index 0 + previous_node: Some(0), + child_idx: Some(0), + arena, + } + } + + pub fn get_current_node(&self) -> Option<&Node> { + self.current_node.map(|idx| &self.arena[idx]) + } + pub fn get_mut_current_node(&mut self) -> Option<&mut Node> { + self.current_node.map(|idx| &mut self.arena[idx]) + } + + pub fn get(&self, idx: usize) -> Option<&Node> { + self.arena.get(idx) + } + + pub fn get_mut(&mut self, idx: usize) -> Option<&mut Node> { + self.arena.get_mut(idx) + } + + pub fn set_next_node(&mut self, child_idx: usize) -> Result<(), HistorianError> { + if let Some(current_idx) = self.current_node { + // Ensure that there are children to pull from + // And counters to keep track of it. + if self.arena[current_idx].children.len() <= child_idx { + self.arena[current_idx].children.resize(child_idx + 1, None); + self.arena[current_idx].count.resize(child_idx + 1, 0) + } + + self.current_node = self.arena[current_idx].children[child_idx]; + self.previous_node = Some(current_idx); + self.child_idx = Some(child_idx); + + // Increment the count of times we + // have visited the child via this path + self.arena[current_idx].count[child_idx] += 1; + Ok(()) + } else { + Err(HistorianError::CFRNodeNotFound) + } + } + + pub fn num_children(&self, data: &NodeData) -> usize { + match data { + NodeData::Root(_) => 1, + NodeData::Player(_) => 6, + NodeData::Action(_) => 8, + NodeData::Chance => 52, + NodeData::Terminal(_) => 0, + } + } + + pub fn add_current_node(&mut self, data: NodeData) -> usize { + let idx = self.arena.len(); + + let num_children = self.num_children(&data); + + let node = Node { + idx, + data, + parent: self.previous_node, + children: vec![None; num_children], + count: vec![0; num_children], + }; + + // Add the node to the arena + self.arena.push(node); + // The previous node's child at the child index is now the current node + let previous_node = &mut self.arena[self.previous_node.unwrap()]; + + // This the child index that the new node will be at + let path_idx = self.child_idx.unwrap(); + if previous_node.children.len() <= path_idx { + previous_node.children.resize(path_idx + 1, None); + previous_node.count.resize(path_idx + 1, 0) + } + previous_node.children[self.child_idx.unwrap()] = Some(idx); + // And the previously empty current node is now the new node + self.current_node = Some(idx); + + idx + } + + pub fn ensure_current_node( + &mut self, + node_type: EnsureNodeType, + round: Round, + ) -> Result<(), HistorianError> { + if let Some(current_node) = self.get_current_node() { + // debug assert that the current node's data matches the ensure node type + match node_type { + EnsureNodeType::Player(idx) => { + if let NodeData::Player(player_data) = ¤t_node.data { + if player_data.idx == idx { + Ok(()) + } else { + Err(HistorianError::CFRUnexpectedNode( + "Expected Player idx does not match".to_string(), + )) + } + } else { + Err(HistorianError::CFRUnexpectedNode(format!( + "Expected Player found #{current_node:?}" + ))) + } + } + EnsureNodeType::Action(_idx) => { + if let NodeData::Action(action_data) = ¤t_node.data { + if action_data.idx == _idx { + Ok(()) + } else { + Err(HistorianError::CFRUnexpectedNode( + "Expected Action idx does not match".to_string(), + )) + } + } else { + Err(HistorianError::CFRUnexpectedNode(format!( + "Expected Action found #{current_node:?}" + ))) + } + } + EnsureNodeType::Chance => { + if current_node.data.is_chance() { + Ok(()) + } else { + Err(HistorianError::CFRUnexpectedNode(format!( + "Expected Chance found #{current_node:?}" + ))) + } + } + EnsureNodeType::Terminal => { + if current_node.data.is_terminal() { + Ok(()) + } else { + Err(HistorianError::CFRUnexpectedNode(format!( + "Expected Terminal found #{current_node:?}" + ))) + } + } + } + } else { + // Based upon the node type create the default node data + let data = match node_type { + EnsureNodeType::Player(idx) => NodeData::Player(PlayerData { + idx, + regrets: self.build_regret_matcher(round), + }), + EnsureNodeType::Action(idx) => NodeData::Action(ActionData { idx }), + EnsureNodeType::Chance => NodeData::Chance, + EnsureNodeType::Terminal => NodeData::Terminal(TerminalData { utility: vec![] }), + }; + // Then add that to self.state as the current node + self.add_current_node(data); + Ok(()) + } + } + + pub fn build_regret_matcher(&self, round: Round) -> little_sorry::RegretMatcher { + let num_experts = self.num_experts(round); + little_sorry::RegretMatcher::new(num_experts).unwrap() + } + + pub fn num_experts(&self, round: Round) -> usize { + match round { + Round::Preflop => PREFLOP_EXPERTS.len(), + _ => EXPERTS.len(), + } + } + + pub fn reset(&mut self) { + self.current_node = self.arena[0].children[0]; + self.previous_node = Some(0); + self.child_idx = Some(0); + } + + pub fn save_point(&self) -> CFRSavePoint { + CFRSavePoint { + current_node: self.current_node, + previous_node: self.previous_node, + child_idx: self.child_idx, + } + } + + pub fn restore_save_point(&mut self, save_state: CFRSavePoint) { + self.current_node = save_state.current_node; + self.previous_node = save_state.previous_node; + self.child_idx = save_state.child_idx; + } +} diff --git a/src/arena/historian/mod.rs b/src/arena/historian/mod.rs index 1515de96..4cc64aed 100644 --- a/src/arena/historian/mod.rs +++ b/src/arena/historian/mod.rs @@ -12,6 +12,10 @@ pub enum HistorianError { BorrowMutError(#[from] std::cell::BorrowMutError), #[error("Borrow Error: {0}")] BorrowError(#[from] std::cell::BorrowError), + #[error("Unexpected CFR Node: {0}")] + CFRUnexpectedNode(String), + #[error("Expected Node not found in tree")] + CFRNodeNotFound, } /// Historians are a way for the simulation to record or notify of diff --git a/src/arena/mod.rs b/src/arena/mod.rs index 301f0663..9939dcc8 100644 --- a/src/arena/mod.rs +++ b/src/arena/mod.rs @@ -68,6 +68,7 @@ pub mod action; pub mod agent; +pub mod cfr; pub mod competition; pub mod errors; pub mod game_state; diff --git a/src/arena/simulation.rs b/src/arena/simulation.rs index 028320a9..ec031683 100644 --- a/src/arena/simulation.rs +++ b/src/arena/simulation.rs @@ -442,7 +442,7 @@ impl HoldemSimulation { let action = self.agents[idx].act(&self.id, &self.game_state); event!(parent: &span, Level::TRACE, ?action, idx); - self.run_agent_action(action) + self.run_agent_action(action); } /// Given the action that an agent wants to take, this function will