Skip to content

Commit

Permalink
feat: add back cfr state
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottneilclark committed Dec 28, 2024
1 parent 0d91025 commit 25d1867
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 14 deletions.
11 changes: 3 additions & 8 deletions src/arena/cfr/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
mod node;
mod state;

pub use node::{Node, NodeData};
pub use node::{Node, NodeData, RootData, PlayerData, TerminalData};
pub use state::CFRState;

#[cfg(test)]
mod tests {

use std::vec;
use std::{cell::RefCell, rc::Rc};

use crate::arena::game_state::{Round, RoundData};

Expand Down Expand Up @@ -39,10 +39,5 @@ mod tests {
0.0,
0,
);

// let cfr_states: Vec<_> = (0..num_agents)
// .map(|_|
// Rc::new(RefCell::new(PlayerCFRState::new(game_state.clone()))))
// .collect();
}
}
33 changes: 27 additions & 6 deletions src/arena/cfr/node.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
use crate::arena::GameState;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct RootData {
pub game_state: GameState,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct PlayerData {
pub player_idx: usize,
pub regrets: little_sorry::RegretMatcher,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct TerminalData {
pub utility: f32,
}

// The base node type for Poker CFR
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum NodeData {
Root(RootData),
Chance,
Expand Down Expand Up @@ -54,7 +53,7 @@ impl std::fmt::Display for NodeData {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Node {
pub idx: usize,
pub data: NodeData,
Expand All @@ -74,6 +73,28 @@ impl Node {
}
}

/// Create a new node with the provided index, parent index, and data.
///
/// # Arguments
///
/// * `idx` - The index of the node
/// * `parent` - The index of the parent node
/// * `data` - The data for the node
///
/// # Returns
///
/// A new node with the provided index, parent index, and data.
///
/// # Example
///
/// ```
/// use rs_poker::arena::cfr::{Node, NodeData};
///
/// let idx = 1;
/// let parent = 0;
/// let data = NodeData::Chance;
/// let node = Node::new(idx, parent, data);
/// ```
pub fn new(idx: usize, parent: usize, data: NodeData) -> Self {
Node {
idx: idx,
Expand Down
69 changes: 69 additions & 0 deletions src/arena/cfr/state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use crate::arena::GameState;

use super::{Node, NodeData};



#[derive(Debug, Clone)]
pub struct CFRState {
pub nodes: Vec<Node>,
next_node_idx: usize,
}

impl CFRState {
pub fn new(game_state: GameState) -> Self {
CFRState {
nodes: vec![Node::new_root(game_state)],
next_node_idx: 1,
}
}

pub fn add(&mut self, parent_idx: usize, data: NodeData) -> usize {
let idx = self.next_node_idx;
self.next_node_idx += 1;

let node = Node::new(idx, parent_idx, data);
self.nodes.push(node);

idx
}

pub fn get(&self, idx: usize) -> Option<&Node> {
self.nodes.get(idx)
}

pub fn get_mut(&mut self, idx: usize) -> Option<&mut Node> {
self.nodes.get_mut(idx)
}
}

#[cfg(test)]
mod tests {
use crate::arena::cfr::{NodeData, PlayerData};

use crate::arena::GameState;

use super::CFRState;

#[test]
fn test_add_get_node() {
// Create a
let mut state = CFRState::new(GameState::new_starting(
vec![100.0; 3],
10.0,
5.0,
0.0,
0,
));

let player_idx: usize = state.add(0, NodeData::Player(PlayerData { player_idx: 0 }));

let node = state.get(player_idx).unwrap().clone();
match node.data {
NodeData::Player(data) => {
assert_eq!(data.player_idx, 0);
}
_ => panic!("Expected player data"),
}
}
}

0 comments on commit 25d1867

Please sign in to comment.