Skip to content

Commit

Permalink
fix: protect stop condition by mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
petiaccja committed Feb 20, 2024
1 parent 0b0702a commit fc4d6c2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
3 changes: 3 additions & 0 deletions include/asyncpp/thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ class thread_pool : public scheduler {
static void schedule(schedulable_promise& item,
atomic_stack<schedulable_promise, &schedulable_promise::m_scheduler_next>& global_worklist,
std::condition_variable& global_notification,
std::mutex& global_mutex,
worker* local = nullptr);

static schedulable_promise* steal(std::span<worker> workers);

static void execute(worker& local,
atomic_stack<schedulable_promise, &schedulable_promise::m_scheduler_next>& global_worklist,
std::condition_variable& global_notification,
std::mutex& global_mutex,
std::atomic_flag& terminate,
std::span<worker> workers);

private:
std::condition_variable m_global_notification;
std::mutex m_global_mutex;
atomic_stack<schedulable_promise, &schedulable_promise::m_scheduler_next> m_global_worklist;
std::vector<worker> m_workers;
std::atomic_flag m_terminate;
Expand Down
12 changes: 8 additions & 4 deletions src/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ thread_pool::thread_pool(size_t num_threads)
for (auto& w : m_workers) {
w.thread = std::jthread([this, &w] {
local = &w;
execute(w, m_global_worklist, m_global_notification, m_terminate, m_workers);
execute(w, m_global_worklist, m_global_notification, m_global_mutex, m_terminate, m_workers);
});
}
}
Expand All @@ -23,13 +23,14 @@ thread_pool::~thread_pool() {


void thread_pool::schedule(schedulable_promise& promise) {
schedule(promise, m_global_worklist, m_global_notification, local);
schedule(promise, m_global_worklist, m_global_notification, m_global_mutex, local);
}


void thread_pool::schedule(schedulable_promise& item,
atomic_stack<schedulable_promise, &schedulable_promise::m_scheduler_next>& global_worklist,
std::condition_variable& global_notification,
std::mutex& global_mutex,
worker* local) {
if (local) {
const auto prev_item = INTERLEAVED(local->worklist.push(&item));
Expand All @@ -38,6 +39,8 @@ void thread_pool::schedule(schedulable_promise& item,
}
}
else {
std::unique_lock lk(global_mutex, std::defer_lock);
INTERLEAVED_ACQUIRE(lk.lock());
INTERLEAVED(global_worklist.push(&item));
INTERLEAVED(global_notification.notify_one());
}
Expand All @@ -57,16 +60,17 @@ schedulable_promise* thread_pool::steal(std::span<worker> workers) {
void thread_pool::execute(worker& local,
atomic_stack<schedulable_promise, &schedulable_promise::m_scheduler_next>& global_worklist,
std::condition_variable& global_notification,
std::mutex& global_mutex,
std::atomic_flag& terminate,
std::span<worker> workers) {
std::mutex mtx;
do {
const auto item = INTERLEAVED(local.worklist.pop());
if (item != nullptr) {
item->handle().resume();
}
else {
std::unique_lock lk(mtx);
std::unique_lock lk(global_mutex, std::defer_lock);
INTERLEAVED_ACQUIRE(lk.lock());
global_notification.wait(lk, [&] {
const auto global = INTERLEAVED(global_worklist.pop());
if (global) {
Expand Down
10 changes: 6 additions & 4 deletions test/test_thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,19 @@ struct test_promise : schedulable_promise {

TEST_CASE("Thread pool: schedule worklist selection", "[Thread pool]") {
std::condition_variable global_notification;
std::mutex global_mutex;
atomic_stack<schedulable_promise, &schedulable_promise::m_scheduler_next> global_worklist;
std::vector<thread_pool::worker> workers(1);

test_promise promise;

SECTION("has local worker") {
thread_pool::schedule(promise, global_worklist, global_notification, &workers[0]);
thread_pool::schedule(promise, global_worklist, global_notification, global_mutex, &workers[0]);
REQUIRE(workers[0].worklist.pop() == &promise);
REQUIRE(global_worklist.empty());
}
SECTION("no local worker") {
thread_pool::schedule(promise, global_worklist, global_notification, &workers[0]);
thread_pool::schedule(promise, global_worklist, global_notification, global_mutex, &workers[0]);
REQUIRE(workers[0].worklist.pop() == &promise);
}
}
Expand All @@ -66,6 +67,7 @@ TEST_CASE("Thread pool: ensure execution", "[Thread pool]") {

struct scenario : testing::validated_scenario {
std::condition_variable global_notification;
std::mutex global_mutex;
atomic_stack<schedulable_promise, &schedulable_promise::m_scheduler_next> global_worklist;
std::vector<thread_pool::worker> workers;
std::atomic_flag terminate;
Expand All @@ -74,13 +76,13 @@ TEST_CASE("Thread pool: ensure execution", "[Thread pool]") {
scenario() : workers(1) {}

void schedule() {
thread_pool::schedule(promise, global_worklist, global_notification);
thread_pool::schedule(promise, global_worklist, global_notification, global_mutex);
INTERLEAVED(terminate.test_and_set());
global_notification.notify_all();
}

void execute() {
thread_pool::execute(workers[0], global_worklist, global_notification, terminate, std::span(workers));
thread_pool::execute(workers[0], global_worklist, global_notification, global_mutex, terminate, std::span(workers));
}

void validate(const testing::path& p) override {
Expand Down

0 comments on commit fc4d6c2

Please sign in to comment.