Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tree multithread-friendly #1

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::marker::PhantomData;
use std::ops::{Add, AddAssign, Div};

use ego_tree::{NodeId, NodeMut, Tree};
use noisy_float::types::n64;
use num_traits::{ToPrimitive, Zero};
use rand::{Rng, thread_rng};
use rand::prelude::SliceRandom;
Expand Down Expand Up @@ -181,6 +182,9 @@ for DefaultUctEvaluator
parent_visits: Nat,
&c: &Self::Args,
) -> Num {
if child.n_visits == 0 {
return n64(0f64);
}
uct_value(
parent_visits,
child.sum_rewards.to_f64().unwrap(),
Expand Down
112 changes: 68 additions & 44 deletions src/tree_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,47 @@ use core::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::marker::PhantomData;
use std::ops::{Add, Div};
use std::sync::Mutex;

use ascii_tree::{Tree, write_tree};
use ascii_tree::Tree::{Leaf, Node};
use ascii_tree::{write_tree, Tree};
use ego_tree::NodeId;
use num_traits::{ToPrimitive, Zero};

use crate::aliases::{LazyMctsNode, LazyMctsTree};
use crate::Evaluator;
use crate::traits::{BackPropPolicy, GameTrait, LazyTreePolicy, Playout};
use crate::Evaluator;

/// This is a special MCTS because it doesn't store the state in the node but instead stores the
/// historic to the node.
#[derive(Clone)]

pub struct LazyMcts<'a, State, TP, PP, BP, EV, AddInfo, Reward>
where
State: GameTrait,
TP: LazyTreePolicy<State, EV, AddInfo, Reward>,
PP: Playout<State>,
BP: BackPropPolicy<Vec<State::Move>, State::Move, Reward, AddInfo, EV::EvalResult>,
EV: Evaluator<State, Reward, AddInfo>,
AddInfo: Clone + Default,
Reward: Clone,
where
State: GameTrait,
TP: LazyTreePolicy<State, EV, AddInfo, Reward>,
PP: Playout<State>,
BP: BackPropPolicy<Vec<State::Move>, State::Move, Reward, AddInfo, EV::EvalResult>,
EV: Evaluator<State, Reward, AddInfo>,
AddInfo: Clone + Default,
Reward: Clone,
{
root_state: &'a State,
tree_policy: PhantomData<TP>,
playout_policy: PhantomData<PP>,
backprop_policy: PhantomData<BP>,
evaluator: PhantomData<EV>,
tree: LazyMctsTree<State, Reward, AddInfo>,
tree: Mutex<LazyMctsTree<State, Reward, AddInfo>>,
}

impl<'a, State, TP, PP, BP, EV, A, R> LazyMcts<'a, State, TP, PP, BP, EV, A, R>
where
State: GameTrait,
TP: LazyTreePolicy<State, EV, A, R>,
PP: Playout<State>,
BP: BackPropPolicy<Vec<State::Move>, State::Move, R, A, EV::EvalResult>,
EV: Evaluator<State, R, A>,
A: Clone + Default,
R: Clone + Div + ToPrimitive + Zero + Add + Display,
where
State: GameTrait,
TP: LazyTreePolicy<State, EV, A, R>,
PP: Playout<State>,
BP: BackPropPolicy<Vec<State::Move>, State::Move, R, A, EV::EvalResult>,
EV: Evaluator<State, R, A>,
A: Clone + Default,
R: Clone + Div + ToPrimitive + Zero + Add + Display,
{
pub fn new(root_state: &'a State) -> Self {
Self::with_capacity(root_state, 0)
Expand All @@ -65,29 +66,33 @@ impl<'a, State, TP, PP, BP, EV, A, R> LazyMcts<'a, State, TP, PP, BP, EV, A, R>
playout_policy: PhantomData,
backprop_policy: PhantomData,
evaluator: PhantomData,
tree,
tree: Mutex::new(tree),
}
}

/// Executes one selection, expansion?, simulation, backpropagation.
pub fn execute(&mut self, evaluation_args: &EV::Args, playout_args: PP::Args) {
let (node_id, state) =
TP::tree_policy(&mut self.tree, self.root_state.clone(), evaluation_args);
pub fn execute(&self, evaluation_args: &EV::Args, playout_args: PP::Args) {
let mut tree = self.tree.lock().unwrap();
let (node_id, state) = TP::tree_policy(&mut tree, self.root_state.clone(), evaluation_args);
drop(tree);

let final_state = PP::playout(state, playout_args);
let eval = EV::evaluate_leaf(final_state, &self.root_state.player_turn());
BP::backprop(&mut self.tree, node_id, eval);

let mut tree = self.tree.lock().unwrap();
BP::backprop(&mut tree, node_id, eval);
}

/// Returns the best move from the root.
pub fn best_move(&self, evaluator_args: &EV::Args) -> State::Move {
let tree = self.tree.lock().unwrap();
let best_child = TP::best_child(
&self.tree,
&tree,
&self.root_state.player_turn(),
self.tree.root().id(),
tree.root().id(),
evaluator_args,
);
self.tree
.get(best_child)
tree.get(best_child)
.unwrap()
.value()
.state
Expand All @@ -97,14 +102,15 @@ impl<'a, State, TP, PP, BP, EV, A, R> LazyMcts<'a, State, TP, PP, BP, EV, A, R>
}

pub fn write_tree(&self) -> String {
let tree = self.dfs(self.tree.root().id());
let tree = self.dfs(self.tree.lock().unwrap().root().id());
let mut output = String::new();
write_tree(&mut output, &tree).unwrap();
output
}

fn dfs(&self, node_id: NodeId) -> Tree {
let node = self.tree.get(node_id).unwrap();
let tree = self.tree.lock().unwrap();
let node = tree.get(node_id).unwrap();
if node.has_children() {
let mut nodes = vec![];
for c in node.children() {
Expand All @@ -122,24 +128,42 @@ impl<'a, State, TP, PP, BP, EV, A, R> LazyMcts<'a, State, TP, PP, BP, EV, A, R>
)])
}
}

pub fn tree(&self) -> &LazyMctsTree<State, R, A> {
&self.tree
}
}

