diff --git a/include/asyncpp/thread_pool.hpp b/include/asyncpp/thread_pool.hpp index fe6dc1d..a2cc78c 100644 --- a/include/asyncpp/thread_pool.hpp +++ b/include/asyncpp/thread_pool.hpp @@ -34,6 +34,7 @@ class thread_pool : public scheduler { static void schedule(schedulable_promise& item, atomic_stack& global_worklist, std::condition_variable& global_notification, + std::mutex& global_mutex, worker* local = nullptr); static schedulable_promise* steal(std::span workers); @@ -41,11 +42,13 @@ class thread_pool : public scheduler { static void execute(worker& local, atomic_stack& global_worklist, std::condition_variable& global_notification, + std::mutex& global_mutex, std::atomic_flag& terminate, std::span workers); private: std::condition_variable m_global_notification; + std::mutex m_global_mutex; atomic_stack m_global_worklist; std::vector m_workers; std::atomic_flag m_terminate; diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp index b27cc22..59dfb63 100644 --- a/src/thread_pool.cpp +++ b/src/thread_pool.cpp @@ -10,7 +10,7 @@ thread_pool::thread_pool(size_t num_threads) for (auto& w : m_workers) { w.thread = std::jthread([this, &w] { local = &w; - execute(w, m_global_worklist, m_global_notification, m_terminate, m_workers); + execute(w, m_global_worklist, m_global_notification, m_global_mutex, m_terminate, m_workers); }); } } @@ -23,13 +23,14 @@ thread_pool::~thread_pool() { void thread_pool::schedule(schedulable_promise& promise) { - schedule(promise, m_global_worklist, m_global_notification, local); + schedule(promise, m_global_worklist, m_global_notification, m_global_mutex, local); } void thread_pool::schedule(schedulable_promise& item, atomic_stack& global_worklist, std::condition_variable& global_notification, + std::mutex& global_mutex, worker* local) { if (local) { const auto prev_item = INTERLEAVED(local->worklist.push(&item)); @@ -38,6 +39,8 @@ void thread_pool::schedule(schedulable_promise& item, } } else { + std::unique_lock lk(global_mutex, std::defer_lock); + INTERLEAVED_ACQUIRE(lk.lock()); INTERLEAVED(global_worklist.push(&item)); INTERLEAVED(global_notification.notify_one()); } @@ -57,16 +60,17 @@ schedulable_promise* thread_pool::steal(std::span workers) { void thread_pool::execute(worker& local, atomic_stack& global_worklist, std::condition_variable& global_notification, + std::mutex& global_mutex, std::atomic_flag& terminate, std::span workers) { - std::mutex mtx; do { const auto item = INTERLEAVED(local.worklist.pop()); if (item != nullptr) { item->handle().resume(); } else { - std::unique_lock lk(mtx); + std::unique_lock lk(global_mutex, std::defer_lock); + INTERLEAVED_ACQUIRE(lk.lock()); global_notification.wait(lk, [&] { const auto global = INTERLEAVED(global_worklist.pop()); if (global) { diff --git a/test/test_thread_pool.cpp b/test/test_thread_pool.cpp index cf13925..1f3bb9b 100644 --- a/test/test_thread_pool.cpp +++ b/test/test_thread_pool.cpp @@ -28,18 +28,19 @@ struct test_promise : schedulable_promise { TEST_CASE("Thread pool: schedule worklist selection", "[Thread pool]") { std::condition_variable global_notification; + std::mutex global_mutex; atomic_stack global_worklist; std::vector workers(1); test_promise promise; SECTION("has local worker") { - thread_pool::schedule(promise, global_worklist, global_notification, &workers[0]); + thread_pool::schedule(promise, global_worklist, global_notification, global_mutex, &workers[0]); REQUIRE(workers[0].worklist.pop() == &promise); REQUIRE(global_worklist.empty()); } SECTION("no local worker") { - thread_pool::schedule(promise, global_worklist, global_notification, &workers[0]); + thread_pool::schedule(promise, global_worklist, global_notification, global_mutex, &workers[0]); REQUIRE(workers[0].worklist.pop() == &promise); } } @@ -66,6 +67,7 @@ TEST_CASE("Thread pool: ensure execution", "[Thread pool]") { struct scenario : testing::validated_scenario { std::condition_variable global_notification; + std::mutex global_mutex; atomic_stack global_worklist; std::vector workers; std::atomic_flag terminate; @@ -74,13 +76,13 @@ TEST_CASE("Thread pool: ensure execution", "[Thread pool]") { scenario() : workers(1) {} void schedule() { - thread_pool::schedule(promise, global_worklist, global_notification); + thread_pool::schedule(promise, global_worklist, global_notification, global_mutex); INTERLEAVED(terminate.test_and_set()); global_notification.notify_all(); } void execute() { - thread_pool::execute(workers[0], global_worklist, global_notification, terminate, std::span(workers)); + thread_pool::execute(workers[0], global_worklist, global_notification, global_mutex, terminate, std::span(workers)); } void validate(const testing::path& p) override {