diff --git a/benchmark/benchmark_atomic.cpp b/benchmark/benchmark_atomic.cpp index d21f3b8..0248e55 100644 --- a/benchmark/benchmark_atomic.cpp +++ b/benchmark/benchmark_atomic.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -10,11 +12,14 @@ // achievable and reasonable on the hardware. +using namespace asyncpp; + + static constexpr size_t base_reps = 4'000'000; BASELINE(atomic_rmw, x1_thread, 30, 1) { - alignas(64) std::atomic_size_t counter = 0; + alignas(avoid_false_sharing) std::atomic_size_t counter = 0; static constexpr size_t reps = base_reps; static const auto func = [&counter] { @@ -29,7 +34,7 @@ BASELINE(atomic_rmw, x1_thread, 30, 1) { BENCHMARK(atomic_rmw, x2_thread, 30, 1) { - alignas(64) std::atomic_size_t counter = 0; + alignas(avoid_false_sharing) std::atomic_size_t counter = 0; static constexpr size_t reps = base_reps / 2; static const auto func = [&counter] { @@ -44,7 +49,7 @@ BENCHMARK(atomic_rmw, x2_thread, 30, 1) { BENCHMARK(atomic_rmw, x4_thread, 30, 1) { - alignas(64) std::atomic_size_t counter = 0; + alignas(avoid_false_sharing) std::atomic_size_t counter = 0; static constexpr size_t reps = base_reps / 4; static const auto func = [&counter] { @@ -59,7 +64,7 @@ BENCHMARK(atomic_rmw, x4_thread, 30, 1) { BENCHMARK(atomic_rmw, x8_thread, 30, 1) { - alignas(64) std::atomic_size_t counter = 0; + alignas(avoid_false_sharing) std::atomic_size_t counter = 0; static constexpr size_t reps = base_reps / 8; static const auto func = [&counter] { @@ -74,7 +79,7 @@ BENCHMARK(atomic_rmw, x8_thread, 30, 1) { BASELINE(atomic_read, x1_thread, 30, 1) { - alignas(64) std::atomic_size_t counter = 0; + alignas(avoid_false_sharing) std::atomic_size_t counter = 0; static constexpr size_t reps = base_reps; static const auto func = [&counter] { @@ -89,7 +94,7 @@ BASELINE(atomic_read, x1_thread, 30, 1) { BENCHMARK(atomic_read, x2_thread, 30, 1) { - alignas(64) std::atomic_size_t counter = 0; + alignas(avoid_false_sharing) std::atomic_size_t counter = 0; static constexpr size_t reps = base_reps / 2; static const auto func = [&counter] { @@ -104,7 +109,7 @@ BENCHMARK(atomic_read, x2_thread, 30, 1) { BENCHMARK(atomic_read, x4_thread, 30, 1) { - alignas(64) std::atomic_size_t counter = 0; + alignas(avoid_false_sharing) std::atomic_size_t counter = 0; static constexpr size_t reps = base_reps / 4; static const auto func = [&counter] { @@ -119,7 +124,7 @@ BENCHMARK(atomic_read, x4_thread, 30, 1) { BENCHMARK(atomic_read, x8_thread, 30, 1) { - alignas(64) std::atomic_size_t counter = 0; + alignas(avoid_false_sharing) std::atomic_size_t counter = 0; static constexpr size_t reps = base_reps / 8; static const auto func = [&counter] { diff --git a/benchmark/benchmark_task_spawn.cpp b/benchmark/benchmark_task_spawn.cpp index cdf315e..444f550 100644 --- a/benchmark/benchmark_task_spawn.cpp +++ b/benchmark/benchmark_task_spawn.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -56,7 +57,7 @@ struct FixtureStack : celero::TestFixture { private: static constexpr inline size_t size = 10485760; struct block { - alignas(64) std::byte content[64]; + alignas(avoid_false_sharing) std::byte content[avoid_false_sharing]; }; std::unique_ptr buffer = std::make_unique_for_overwrite(size / sizeof(block)); std::pmr::monotonic_buffer_resource resource; diff --git a/benchmark/benchmark_thread_pool.cpp b/benchmark/benchmark_thread_pool.cpp index 5d09b0a..14b586f 100644 --- a/benchmark/benchmark_thread_pool.cpp +++ b/benchmark/benchmark_thread_pool.cpp @@ -40,7 +40,7 @@ struct FixturePool : celero::TestFixture { sync(); } - thread_pool_3 pool; + thread_pool pool; std::atomic_size_t num_running; std::vector promises; }; @@ -68,7 +68,7 @@ BENCHMARK_F(tp_schedule_outside, x4_thread, FixturePool<4>, 30, 1) { BASELINE_F(tp_schedule_inside, x1_thread, FixturePool<1>, 30, 1) { - static constexpr auto coro = [](std::span promises, thread_pool_3& pool) -> task { + static constexpr auto coro = [](std::span promises, thread_pool& pool) -> task { for (auto& promise : promises) { pool.schedule(promise); } @@ -80,7 +80,7 @@ BASELINE_F(tp_schedule_inside, x1_thread, FixturePool<1>, 30, 1) { BENCHMARK_F(tp_schedule_inside, x2_thread, FixturePool<2>, 30, 1) { - static constexpr auto coro = [](std::span promises, thread_pool_3& pool) -> task { + static constexpr auto coro = [](std::span promises, thread_pool& pool) -> task { for (auto& promise : promises) { pool.schedule(promise); } @@ -95,7 +95,7 @@ BENCHMARK_F(tp_schedule_inside, x2_thread, FixturePool<2>, 30, 1) { BENCHMARK_F(tp_schedule_inside, x4_thread, FixturePool<4>, 30, 1) { - static constexpr auto coro = [](std::span promises, thread_pool_3& pool) -> task { + static constexpr auto coro = [](std::span promises, thread_pool& pool) -> task { for (auto& promise : promises) { pool.schedule(promise); } @@ -114,7 +114,7 @@ BENCHMARK_F(tp_schedule_inside, x4_thread, FixturePool<4>, 30, 1) { BASELINE_F(tp_stealing, x1_thread, FixturePool<1 + 1>, 30, 1) { - static constexpr auto coro = [](std::span promises, thread_pool_3& pool) -> task { + static constexpr auto coro = [](std::span promises, thread_pool& pool) -> task { for (auto& promise : promises) { pool.schedule(promise); } @@ -125,7 +125,7 @@ BASELINE_F(tp_stealing, x1_thread, FixturePool<1 + 1>, 30, 1) { BENCHMARK_F(tp_stealing, x2_thread, FixturePool<2 + 1>, 30, 1) { - static constexpr auto coro = [](std::span promises, thread_pool_3& pool) -> task { + static constexpr auto coro = [](std::span promises, thread_pool& pool) -> task { for (auto& promise : promises) { pool.schedule(promise); } @@ -136,7 +136,7 @@ BENCHMARK_F(tp_stealing, x2_thread, FixturePool<2 + 1>, 30, 1) { BENCHMARK_F(tp_stealing, x4_thread, FixturePool<4 + 1>, 30, 1) { - static constexpr auto coro = [](std::span promises, thread_pool_3& pool) -> task { + static constexpr auto coro = [](std::span promises, thread_pool& pool) -> task { for (auto& promise : promises) { pool.schedule(promise); } diff --git a/include/asyncpp/thread_pool.hpp b/include/asyncpp/thread_pool.hpp index 72ac74d..e4cd187 100644 --- a/include/asyncpp/thread_pool.hpp +++ b/include/asyncpp/thread_pool.hpp @@ -4,19 +4,19 @@ #include "container/atomic_deque.hpp" #include "container/atomic_stack.hpp" #include "scheduler.hpp" +#include "threading/cache.hpp" +#include "threading/spinlock.hpp" #include #include #include -#include #include #include namespace asyncpp { - -class thread_pool_3 : public scheduler { +class thread_pool : public scheduler { public: struct pack; @@ -24,194 +24,46 @@ class thread_pool_3 : public scheduler { public: using queue = deque; - worker() - : m_sema(0) {} - - ~worker() { - cancel(); - } - - void insert(schedulable_promise& promise) { - std::unique_lock lk(m_spinlock, std::defer_lock); - INTERLEAVED_ACQUIRE(lk.lock()); - const auto previous = m_promises.push_back(&promise); - const auto blocked = m_blocked.test(std::memory_order_relaxed); - INTERLEAVED(lk.unlock()); - - if (!previous && blocked) { - INTERLEAVED(m_sema.release()); - } - } - - schedulable_promise* steal_from_this() { - std::unique_lock lk(m_spinlock, std::defer_lock); - INTERLEAVED(lk.lock()); - return m_promises.pop_front(); - } - - schedulable_promise* try_get_promise(pack& pack, size_t& stealing_attempt, bool& exit_loop) { - std::unique_lock lk(m_spinlock, std::defer_lock); - INTERLEAVED_ACQUIRE(lk.lock()); - const auto promise = m_promises.front(); - if (promise) { - m_promises.pop_front(); - return promise; - } - - if (stealing_attempt > 0) { - INTERLEAVED(lk.unlock()); - const auto stolen = steal_from_other(pack, stealing_attempt); - stealing_attempt = stolen ? pack.workers.size() : stealing_attempt - 1; - return stolen; - } - - - if (INTERLEAVED(m_cancelled.test(std::memory_order_relaxed))) { - exit_loop = true; - } - else { - INTERLEAVED(m_blocked.test_and_set(std::memory_order_relaxed)); - pack.blocked.push(this); - pack.num_blocked.fetch_add(1, std::memory_order_relaxed); - INTERLEAVED(lk.unlock()); - INTERLEAVED_ACQUIRE(m_sema.acquire()); - INTERLEAVED(m_blocked.clear()); - stealing_attempt = pack.workers.size(); - } - return nullptr; - } - - schedulable_promise* steal_from_other(pack& pack, size_t& stealing_attempt) const { - const size_t pack_size = pack.workers.size(); - const size_t my_index = this - pack.workers.data(); - const size_t victim_index = (my_index + stealing_attempt) % pack_size; - return pack.workers[victim_index].steal_from_this(); - } - - void start(pack& pack) { - m_thread = std::jthread([this, &pack] { - run(pack); - }); - } - - void cancel() { - std::unique_lock lk(m_spinlock, std::defer_lock); - INTERLEAVED_ACQUIRE(lk.lock()); - INTERLEAVED(m_cancelled.test_and_set(std::memory_order_relaxed)); - const auto blocked = INTERLEAVED(m_blocked.test(std::memory_order_relaxed)); - lk.unlock(); - if (blocked) { - INTERLEAVED(m_sema.release()); - } - } + worker(); + ~worker(); + + void insert(schedulable_promise& promise); + schedulable_promise* steal_from_this(); + schedulable_promise* try_get_promise(pack& pack, size_t& stealing_attempt, bool& exit_loop); + schedulable_promise* steal_from_other(pack& pack, size_t& stealing_attempt) const; + void start(pack& pack); + void cancel(); private: - void run(pack& pack) { - m_local = this; - size_t stealing_attempt = pack.workers.size(); - bool exit_loop = false; - while (!exit_loop) { - const auto promise = try_get_promise(pack, stealing_attempt, exit_loop); - if (promise) { - promise->resume_now(); - } - } - } + void run(pack& pack); public: worker* m_next = nullptr; private: - alignas(64) spinlock m_spinlock; - alignas(64) queue m_promises; - alignas(64) std::atomic_flag m_blocked; - alignas(64) std::binary_semaphore m_sema; - alignas(64) std::jthread m_thread; - alignas(64) std::atomic_flag m_cancelled; + alignas(avoid_false_sharing) spinlock m_spinlock; + alignas(avoid_false_sharing) queue m_promises; + alignas(avoid_false_sharing) std::atomic_flag m_blocked; + alignas(avoid_false_sharing) std::binary_semaphore m_sema; + alignas(avoid_false_sharing) std::jthread m_thread; + alignas(avoid_false_sharing) std::atomic_flag m_cancelled; }; struct pack { - alignas(64) std::vector workers; - alignas(64) atomic_stack blocked; - alignas(64) std::atomic_size_t num_blocked = 0; + alignas(avoid_false_sharing) std::vector workers; + alignas(avoid_false_sharing) atomic_stack blocked; + alignas(avoid_false_sharing) std::atomic_size_t num_blocked = 0; }; public: - thread_pool_3(size_t num_threads = 1) - : m_pack(std::vector(num_threads)), m_next_in_schedule(0) { - for (auto& worker : m_pack.workers) { - worker.start(m_pack); - } - } - - void schedule(schedulable_promise& promise) override { - size_t num_blocked = m_pack.num_blocked.load(std::memory_order_relaxed); - const auto blocked = num_blocked > 0 ? m_pack.blocked.pop() : nullptr; - if (blocked) { - blocked->insert(promise); - m_pack.num_blocked.fetch_sub(1, std::memory_order_relaxed); - } - else if (m_local) { - m_local->insert(promise); - } - else { - const auto selected = m_next_in_schedule.fetch_add(1, std::memory_order_relaxed) % m_pack.workers.size(); - m_pack.workers[selected].insert(promise); - } - } - -private: - alignas(64) pack m_pack; - alignas(64) std::atomic_ptrdiff_t m_next_in_schedule; - inline static thread_local worker* m_local = nullptr; -}; - - -class thread_pool : public scheduler { -public: - struct worker { - worker* m_next = nullptr; - - std::jthread thread; - atomic_stack worklist; - }; - thread_pool(size_t num_threads = 1); - thread_pool(thread_pool&&) = delete; - thread_pool operator=(thread_pool&&) = delete; - ~thread_pool(); - void schedule(schedulable_promise& promise) override; - - static void schedule(schedulable_promise& item, - atomic_stack& global_worklist, - std::condition_variable& global_notification, - std::mutex& global_mutex, - std::atomic_size_t& num_waiting, - worker* local = nullptr); - - static schedulable_promise* steal(std::span workers); - - static void execute(worker& local, - atomic_stack& global_worklist, - std::condition_variable& global_notification, - std::mutex& global_mutex, - std::atomic_flag& terminate, - std::atomic_size_t& num_waiting, - std::span workers); - private: - std::condition_variable m_global_notification; - std::mutex m_global_mutex; - atomic_stack m_global_worklist; - std::vector m_workers; - std::atomic_flag m_terminate; - std::atomic_size_t m_num_waiting = 0; - - inline static thread_local worker* local = nullptr; + alignas(avoid_false_sharing) pack m_pack; + alignas(avoid_false_sharing) std::atomic_ptrdiff_t m_next_in_schedule; + inline static thread_local worker* m_local = nullptr; }; - } // namespace asyncpp \ No newline at end of file diff --git a/include/asyncpp/threading/cache.hpp b/include/asyncpp/threading/cache.hpp new file mode 100644 index 0000000..8adbc2f --- /dev/null +++ b/include/asyncpp/threading/cache.hpp @@ -0,0 +1,14 @@ +#pragma once + + +namespace asyncpp { + +#ifdef __cpp_lib_hardware_interference_size +inline constexpr size_t avoid_false_sharing = std::hardware_destructive_interference_size; +inline constexpr size_t promote_true_sharing = std::hardware_constructive_interference_size; +#else +inline constexpr size_t avoid_false_sharing = 64; +inline constexpr size_t promote_true_sharing = 64; +#endif + +} // namespace asyncpp \ No newline at end of file diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp index 26f6f95..b786ce4 100644 --- a/src/thread_pool.cpp +++ b/src/thread_pool.cpp @@ -5,92 +5,130 @@ namespace asyncpp { -thread_pool::thread_pool(size_t num_threads) - : m_workers(num_threads) { - for (auto& w : m_workers) { - w.thread = std::jthread([this, &w] { - local = &w; - execute(w, m_global_worklist, m_global_notification, m_global_mutex, m_terminate, m_num_waiting, m_workers); - }); - } +thread_pool::worker::worker() + : m_sema(0) {} + + +thread_pool::worker::~worker() { + cancel(); } -thread_pool::~thread_pool() { - std::lock_guard lk(m_global_mutex); - m_terminate.test_and_set(); - m_global_notification.notify_all(); +void thread_pool::worker::insert(schedulable_promise& promise) { + std::unique_lock lk(m_spinlock, std::defer_lock); + INTERLEAVED_ACQUIRE(lk.lock()); + const auto previous = m_promises.push_back(&promise); + const auto blocked = m_blocked.test(std::memory_order_relaxed); + INTERLEAVED(lk.unlock()); + + if (!previous && blocked) { + INTERLEAVED(m_sema.release()); + } } -void thread_pool::schedule(schedulable_promise& promise) { - schedule(promise, m_global_worklist, m_global_notification, m_global_mutex, m_num_waiting, local); +schedulable_promise* thread_pool::worker::steal_from_this() { + std::unique_lock lk(m_spinlock, std::defer_lock); + INTERLEAVED(lk.lock()); + return m_promises.pop_front(); } -void thread_pool::schedule(schedulable_promise& item, - atomic_stack& global_worklist, - std::condition_variable& global_notification, - std::mutex& global_mutex, - std::atomic_size_t& num_waiting, - worker* local) { - if (local) { - const auto prev_item = INTERLEAVED(local->worklist.push(&item)); - if (prev_item != nullptr) { - if (num_waiting.load(std::memory_order_relaxed) > 0) { - global_notification.notify_one(); - } - } +schedulable_promise* thread_pool::worker::try_get_promise(pack& pack, size_t& stealing_attempt, bool& exit_loop) { + std::unique_lock lk(m_spinlock, std::defer_lock); + INTERLEAVED_ACQUIRE(lk.lock()); + const auto promise = m_promises.front(); + if (promise) { + m_promises.pop_front(); + return promise; + } + + if (stealing_attempt > 0) { + INTERLEAVED(lk.unlock()); + const auto stolen = steal_from_other(pack, stealing_attempt); + stealing_attempt = stolen ? pack.workers.size() : stealing_attempt - 1; + return stolen; + } + + + if (INTERLEAVED(m_cancelled.test(std::memory_order_relaxed))) { + exit_loop = true; } else { - std::unique_lock lk(global_mutex, std::defer_lock); - INTERLEAVED_ACQUIRE(lk.lock()); - INTERLEAVED(global_worklist.push(&item)); - INTERLEAVED(global_notification.notify_one()); + INTERLEAVED(m_blocked.test_and_set(std::memory_order_relaxed)); + pack.blocked.push(this); + pack.num_blocked.fetch_add(1, std::memory_order_relaxed); + INTERLEAVED(lk.unlock()); + INTERLEAVED_ACQUIRE(m_sema.acquire()); + INTERLEAVED(m_blocked.clear()); + stealing_attempt = pack.workers.size(); } + return nullptr; } -schedulable_promise* thread_pool::steal(std::span workers) { - for (auto& w : workers) { - if (const auto item = INTERLEAVED(w.worklist.pop())) { - return item; - } +schedulable_promise* thread_pool::worker::steal_from_other(pack& pack, size_t& stealing_attempt) const { + const size_t pack_size = pack.workers.size(); + const size_t my_index = this - pack.workers.data(); + const size_t victim_index = (my_index + stealing_attempt) % pack_size; + return pack.workers[victim_index].steal_from_this(); +} + + +void thread_pool::worker::start(pack& pack) { + m_thread = std::jthread([this, &pack] { + run(pack); + }); +} + + +void thread_pool::worker::cancel() { + std::unique_lock lk(m_spinlock, std::defer_lock); + INTERLEAVED_ACQUIRE(lk.lock()); + INTERLEAVED(m_cancelled.test_and_set(std::memory_order_relaxed)); + const auto blocked = INTERLEAVED(m_blocked.test(std::memory_order_relaxed)); + lk.unlock(); + if (blocked) { + INTERLEAVED(m_sema.release()); } - return nullptr; } -void thread_pool::execute(worker& local, - atomic_stack& global_worklist, - std::condition_variable& global_notification, - std::mutex& global_mutex, - std::atomic_flag& terminate, - std::atomic_size_t& num_waiting, - std::span workers) { - do { - if (const auto item = INTERLEAVED(local.worklist.pop())) { - item->resume_now(); - continue; - } - else if (const auto item = INTERLEAVED(global_worklist.pop())) { - local.worklist.push(item); - continue; - } - else if (const auto item = steal(workers)) { - local.worklist.push(item); - continue; - } - else { - std::unique_lock lk(global_mutex, std::defer_lock); - INTERLEAVED_ACQUIRE(lk.lock()); - if (!INTERLEAVED(terminate.test()) && INTERLEAVED(global_worklist.empty())) { - num_waiting.fetch_add(1, std::memory_order_relaxed); - INTERLEAVED_ACQUIRE(global_notification.wait(lk)); - num_waiting.fetch_sub(1, std::memory_order_relaxed); - } +void thread_pool::worker::run(pack& pack) { + m_local = this; + size_t stealing_attempt = pack.workers.size(); + bool exit_loop = false; + while (!exit_loop) { + const auto promise = try_get_promise(pack, stealing_attempt, exit_loop); + if (promise) { + promise->resume_now(); } - } while (!INTERLEAVED(terminate.test())); + } +} + + +thread_pool::thread_pool(size_t num_threads) + : m_pack(std::vector(num_threads)), m_next_in_schedule(0) { + for (auto& worker : m_pack.workers) { + worker.start(m_pack); + } +} + + +void thread_pool::schedule(schedulable_promise& promise) { + size_t num_blocked = m_pack.num_blocked.load(std::memory_order_relaxed); + const auto blocked = num_blocked > 0 ? m_pack.blocked.pop() : nullptr; + if (blocked) { + blocked->insert(promise); + m_pack.num_blocked.fetch_sub(1, std::memory_order_relaxed); + } + else if (m_local) { + m_local->insert(promise); + } + else { + const auto selected = m_next_in_schedule.fetch_add(1, std::memory_order_relaxed) % m_pack.workers.size(); + m_pack.workers[selected].insert(promise); + } } diff --git a/test/test_thread_pool.cpp b/test/test_thread_pool.cpp index 144cb8d..6f55c58 100644 --- a/test/test_thread_pool.cpp +++ b/test/test_thread_pool.cpp @@ -30,7 +30,7 @@ struct test_promise : schedulable_promise { TEST_CASE("Thread pool 3: insert - try_get_promise interleave", "[Thread pool 3]") { struct scenario : testing::validated_scenario { - thread_pool_3::pack pack{ .workers = std::vector(1) }; + thread_pool::pack pack{ .workers = std::vector(1) }; test_promise promise; bool exit_loop = false; size_t stealing_attempt = 0; @@ -61,7 +61,7 @@ TEST_CASE("Thread pool 3: insert - try_get_promise interleave", "[Thread pool 3] TEST_CASE("Thread pool 3: cancel - try_get_promise interleave", "[Thread pool 3]") { struct scenario : testing::validated_scenario { - thread_pool_3::pack pack{ .workers = std::vector(1) }; + thread_pool::pack pack{ .workers = std::vector(1) }; bool exit_loop = false; size_t stealing_attempt = 0; schedulable_promise* result = nullptr; @@ -91,7 +91,7 @@ TEST_CASE("Thread pool 3: cancel - try_get_promise interleave", "[Thread pool 3] TEST_CASE("Thread pool 3: steal - try_get_promise interleave", "[Thread pool 3]") { struct scenario : testing::validated_scenario { - thread_pool_3::pack pack{ .workers = std::vector(1) }; + thread_pool::pack pack{ .workers = std::vector(1) }; test_promise promise; bool exit_loop = false; size_t stealing_attempt = 0; @@ -122,101 +122,6 @@ TEST_CASE("Thread pool 3: steal - try_get_promise interleave", "[Thread pool 3]" TEST_CASE("Thread pool 3: smoke test - schedule tasks", "[Scheduler]") { - thread_pool_3 sched(num_threads); - - static const auto coro = [&sched](auto self, int depth) -> task { - if (depth <= 0) { - co_return 1; - } - std::array, branching> children; - std::ranges::generate(children, [&] { return launch(self(self, depth - 1), sched); }); - int64_t sum = 0; - for (auto& tk : children) { - sum += co_await tk; - } - co_return sum; - }; - - const auto count = int64_t(std::pow(branching, depth)); - const auto result = join(bind(coro(coro, depth), sched)); - REQUIRE(result == count); -} - - -TEST_CASE("Thread pool: schedule worklist selection", "[Thread pool]") { - std::condition_variable global_notification; - std::mutex global_mutex; - atomic_stack global_worklist; - std::atomic_size_t num_waiting; - std::vector workers(1); - - test_promise promise; - - SECTION("has local worker") { - thread_pool::schedule(promise, global_worklist, global_notification, global_mutex, num_waiting, &workers[0]); - REQUIRE(workers[0].worklist.pop() == &promise); - REQUIRE(global_worklist.empty()); - } - SECTION("no local worker") { - thread_pool::schedule(promise, global_worklist, global_notification, global_mutex, num_waiting, &workers[0]); - REQUIRE(workers[0].worklist.pop() == &promise); - } -} - - -TEST_CASE("Thread pool: steal from workers", "[Thread pool]") { - std::vector workers(4); - - test_promise promise; - - SECTION("no work items") { - REQUIRE(nullptr == thread_pool::steal(workers)); - } - SECTION("1 work item") { - workers[2].worklist.push(&promise); - REQUIRE(&promise == thread_pool::steal(workers)); - } -} - - -TEST_CASE("Thread pool: ensure execution", "[Thread pool]") { - // This test makes sure that no matter the interleaving, a scheduled promise - // will be picked up and executed by a worker thread. - - struct scenario : testing::validated_scenario { - std::condition_variable global_notification; - std::mutex global_mutex; - atomic_stack global_worklist; - std::atomic_size_t num_waiting; - std::vector workers; - std::atomic_flag terminate; - test_promise promise; - - scenario() : workers(1) {} - - void schedule() { - thread_pool::schedule(promise, global_worklist, global_notification, global_mutex, num_waiting); - std::unique_lock lk(global_mutex, std::defer_lock); - INTERLEAVED_ACQUIRE(lk.lock()); - INTERLEAVED(terminate.test_and_set()); - INTERLEAVED(global_notification.notify_all()); - } - - void execute() { - thread_pool::execute(workers[0], global_worklist, global_notification, global_mutex, terminate, num_waiting, std::span(workers)); - } - - void validate(const testing::path& p) override { - INFO(p.dump()); - REQUIRE(promise.num_queried.load() > 0); - } - }; - - INTERLEAVED_RUN(scenario, THREAD("schedule", &scenario::schedule), THREAD("execute", &scenario::execute)); -} - - -TEST_CASE("Thread pool: smoke test - schedule tasks", "[Scheduler]") { thread_pool sched(num_threads); static const auto coro = [&sched](auto self, int depth) -> task { @@ -235,4 +140,4 @@ TEST_CASE("Thread pool: smoke test - schedule tasks", "[Scheduler]") { const auto count = int64_t(std::pow(branching, depth)); const auto result = join(bind(coro(coro, depth), sched)); REQUIRE(result == count); -} \ No newline at end of file +}