From 4a0806663a121165d803f2095239e248343a6098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Fri, 17 Nov 2023 21:18:02 +0100 Subject: [PATCH] mutex --- include/async++/container/atomic_queue.hpp | 68 ++++++++ include/async++/mutex.hpp | 153 ++++++++++++++++++ include/async++/shared_task.hpp | 10 +- include/async++/task.hpp | 14 +- src/CMakeLists.txt | 1 + src/mutex.cpp | 66 ++++++++ test/CMakeLists.txt | 6 +- .../test_atomic_collection.cpp | 0 test/container/test_atomic_queue.cpp | 37 +++++ test/{ => container}/test_atomic_stack.cpp | 18 +-- test/test_mutex.cpp | 69 ++++++++ test/test_shared_task.cpp | 13 ++ test/test_task.cpp | 13 ++ 13 files changed, 450 insertions(+), 18 deletions(-) create mode 100644 include/async++/container/atomic_queue.hpp create mode 100644 include/async++/mutex.hpp create mode 100644 src/mutex.cpp rename test/{ => container}/test_atomic_collection.cpp (100%) create mode 100644 test/container/test_atomic_queue.cpp rename test/{ => container}/test_atomic_stack.cpp (61%) create mode 100644 test/test_mutex.cpp diff --git a/include/async++/container/atomic_queue.hpp b/include/async++/container/atomic_queue.hpp new file mode 100644 index 0000000..d39a039 --- /dev/null +++ b/include/async++/container/atomic_queue.hpp @@ -0,0 +1,68 @@ +#pragma once + +#include "../sync/spinlock.hpp" + +#include + + +namespace asyncpp { + +template +class atomic_queue { +public: + Element* push(Element* element) noexcept { + std::lock_guard lk(m_mtx); + const auto prev_front = m_front.load(std::memory_order_relaxed); + element->*prev = prev_front; + m_front.store(element, std::memory_order_relaxed); + if (prev_front == nullptr) { + m_back.store(element, std::memory_order_relaxed); + } + else { + prev_front->*next = element; + } + return prev_front; + } + + bool compare_push(Element*& expected, Element* element) { + std::lock_guard lk(m_mtx); + const auto prev_front = m_front.load(std::memory_order_relaxed); + if (prev_front == expected) { + element->*prev = prev_front; + m_front.store(element, std::memory_order_relaxed); + if (prev_front == nullptr) { + m_back.store(element, std::memory_order_relaxed); + } + else { + prev_front->*next = element; + } + return true; + } + expected = prev_front; + return false; + } + + Element* pop() noexcept { + std::lock_guard lk(m_mtx); + const auto prev_back = m_back.load(std::memory_order_relaxed); + if (prev_back != nullptr) { + const auto new_back = prev_back->*next; + m_back.store(new_back, std::memory_order_relaxed); + if (new_back == nullptr) { + m_front.store(nullptr, std::memory_order_relaxed); + } + } + return prev_back; + } + + bool empty() const noexcept { + return m_back.load(std::memory_order_relaxed) == nullptr; + } + +private: + std::atomic m_front; + std::atomic m_back; + mutable spinlock m_mtx; +}; + +} // namespace asyncpp \ No newline at end of file diff --git a/include/async++/mutex.hpp b/include/async++/mutex.hpp new file mode 100644 index 0000000..ac8ff70 --- /dev/null +++ b/include/async++/mutex.hpp @@ -0,0 +1,153 @@ +#pragma once + +#include "container/atomic_queue.hpp" +#include "promise.hpp" +#include "sync/spinlock.hpp" + +#include +#include +#include +#include + + +namespace asyncpp { + +template +class [[nodiscard]] lock { +public: + lock(Mutex* mtx) : m_mtx(mtx) {} + lock(lock&& rhs) : m_mtx(std::exchange(rhs.m_mtx, nullptr)) {} + lock& operator=(lock&& rhs) { + if (m_mtx) { + m_mtx->unlock(); + } + m_mtx = std::exchange(rhs.m_mtx, nullptr); + return *this; + } + ~lock() { + if (m_mtx) { + (m_mtx->*unlock)(); + } + } + Mutex& parent() const noexcept { + return *m_mtx; + } + +private: + Mutex* m_mtx = nullptr; +}; + + +class mutex { + void unlock(); + +public: + using lock = lock; + friend lock; + +private: + struct awaitable { + awaitable* m_next = nullptr; + awaitable* m_prev = nullptr; + + awaitable(mutex* mtx) : m_mtx(mtx) {} + bool await_ready() noexcept; + template Promise> + bool await_suspend(std::coroutine_handle enclosing) noexcept; + lock await_resume() noexcept; + void on_ready(lock lk) noexcept; + + private: + mutex* m_mtx; + impl::resumable_promise* m_enclosing = nullptr; + std::optional m_lk; + }; + +public: + [[nodiscard]] std::optional try_lock() noexcept; + awaitable unique() noexcept; + awaitable operator co_await() noexcept; + +private: + std::optional wait(awaitable* waiting); + +private: + atomic_queue m_queue; + bool m_locked = false; + spinlock m_spinlock; +}; + + +template Promise> +bool mutex::awaitable::await_suspend(std::coroutine_handle enclosing) noexcept { + m_enclosing = &enclosing.promise(); + m_lk = m_mtx->wait(this); + return !m_lk.has_value(); +} + + +template +class unique_lock { + using mutex_awaitable = std::invoke_result_t; + struct awaitable { + unique_lock* m_lock; + mutex_awaitable m_awaitable; + + auto await_ready() noexcept { + return m_awaitable.await_ready(); + } + + template + auto await_suspend(std::coroutine_handle enclosing) noexcept { + return m_awaitable.await_suspend(enclosing); + } + + void await_resume() noexcept { + m_lock->m_lk = m_awaitable.await_resume(); + } + }; + +public: + unique_lock(Mutex& mtx) noexcept : m_mtx(mtx) {} + + template + unique_lock(lock lk) noexcept : m_mtx(lk.parent()), m_lk(std::move(lk)) {} + + bool try_lock() noexcept { + assert(!owns_lock()); + m_lk = m_mtx.try_lock(); + return m_lk.has_value(); + } + + auto operator co_await() noexcept { + assert(!owns_lock()); + return awaitable(this, m_mtx.unique()); + } + + void unlock() noexcept { + assert(owns_lock()); + m_lk = std::nullopt; + } + + Mutex& mutex() const noexcept { + return m_mtx; + } + + bool owns_lock() const noexcept { + return m_lk.has_value(); + } + + operator bool() const noexcept { + return owns_lock(); + } + +private: + Mutex& m_mtx; + std::optional m_lk; +}; + + +template +unique_lock(lock lk) -> unique_lock; + +} // namespace asyncpp \ No newline at end of file diff --git a/include/async++/shared_task.hpp b/include/async++/shared_task.hpp index d7f51ba..88c334f 100644 --- a/include/async++/shared_task.hpp +++ b/include/async++/shared_task.hpp @@ -161,21 +161,21 @@ namespace impl_shared_task { template struct sync_awaitable : chained_awaitable { promise* m_awaited = nullptr; - std::promise&> m_promise; - std::future&> m_future = m_promise.get_future(); + std::promise*> m_promise; + std::future*> m_future = m_promise.get_future(); sync_awaitable(promise* awaited) noexcept : m_awaited(awaited) { m_awaited->acquire(); const bool ready = m_awaited->await(this); if (ready) { - m_promise.set_value(m_awaited->get_result()); + m_promise.set_value(&m_awaited->get_result()); } } ~sync_awaitable() override { m_awaited->release(); } void on_ready() noexcept final { - return m_promise.set_value(m_awaited->get_result()); + return m_promise.set_value(&m_awaited->get_result()); } }; @@ -237,7 +237,7 @@ class shared_task { auto get() const -> typename impl::task_result::reference { assert(valid()); impl_shared_task::sync_awaitable awaitable(m_promise); - return INTERLEAVED_ACQUIRE(awaitable.m_future.get()).get_or_throw(); + return INTERLEAVED_ACQUIRE(awaitable.m_future.get())->get_or_throw(); } auto operator co_await() const { diff --git a/include/async++/task.hpp b/include/async++/task.hpp index c8feebc..7779c7c 100644 --- a/include/async++/task.hpp +++ b/include/async++/task.hpp @@ -137,7 +137,12 @@ namespace impl_task { m_result = m_awaited->get_result(); m_awaited->release(); } - return std::forward(m_result.get_or_throw()); + if constexpr (!std::is_void_v) { + return std::forward(m_result.get_or_throw()); + } + else { + m_result.get_or_throw(); + } } void on_ready() noexcept final { @@ -203,7 +208,12 @@ class [[nodiscard]] task { T get() { assert(valid()); impl_task::sync_awaitable awaitable(std::exchange(m_promise, nullptr)); - return std::forward(INTERLEAVED_ACQUIRE(awaitable.m_future.get()).get_or_throw()); + if constexpr (!std::is_void_v) { + return std::forward(INTERLEAVED_ACQUIRE(awaitable.m_future.get()).get_or_throw()); + } + else { + INTERLEAVED_ACQUIRE(awaitable.m_future.get()).get_or_throw(); + } } auto operator co_await() { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8e8f737..b934d4b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(async++) target_sources(async++ PRIVATE thread_pool.cpp + mutex.cpp interleaving/runner.cpp interleaving/sequencer.cpp interleaving/state_tree.cpp diff --git a/src/mutex.cpp b/src/mutex.cpp new file mode 100644 index 0000000..c90fa66 --- /dev/null +++ b/src/mutex.cpp @@ -0,0 +1,66 @@ +#include + +namespace asyncpp { + +bool mutex::awaitable::await_ready() noexcept { + m_lk = m_mtx->try_lock(); + return m_lk.has_value(); +} + + +mutex::lock mutex::awaitable::await_resume() noexcept { + assert(m_lk); + return std::move(m_lk.value()); +} + + +void mutex::awaitable::on_ready(lock lk) noexcept { + m_lk = std::move(lk); + assert(m_enclosing); + m_enclosing->resume(); +} + + +std::optional mutex::try_lock() noexcept { + std::lock_guard lk(m_spinlock); + if (std::exchange(m_locked, true) == false) { + return lock(this); + } + return std::nullopt; +} + + +mutex::awaitable mutex::unique() noexcept { + return awaitable(this); +} + + +mutex::awaitable mutex::operator co_await() noexcept { + return unique(); +} + + +std::optional mutex::wait(awaitable* waiting) { + std::lock_guard lk(m_spinlock); + const bool acquired = std::exchange(m_locked, true) == false; + if (acquired) { + return lock(this); + } + m_queue.push(waiting); + return std::nullopt; +} + + +void mutex::unlock() { + std::unique_lock lk(m_spinlock); + assert(m_locked); + m_locked = false; + awaitable* const next = m_queue.pop(); + lk.unlock(); + if (next) { + m_locked = true; + next->on_ready(lock(this)); + } +} + +} // namespace asyncpp \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ccefc38..51d6ec2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,10 +7,12 @@ target_sources(test test_stream.cpp test_thread_pool.cpp test_task.cpp - test_atomic_stack.cpp + test_mutex.cpp interleaving/test_runner.cpp test_shared_task.cpp - test_atomic_collection.cpp + container/test_atomic_queue.cpp + container/test_atomic_stack.cpp + container/test_atomic_collection.cpp ) diff --git a/test/test_atomic_collection.cpp b/test/container/test_atomic_collection.cpp similarity index 100% rename from test/test_atomic_collection.cpp rename to test/container/test_atomic_collection.cpp diff --git a/test/container/test_atomic_queue.cpp b/test/container/test_atomic_queue.cpp new file mode 100644 index 0000000..fccaea1 --- /dev/null +++ b/test/container/test_atomic_queue.cpp @@ -0,0 +1,37 @@ +#include + +#include + + +using namespace asyncpp; + + +struct queue_element { + int id = 0; + queue_element* next = nullptr; + queue_element* prev = nullptr; +}; + +using queue_t = atomic_queue; + + +TEST_CASE("Atomic queue: all", "[Atomic queue]") { + queue_element e0{ .id = 0 }; + queue_element e1{ .id = 1 }; + queue_element e2{ .id = 2 }; + queue_element e3{ .id = 3 }; + queue_t queue; + REQUIRE(queue.empty()); + REQUIRE(queue.pop() == nullptr); + REQUIRE(queue.push(&e0) == nullptr); + REQUIRE(queue.push(&e1) == &e0); + REQUIRE(queue.push(&e2) == &e1); + REQUIRE(queue.push(&e3) == &e2); + REQUIRE(!queue.empty()); + REQUIRE(queue.pop() == &e0); + REQUIRE(queue.pop() == &e1); + REQUIRE(queue.pop() == &e2); + REQUIRE(queue.pop() == &e3); + REQUIRE(queue.empty()); + REQUIRE(queue.pop() == nullptr); +} \ No newline at end of file diff --git a/test/test_atomic_stack.cpp b/test/container/test_atomic_stack.cpp similarity index 61% rename from test/test_atomic_stack.cpp rename to test/container/test_atomic_stack.cpp index 5230aee..be60b0f 100644 --- a/test/test_atomic_stack.cpp +++ b/test/container/test_atomic_stack.cpp @@ -6,20 +6,20 @@ using namespace asyncpp; -struct collection_element { +struct queue_element { int id = 0; - collection_element* next = nullptr; + queue_element* next = nullptr; }; -using stack_t = atomic_stack; +using queue_t = atomic_stack; -TEST_CASE("Atomic list: all", "[Atomic list]") { - collection_element e0{ .id = 0 }; - collection_element e1{ .id = 1 }; - collection_element e2{ .id = 2 }; - collection_element e3{ .id = 3 }; - stack_t stack; +TEST_CASE("Atomic stack: all", "[Atomic stack]") { + queue_element e0{ .id = 0 }; + queue_element e1{ .id = 1 }; + queue_element e2{ .id = 2 }; + queue_element e3{ .id = 3 }; + queue_t stack; REQUIRE(stack.empty()); REQUIRE(stack.pop() == nullptr); REQUIRE(stack.push(&e0) == nullptr); diff --git a/test/test_mutex.cpp b/test/test_mutex.cpp new file mode 100644 index 0000000..ea87dc4 --- /dev/null +++ b/test/test_mutex.cpp @@ -0,0 +1,69 @@ +#include +#include + +#include + +using namespace asyncpp; + + +TEST_CASE("Mutex: try lock", "[Mutex]") { + static const auto coro = [](mutex& mtx) -> task { + const auto lock = mtx.try_lock(); + REQUIRE(!mtx.try_lock()); + co_return; + }; + + mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Mutex: lock", "[Mutex]") { + static const auto coro = [](mutex& mtx) -> task { + const auto lock = co_await mtx; + REQUIRE(!mtx.try_lock()); + }; + + mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Mutex: unique lock try", "[Mutex]") { + static const auto coro = [](mutex& mtx) -> task { + unique_lock lk(mtx); + REQUIRE(!lk.owns_lock()); + REQUIRE(lk.try_lock()); + REQUIRE(lk.owns_lock()); + co_return; + }; + + mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Mutex: unique lock await", "[Mutex]") { + static const auto coro = [](mutex& mtx) -> task { + unique_lock lk(mtx); + REQUIRE(!lk.owns_lock()); + co_await lk; + REQUIRE(lk.owns_lock()); + co_return; + }; + + mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Mutex: unique lock start locked", "[Mutex]") { + static const auto coro = [](mutex& mtx) -> task { + unique_lock lk(co_await mtx); + REQUIRE(lk.owns_lock()); + co_return; + }; + + mutex mtx; + coro(mtx).get(); +} \ No newline at end of file diff --git a/test/test_shared_task.cpp b/test/test_shared_task.cpp index 6854ea5..13b82a1 100644 --- a/test/test_shared_task.cpp +++ b/test/test_shared_task.cpp @@ -183,3 +183,16 @@ TEST_CASE("Shared task: co_await ref", "[Shared task]") { REQUIRE(result == 42); REQUIRE(&result == &value); } + + +TEST_CASE("Shared task: co_await void", "[Shared task]") { + static int value = 42; + static const auto coro = []() -> shared_task { + co_return; + }; + static const auto enclosing = []() -> shared_task { + co_await coro(); + }; + auto task = enclosing(); + task.get(); +} diff --git a/test/test_task.cpp b/test/test_task.cpp index 977a845..460706f 100644 --- a/test/test_task.cpp +++ b/test/test_task.cpp @@ -175,4 +175,17 @@ TEST_CASE("Task: co_await ref", "[Task]") { auto& result = task.get(); REQUIRE(result == 42); REQUIRE(&result == &value); +} + + +TEST_CASE("Task: co_await/get void", "[Task]") { + static int value = 42; + static const auto coro = [](int& value) -> task { + co_return; + }; + static const auto enclosing = [](int& value) -> task { + co_await coro(value); + }; + auto task = enclosing(value); + task.get(); } \ No newline at end of file