Skip to content

Commit

Permalink
Fix potential crash (#103)
Browse files Browse the repository at this point in the history
avoid overwriting NodeData for tablebase nodes
delete inCheck from Node
  • Loading branch information
QueensGambit authored May 9, 2021
1 parent eaa6465 commit 157a870
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 20 deletions.
5 changes: 2 additions & 3 deletions engine/src/agents/mctsagent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ shared_ptr<Node> MCTSAgent::get_root_node_from_tree(StateObj *state)
void MCTSAgent::create_new_root_node(StateObj* state)
{
info_string("create new tree");
// TODO: Make sure that "inCheck=False" does not cause issues
#ifdef MCTS_STORE_STATES
rootNode = make_shared<Node>(state->clone(), false, searchSettings);
rootNode = make_shared<Node>(state->clone(), searchSettings);
#else
rootNode = make_shared<Node>(state, false, searchSettings);
rootNode = make_shared<Node>(state, searchSettings);
#endif
#ifdef SEARCH_UCT
unique_ptr<StateObj> newState = unique_ptr<StateObj>(state->clone());
Expand Down
2 changes: 1 addition & 1 deletion engine/src/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ const string engineName = "MultiAra";
const string engineName = "ClassicAra";
#endif

const string engineVersion = "0.9.2.post1-Dev";
const string engineVersion = "0.9.2.post2";
const string engineAuthors = "Johannes Czech, Moritz Willig, Alena Beyer et al.";

#define LOSS_VALUE -1
Expand Down
11 changes: 6 additions & 5 deletions engine/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void Node::set_auxiliary_outputs(const float* auxiliaryOutputs)
}
#endif

Node::Node(StateObj* state, bool inCheck, const SearchSettings* searchSettings):
Node::Node(StateObj* state, const SearchSettings* searchSettings):
legalActions(state->legal_actions()),
key(state->hash_key()),
valueSum(0),
Expand All @@ -96,7 +96,7 @@ Node::Node(StateObj* state, bool inCheck, const SearchSettings* searchSettings):
sorted(false)
{
// specify the number of direct child nodes of this node
check_for_terminal(state, inCheck);
check_for_terminal(state);
#ifdef MCTS_TB_SUPPORT
if (searchSettings->useTablebase && !isTerminal) {
check_for_tablebase_wdl(state);
Expand Down Expand Up @@ -556,7 +556,9 @@ size_t Node::get_number_child_nodes() const
void Node::prepare_node_for_visits()
{
sort_moves_by_probabilities();
init_node_data();
if (d == nullptr) { // mark_tablebase() initializes the NodeData
init_node_data();
}
#ifdef MCTS_STORE_STATES
state->prepare_action();
#endif
Expand Down Expand Up @@ -751,7 +753,7 @@ void Node::mark_as_terminal()
d->noVisitIdx = 0;
}

void Node::check_for_terminal(StateObj* pos, bool inCheck)
void Node::check_for_terminal(StateObj* pos)
{
float customValue;
TerminalType terminalType = pos->is_terminal(get_number_child_nodes(), customValue);
Expand Down Expand Up @@ -802,7 +804,6 @@ void Node::check_for_tablebase_wdl(StateObj* state)
void Node::mark_as_tablebase()
{
init_node_data();
fully_expand_node();
isTablebase = true;
}
#endif
Expand Down
7 changes: 2 additions & 5 deletions engine/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ class Node
public:
/**
* @brief Node Primary constructor which is used when expanding a node during search
* @param parentNode Pointer to parent node
* @param move Move which led to current board state
* @param State Corresponding state object
* @param searchSettings Pointer to the searchSettings
*/
Node(StateObj *state,
bool inCheck,
const SearchSettings* searchSettings);

/**
Expand Down Expand Up @@ -477,9 +475,8 @@ class Node
/**
* @brief check_for_terminal Checks if the given board position is a terminal node and updates isTerminal
* @param state Current board position for this node
* @param inCheck Boolean indicating if the king is in check
*/
void check_for_terminal(StateObj* state, bool inCheck);
void check_for_terminal(StateObj* state);

#ifdef MCTS_TB_SUPPORT
/**
Expand Down
16 changes: 11 additions & 5 deletions engine/src/searchthread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void SearchThread::set_is_running(bool value)
isRunning = value;
}

NodeBackup SearchThread::add_new_node_to_tree(StateObj* newState, Node* parentNode, ChildIdx childIdx, bool inCheck)
NodeBackup SearchThread::add_new_node_to_tree(StateObj* newState, Node* parentNode, ChildIdx childIdx)
{
if(searchSettings->useMCGS) {
mapWithMutex->mtx.lock();
Expand All @@ -107,7 +107,7 @@ NodeBackup SearchThread::add_new_node_to_tree(StateObj* newState, Node* parentNo
mapWithMutex->mtx.unlock();
}
assert(parentNode != nullptr);
shared_ptr<Node> newNode = make_shared<Node>(newState, inCheck, searchSettings);
shared_ptr<Node> newNode = make_shared<Node>(newState, searchSettings);
// connect the Node to the parent
parentNode->add_new_child_node(newNode, childIdx);
return NODE_NEW_NODE;
Expand Down Expand Up @@ -210,13 +210,12 @@ Node* SearchThread::get_new_child_to_evaluate(ChildIdx& childIdx, NodeDescriptio
newState->do_action(action);
}
#endif
const bool inCheck = newState->gives_check(currentNode->get_action(childIdx));
newState->do_action(currentNode->get_action(childIdx));
currentNode->increment_no_visit_idx();
#ifdef MCTS_STORE_STATES
description.type = add_new_node_to_tree(newState, currentNode, childIdx, inCheck);
description.type = add_new_node_to_tree(newState, currentNode, childIdx);
#else
description.type = add_new_node_to_tree(newState.get(), currentNode, childIdx, inCheck);
description.type = add_new_node_to_tree(newState.get(), currentNode, childIdx);
#endif
currentNode->unlock();

Expand Down Expand Up @@ -306,6 +305,10 @@ void SearchThread::set_nn_results_to_child_nodes()
{
size_t batchIdx = 0;
for (auto node: *newNodes) {
if (node == nullptr) {
info_string("nullptr newNode");
continue;
}
if (!node->is_terminal()) {
fill_nn_results(batchIdx, net->is_policy_map(), valueOutputs, probOutputs, auxiliaryOutputs, node, tbHits, newNodeSideToMove->get_element(batchIdx), searchSettings, rootNode->is_tablebase());
}
Expand Down Expand Up @@ -417,6 +420,9 @@ void run_search_thread(SearchThread *t)
void SearchThread::backup_values(FixedVector<Node*>& nodes, vector<Trajectory>& trajectories) {
for (size_t idx = 0; idx < nodes.size(); ++idx) {
Node* node = nodes.get_element(idx);
if (node == nullptr) {
continue;
}
#ifdef MCTS_TB_SUPPORT
const bool solveForTerminal = searchSettings->mctsSolver && node->is_tablebase();
backup_value<false>(node->get_value(), searchSettings->virtualLoss, trajectories[idx], solveForTerminal);
Expand Down
2 changes: 1 addition & 1 deletion engine/src/searchthread.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class SearchThread : NeuralNetAPIUser
* @param inCheck Defines if the current position sets a player in check
* @return Returns NODE_TRANSPOSITION if a tranpsosition node was added and NODE_NEW_NODE otherwise
*/
NodeBackup add_new_node_to_tree(StateObj* newPos, Node* parentNode, ChildIdx childIdx, bool inCheck);
NodeBackup add_new_node_to_tree(StateObj* newPos, Node* parentNode, ChildIdx childIdx);

/**
* @brief reset_tb_hits Sets the number of table hits to 0
Expand Down

0 comments on commit 157a870

Please sign in to comment.