diff --git a/engine/src/agents/agent.cpp b/engine/src/agents/agent.cpp index 067157b4..9f500719 100644 --- a/engine/src/agents/agent.cpp +++ b/engine/src/agents/agent.cpp @@ -52,9 +52,14 @@ void Agent::set_best_move(size_t moveCounter) } } +void Agent::set_must_wait(bool value) +{ + mustWait = value; +} + Agent::Agent(NeuralNetAPI* net, PlaySettings* playSettings, bool verbose): NeuralNetAPIUser(net), - playSettings(playSettings), verbose(verbose), isRunning(false) + playSettings(playSettings), mustWait(true), verbose(verbose), isRunning(false) { } @@ -73,14 +78,17 @@ Action Agent::get_best_action() void Agent::lock_and_wait() { unique_lock lock(isRunningMutex); - isRunningCondition.wait(lock); + while(mustWait) { + isRunningCondition.wait(lock); + } } void Agent::unlock_and_notify() { // std::lock_guard is deprecated in C++17, therefore we use scoped_lock instead scoped_lock lock(isRunningMutex); - isRunningCondition.notify_all(); + mustWait = false; + isRunningCondition.notify_one(); } void Agent::perform_action() diff --git a/engine/src/agents/agent.h b/engine/src/agents/agent.h index 75174adc..eda0c007 100644 --- a/engine/src/agents/agent.h +++ b/engine/src/agents/agent.h @@ -65,6 +65,8 @@ class Agent : public NeuralNetAPIUser // Reference: https://github.com/dmfrodrigues/GraphViewerCpp/issues/16 condition_variable isRunningCondition; mutex isRunningMutex; + // additional boolean variable to control the condition variable + bool mustWait; bool verbose; // boolean which can be triggered by "stop" from std-in to stop the current search bool isRunning; @@ -119,6 +121,7 @@ class Agent : public NeuralNetAPIUser * @brief unlock_and_notify Unlocks the isRunningMutex and notifies all threads from the isRunningCondition variable. */ void unlock_and_notify(); + void set_must_wait(bool value); }; } diff --git a/engine/src/uci/crazyara.cpp b/engine/src/uci/crazyara.cpp index c5b100b4..3cbdeb33 100644 --- a/engine/src/uci/crazyara.cpp +++ b/engine/src/uci/crazyara.cpp @@ -139,7 +139,7 @@ void CrazyAra::uci_loop(int argc, char *argv[]) else if (token == "match") multimodel_arena(is, "", "", true); else if (token == "tournament") roundrobin(is); -#endif +#endif else cout << "Unknown command: " << cmd << endl; @@ -210,11 +210,13 @@ void CrazyAra::go(StateObj* state, istringstream &is, EvalInfo& evalInfo) if (useRawNetwork) { rawAgent->set_search_settings(state, &searchLimits, &evalInfo); + rawAgent->set_must_wait(true); mainSearchThread = thread(run_agent_thread, rawAgent.get()); rawAgent->lock_and_wait(); // wait for the agent to be initalized to allow then stopping it. } else { mctsAgent->set_search_settings(state, &searchLimits, &evalInfo); + mctsAgent->set_must_wait(true); mainSearchThread = thread(run_agent_thread, mctsAgent.get()); mctsAgent->lock_and_wait(); // wait for the agent to be initalized to allow then stopping it. } @@ -664,7 +666,7 @@ void CrazyAra::set_uci_option(istringstream &is, StateObj& state) } unique_ptr CrazyAra::create_new_mcts_agent(NeuralNetAPI* netSingle, vector>& netBatches, SearchSettings* searchSettings, MCTSAgentType type) -{ +{ switch (type) { case MCTSAgentType::kDefault: return make_unique(netSingle, netBatches, searchSettings, &playSettings); @@ -737,6 +739,7 @@ void CrazyAra::init_search_settings() } searchSettings.reuseTree = Options["Reuse_Tree"]; searchSettings.mctsSolver = Options["MCTS_Solver"]; + searchSettings.useUncertainty = Options["Use_Uncertainty"]; } void CrazyAra::init_play_settings() @@ -780,7 +783,7 @@ std::vector comb(std::vector N, int K) { if (bitmask[i]){ c.append(std::to_string(N[i])+ " "); - } + } } p.push_back(c); } while (std::prev_permutation(bitmask.begin(), bitmask.end()));