impl<State, TP, PP, BP, EV, A, R> Debug for LazyMcts<'_, State, TP, PP, BP, EV, A, R>
where
State: GameTrait,
TP: LazyTreePolicy<State, EV, A, R>,
PP: Playout<State>,
BP: BackPropPolicy<Vec<State::Move>, State::Move, R, A, EV::EvalResult>,
EV: Evaluator<State, R, A>,
EV::EvalResult: Debug,
A: Clone + Default + Debug,
R: Clone + Debug + Div + Add + Zero + ToPrimitive,
where
State: GameTrait,
TP: LazyTreePolicy<State, EV, A, R>,
PP: Playout<State>,
BP: BackPropPolicy<Vec<State::Move>, State::Move, R, A, EV::EvalResult>,
EV: Evaluator<State, R, A>,
EV::EvalResult: Debug,
A: Clone + Default + Debug,
R: Clone + Debug + Div + Add + Zero + ToPrimitive,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(&format!("{:?}", self.tree))
}
}

impl<State, TP, PP, BP, EV, A, R> Clone for LazyMcts<'_, State, TP, PP, BP, EV, A, R>
where
State: GameTrait,
TP: LazyTreePolicy<State, EV, A, R>,
PP: Playout<State>,
BP: BackPropPolicy<Vec<State::Move>, State::Move, R, A, EV::EvalResult>,
EV: Evaluator<State, R, A>,
A: Clone + Default,
R: Clone + Debug + Div + Add + Zero + ToPrimitive,
{
fn clone(&self) -> Self {
Self {
root_state: self.root_state,
tree_policy: PhantomData,
playout_policy: PhantomData,
backprop_policy: PhantomData,
evaluator: PhantomData,
tree: Mutex::new(self.tree.lock().unwrap().clone()),
}
}
}