Skip to content

Commit

Permalink
feat: add counterfactual regret minimization node and state
Browse files Browse the repository at this point in the history
Summary:
- Add CFR game state tree that uses a vec as an allocation arena.
- Add a base cfr agent that can be filled out
- Add a historian that will follow the game for an agent filling out the
  cfr tree.

Test Plan:
- Added a test for CFRState
- Added doc tests for state
  • Loading branch information
elliottneilclark committed Jan 2, 2025
1 parent 3bdf181 commit 614fdc9
Show file tree
Hide file tree
Showing 9 changed files with 686 additions and 3 deletions.
96 changes: 96 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ arbitrary = { version = "~1.4.1", optional = true, features = ["derive"] }
tracing = { version = "~0.1.41", optional = true}
approx = { version = "~0.5.1", optional = true}
uuid = {version = "~1.11.0", optional = true, features = ["v7"]}
little-sorry = { version = "~0.5.0", optional = true}
ndarray = { version = "~0.16.1", optional = true}

[dev-dependencies]
criterion = "0.5.1"
Expand All @@ -33,7 +35,7 @@ approx = { version = "0.5.1"}
default = ["arena", "serde"]
uuid = ["dep:uuid"]
serde = ["dep:serde", "dep:serde_json", "uuid?/serde"]
arena = ["dep:tracing", "uuid"]
arena = ["dep:tracing", "dep:little-sorry", "dep:ndarray", "uuid"]
arena-test-util = ["arena", "dep:approx"]

[[bench]]
Expand Down
45 changes: 45 additions & 0 deletions src/arena/cfr/action_generator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::arena::{
action::{Action, AgentAction},
GameState,
};

use super::{CFRState, TraversalState};

pub trait ActionGenerator {
fn new(cfr_state: CFRState, traversal_state: TraversalState) -> Self;

fn action_to_idx(&self, action: &Action) -> usize;

fn gen_action(&self, game_state: &GameState) -> AgentAction;

fn num_possible_actions(&self) -> usize;
}

pub struct CFRActionGenerator {
cfr_state: CFRState,
traversal_state: TraversalState,
}

impl ActionGenerator for CFRActionGenerator {
fn action_to_idx(&self, _action: &Action) -> usize {
todo!()
}

fn gen_action(&self, _game_state: &GameState) -> AgentAction {
todo!()
}

fn new(cfr_state: CFRState, traversal_state: TraversalState) -> Self {
CFRActionGenerator {
cfr_state,
traversal_state,
}
}

fn num_possible_actions(&self) -> usize {
// TODO: Implement this. It has to always be less
// than 52 since we use the same children array
// for all nodes including chance nodes.
8
}
}
96 changes: 96 additions & 0 deletions src/arena/cfr/agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use crate::arena::{
action::{Action, AgentAction},
Agent, GameState, Historian, HistorianError,
};

use super::{state::CFRState, state::TraversalState};

pub struct CFRAgent {
pub traversal_state: TraversalState,
pub cfr_state: CFRState,
}

pub struct CFRHistorian {
pub current_traversal_state: TraversalState,
pub cfr_state: CFRState,
}

impl CFRHistorian {
pub fn new(traversal_state: TraversalState, cfr_state: CFRState) -> Self {
CFRHistorian {
current_traversal_state: traversal_state,
cfr_state,
}
}
}

impl Historian for CFRHistorian {
fn record_action(
&mut self,
_id: &uuid::Uuid,
_game_state: &GameState,
action: Action,
) -> Result<(), HistorianError> {
match action {
// These are all assumed from game start and encoded in the root node.
Action::GameStart(_) | Action::ForcedBet(_) | Action::PlayerSit(_) => Ok(()),
// We don't encode round advance in the tree because it never changes the outcome.
Action::RoundAdvance(_) => Ok(()),
// Rather than use award since it can be for a side pot we use the final award ammount
// in the terminal node.
Action::Award(_) => Ok(()),
Action::DealStartingHand(_deal_starting_hand_payload) => todo!(),
Action::PlayedAction(_played_action_payload) => todo!(),
Action::FailedAction(_failed_action_payload) => todo!(),
Action::DealCommunity(_card) => todo!(),
}
}
}

impl CFRAgent {
pub fn new(
cfr_state: CFRState,
node_idx: usize,
chosen_child: usize,
player_idx: usize,
) -> Self {
CFRAgent {
cfr_state,
traversal_state: TraversalState::new(node_idx, chosen_child, player_idx),
}
}

pub fn historian(&self) -> CFRHistorian {
CFRHistorian::new(self.traversal_state.clone(), self.cfr_state.clone())
}
}

impl Agent for CFRAgent {
fn act(
&mut self,
_id: &uuid::Uuid,
_game_state: &GameState,
) -> crate::arena::action::AgentAction {
AgentAction::Fold
}
}

#[cfg(test)]
mod tests {
use crate::arena::game_state;

use super::*;

#[test]
fn test_create_agent() {
let game_state = game_state::GameState::new_starting(vec![100.0; 3], 10.0, 5.0, 0.0, 0);
let cfr_state = CFRState::new(game_state);
let _ = CFRAgent::new(
cfr_state.clone(),
// we are still at root so 0
0,
0,
0,
);
}
}
Loading

0 comments on commit 614fdc9

Please sign in to comment.