Skip to content

Commit

Permalink
Add mustWait boolean to control the condition variable
Browse files Browse the repository at this point in the history
  • Loading branch information
QueensGambit committed May 16, 2023
1 parent 9ca5001 commit 9f0f077
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
14 changes: 11 additions & 3 deletions engine/src/agents/agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,14 @@ void Agent::set_best_move(size_t moveCounter)
}
}

void Agent::set_must_wait(bool value)
{
mustWait = value;
}

Agent::Agent(NeuralNetAPI* net, PlaySettings* playSettings, bool verbose):
NeuralNetAPIUser(net),
playSettings(playSettings), verbose(verbose), isRunning(false)
playSettings(playSettings), mustWait(true), verbose(verbose), isRunning(false)
{
}

Expand All @@ -73,14 +78,17 @@ Action Agent::get_best_action()
void Agent::lock_and_wait()
{
unique_lock<mutex> lock(isRunningMutex);
isRunningCondition.wait(lock);
while(mustWait) {
isRunningCondition.wait(lock);
}
}

void Agent::unlock_and_notify()
{
// std::lock_guard is deprecated in C++17, therefore we use scoped_lock instead
scoped_lock<mutex> lock(isRunningMutex);
isRunningCondition.notify_all();
mustWait = false;
isRunningCondition.notify_one();
}

void Agent::perform_action()
Expand Down
3 changes: 3 additions & 0 deletions engine/src/agents/agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class Agent : public NeuralNetAPIUser
// Reference: https://github.com/dmfrodrigues/GraphViewerCpp/issues/16
condition_variable isRunningCondition;
mutex isRunningMutex;
// additional boolean variable to control the condition variable
bool mustWait;
bool verbose;
// boolean which can be triggered by "stop" from std-in to stop the current search
bool isRunning;
Expand Down Expand Up @@ -119,6 +121,7 @@ class Agent : public NeuralNetAPIUser
* @brief unlock_and_notify Unlocks the isRunningMutex and notifies all threads from the isRunningCondition variable.
*/
void unlock_and_notify();
void set_must_wait(bool value);
};
}

Expand Down
9 changes: 6 additions & 3 deletions engine/src/uci/crazyara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void CrazyAra::uci_loop(int argc, char *argv[])

else if (token == "match") multimodel_arena(is, "", "", true);
else if (token == "tournament") roundrobin(is);
#endif
#endif
else
cout << "Unknown command: " << cmd << endl;

Expand Down Expand Up @@ -210,11 +210,13 @@ void CrazyAra::go(StateObj* state, istringstream &is, EvalInfo& evalInfo)

if (useRawNetwork) {
rawAgent->set_search_settings(state, &searchLimits, &evalInfo);
rawAgent->set_must_wait(true);
mainSearchThread = thread(run_agent_thread, rawAgent.get());
rawAgent->lock_and_wait(); // wait for the agent to be initalized to allow then stopping it.
}
else {
mctsAgent->set_search_settings(state, &searchLimits, &evalInfo);
mctsAgent->set_must_wait(true);
mainSearchThread = thread(run_agent_thread, mctsAgent.get());
mctsAgent->lock_and_wait(); // wait for the agent to be initalized to allow then stopping it.
}
Expand Down Expand Up @@ -664,7 +666,7 @@ void CrazyAra::set_uci_option(istringstream &is, StateObj& state)
}

unique_ptr<MCTSAgent> CrazyAra::create_new_mcts_agent(NeuralNetAPI* netSingle, vector<unique_ptr<NeuralNetAPI>>& netBatches, SearchSettings* searchSettings, MCTSAgentType type)
{
{
switch (type) {
case MCTSAgentType::kDefault:
return make_unique<MCTSAgent>(netSingle, netBatches, searchSettings, &playSettings);
Expand Down Expand Up @@ -737,6 +739,7 @@ void CrazyAra::init_search_settings()
}
searchSettings.reuseTree = Options["Reuse_Tree"];
searchSettings.mctsSolver = Options["MCTS_Solver"];
searchSettings.useUncertainty = Options["Use_Uncertainty"];
}

void CrazyAra::init_play_settings()
Expand Down Expand Up @@ -780,7 +783,7 @@ std::vector<std::string> comb(std::vector<int> N, int K)
{
if (bitmask[i]){
c.append(std::to_string(N[i])+ " ");
}
}
}
p.push_back(c);
} while (std::prev_permutation(bitmask.begin(), bitmask.end()));
Expand Down

0 comments on commit 9f0f077

Please sign in to comment.