diff --git a/engine/src/agents/mctsagent.cpp b/engine/src/agents/mctsagent.cpp index 30796e0c..ea909c89 100644 --- a/engine/src/agents/mctsagent.cpp +++ b/engine/src/agents/mctsagent.cpp @@ -166,11 +166,10 @@ shared_ptr MCTSAgent::get_root_node_from_tree(StateObj *state) void MCTSAgent::create_new_root_node(StateObj* state) { info_string("create new tree"); - // TODO: Make sure that "inCheck=False" does not cause issues #ifdef MCTS_STORE_STATES - rootNode = make_shared(state->clone(), false, searchSettings); + rootNode = make_shared(state->clone(), searchSettings); #else - rootNode = make_shared(state, false, searchSettings); + rootNode = make_shared(state, searchSettings); #endif #ifdef SEARCH_UCT unique_ptr newState = unique_ptr(state->clone()); diff --git a/engine/src/constants.h b/engine/src/constants.h index 1b8d4f44..847459c5 100644 --- a/engine/src/constants.h +++ b/engine/src/constants.h @@ -53,7 +53,7 @@ const string engineName = "MultiAra"; const string engineName = "ClassicAra"; #endif -const string engineVersion = "0.9.2.post1-Dev"; +const string engineVersion = "0.9.2.post2"; const string engineAuthors = "Johannes Czech, Moritz Willig, Alena Beyer et al."; #define LOSS_VALUE -1 diff --git a/engine/src/node.cpp b/engine/src/node.cpp index 1dc80d70..7c786baa 100644 --- a/engine/src/node.cpp +++ b/engine/src/node.cpp @@ -79,7 +79,7 @@ void Node::set_auxiliary_outputs(const float* auxiliaryOutputs) } #endif -Node::Node(StateObj* state, bool inCheck, const SearchSettings* searchSettings): +Node::Node(StateObj* state, const SearchSettings* searchSettings): legalActions(state->legal_actions()), key(state->hash_key()), valueSum(0), @@ -96,7 +96,7 @@ Node::Node(StateObj* state, bool inCheck, const SearchSettings* searchSettings): sorted(false) { // specify the number of direct child nodes of this node - check_for_terminal(state, inCheck); + check_for_terminal(state); #ifdef MCTS_TB_SUPPORT if (searchSettings->useTablebase && !isTerminal) { check_for_tablebase_wdl(state); @@ -556,7 +556,9 @@ size_t Node::get_number_child_nodes() const void Node::prepare_node_for_visits() { sort_moves_by_probabilities(); - init_node_data(); + if (d == nullptr) { // mark_tablebase() initializes the NodeData + init_node_data(); + } #ifdef MCTS_STORE_STATES state->prepare_action(); #endif @@ -751,7 +753,7 @@ void Node::mark_as_terminal() d->noVisitIdx = 0; } -void Node::check_for_terminal(StateObj* pos, bool inCheck) +void Node::check_for_terminal(StateObj* pos) { float customValue; TerminalType terminalType = pos->is_terminal(get_number_child_nodes(), customValue); @@ -802,7 +804,6 @@ void Node::check_for_tablebase_wdl(StateObj* state) void Node::mark_as_tablebase() { init_node_data(); - fully_expand_node(); isTablebase = true; } #endif diff --git a/engine/src/node.h b/engine/src/node.h index 2620c3a7..1ebafb57 100644 --- a/engine/src/node.h +++ b/engine/src/node.h @@ -85,12 +85,10 @@ class Node public: /** * @brief Node Primary constructor which is used when expanding a node during search - * @param parentNode Pointer to parent node - * @param move Move which led to current board state + * @param State Corresponding state object * @param searchSettings Pointer to the searchSettings */ Node(StateObj *state, - bool inCheck, const SearchSettings* searchSettings); /** @@ -477,9 +475,8 @@ class Node /** * @brief check_for_terminal Checks if the given board position is a terminal node and updates isTerminal * @param state Current board position for this node - * @param inCheck Boolean indicating if the king is in check */ - void check_for_terminal(StateObj* state, bool inCheck); + void check_for_terminal(StateObj* state); #ifdef MCTS_TB_SUPPORT /** diff --git a/engine/src/searchthread.cpp b/engine/src/searchthread.cpp index e2adaa52..e8ed7980 100644 --- a/engine/src/searchthread.cpp +++ b/engine/src/searchthread.cpp @@ -80,7 +80,7 @@ void SearchThread::set_is_running(bool value) isRunning = value; } -NodeBackup SearchThread::add_new_node_to_tree(StateObj* newState, Node* parentNode, ChildIdx childIdx, bool inCheck) +NodeBackup SearchThread::add_new_node_to_tree(StateObj* newState, Node* parentNode, ChildIdx childIdx) { if(searchSettings->useMCGS) { mapWithMutex->mtx.lock(); @@ -107,7 +107,7 @@ NodeBackup SearchThread::add_new_node_to_tree(StateObj* newState, Node* parentNo mapWithMutex->mtx.unlock(); } assert(parentNode != nullptr); - shared_ptr newNode = make_shared(newState, inCheck, searchSettings); + shared_ptr newNode = make_shared(newState, searchSettings); // connect the Node to the parent parentNode->add_new_child_node(newNode, childIdx); return NODE_NEW_NODE; @@ -210,13 +210,12 @@ Node* SearchThread::get_new_child_to_evaluate(ChildIdx& childIdx, NodeDescriptio newState->do_action(action); } #endif - const bool inCheck = newState->gives_check(currentNode->get_action(childIdx)); newState->do_action(currentNode->get_action(childIdx)); currentNode->increment_no_visit_idx(); #ifdef MCTS_STORE_STATES - description.type = add_new_node_to_tree(newState, currentNode, childIdx, inCheck); + description.type = add_new_node_to_tree(newState, currentNode, childIdx); #else - description.type = add_new_node_to_tree(newState.get(), currentNode, childIdx, inCheck); + description.type = add_new_node_to_tree(newState.get(), currentNode, childIdx); #endif currentNode->unlock(); @@ -306,6 +305,10 @@ void SearchThread::set_nn_results_to_child_nodes() { size_t batchIdx = 0; for (auto node: *newNodes) { + if (node == nullptr) { + info_string("nullptr newNode"); + continue; + } if (!node->is_terminal()) { fill_nn_results(batchIdx, net->is_policy_map(), valueOutputs, probOutputs, auxiliaryOutputs, node, tbHits, newNodeSideToMove->get_element(batchIdx), searchSettings, rootNode->is_tablebase()); } @@ -417,6 +420,9 @@ void run_search_thread(SearchThread *t) void SearchThread::backup_values(FixedVector& nodes, vector& trajectories) { for (size_t idx = 0; idx < nodes.size(); ++idx) { Node* node = nodes.get_element(idx); + if (node == nullptr) { + continue; + } #ifdef MCTS_TB_SUPPORT const bool solveForTerminal = searchSettings->mctsSolver && node->is_tablebase(); backup_value(node->get_value(), searchSettings->virtualLoss, trajectories[idx], solveForTerminal); diff --git a/engine/src/searchthread.h b/engine/src/searchthread.h index 9072b90a..2d59b8c0 100644 --- a/engine/src/searchthread.h +++ b/engine/src/searchthread.h @@ -146,7 +146,7 @@ class SearchThread : NeuralNetAPIUser * @param inCheck Defines if the current position sets a player in check * @return Returns NODE_TRANSPOSITION if a tranpsosition node was added and NODE_NEW_NODE otherwise */ - NodeBackup add_new_node_to_tree(StateObj* newPos, Node* parentNode, ChildIdx childIdx, bool inCheck); + NodeBackup add_new_node_to_tree(StateObj* newPos, Node* parentNode, ChildIdx childIdx); /** * @brief reset_tb_hits Sets the number of table hits to 0