diff --git a/include/async++/container/atomic_queue.hpp b/include/async++/container/atomic_queue.hpp index d39a039..63d2034 100644 --- a/include/async++/container/atomic_queue.hpp +++ b/include/async++/container/atomic_queue.hpp @@ -3,6 +3,7 @@ #include "../sync/spinlock.hpp" #include +#include namespace asyncpp { @@ -14,6 +15,7 @@ class atomic_queue { std::lock_guard lk(m_mtx); const auto prev_front = m_front.load(std::memory_order_relaxed); element->*prev = prev_front; + element->*next = nullptr; m_front.store(element, std::memory_order_relaxed); if (prev_front == nullptr) { m_back.store(element, std::memory_order_relaxed); @@ -24,24 +26,6 @@ class atomic_queue { return prev_front; } - bool compare_push(Element*& expected, Element* element) { - std::lock_guard lk(m_mtx); - const auto prev_front = m_front.load(std::memory_order_relaxed); - if (prev_front == expected) { - element->*prev = prev_front; - m_front.store(element, std::memory_order_relaxed); - if (prev_front == nullptr) { - m_back.store(element, std::memory_order_relaxed); - } - else { - prev_front->*next = element; - } - return true; - } - expected = prev_front; - return false; - } - Element* pop() noexcept { std::lock_guard lk(m_mtx); const auto prev_back = m_back.load(std::memory_order_relaxed); @@ -51,10 +35,21 @@ class atomic_queue { if (new_back == nullptr) { m_front.store(nullptr, std::memory_order_relaxed); } + else { + new_back->*prev = nullptr; + } } return prev_back; } + Element* front() { + return m_front.load(std::memory_order_relaxed); + } + + Element* back() { + return m_back.load(std::memory_order_relaxed); + } + bool empty() const noexcept { return m_back.load(std::memory_order_relaxed) == nullptr; } diff --git a/include/async++/interleaving/runner.hpp b/include/async++/interleaving/runner.hpp index 039a509..9455b97 100644 --- a/include/async++/interleaving/runner.hpp +++ b/include/async++/interleaving/runner.hpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace asyncpp::interleaving { @@ -31,12 +32,30 @@ struct interleaving_printer { std::ostream& operator<<(std::ostream& os, const interleaving_printer& il); -generator run_all(std::function fixture, std::vector> threads, std::vector names = {}); +class filter { +public: + filter() : filter(".*") {} + explicit filter(std::string_view file_regex) : m_files(file_regex.begin(), file_regex.end()) {} + + bool operator()(const sequence_point& point) const; + +private: + std::regex m_files; +}; + + +generator run_all(std::function fixture, + std::vector> threads, + std::vector names = {}, + filter filter_ = {}); template requires std::convertible_to -generator run_all(std::function fixture, std::vector> threads, std::vector names = {}) { +generator run_all(std::function fixture, + std::vector> threads, + std::vector names = {}, + filter filter_ = {}) { std::function wrapped_init = [fixture = std::move(fixture)]() -> std::any { if constexpr (!std::is_void_v) { return std::any(fixture()); @@ -51,18 +70,20 @@ generator run_all(std::function fixture, std::vector run_all(std::vector> threads, std::vector names = {}) { +inline generator run_all(std::vector> threads, + std::vector names = {}, + filter filter_ = {}) { std::vector> wrapped_threads; std::ranges::transform(threads, std::back_inserter(wrapped_threads), [](auto& thread) { return std::function([thread = std::move(thread)](std::any&) { return thread(); }); }); - return run_all([] { return std::any(); }, std::move(wrapped_threads), std::move(names)); + return run_all([] { return std::any(); }, std::move(wrapped_threads), std::move(names), filter_); } } // namespace asyncpp::interleaving diff --git a/include/async++/lock.hpp b/include/async++/lock.hpp new file mode 100644 index 0000000..1e0dab3 --- /dev/null +++ b/include/async++/lock.hpp @@ -0,0 +1,179 @@ +#pragma once + +#include +#include +#include +#include + + +namespace asyncpp { + +template +class mutex_lock { + friend Mutex; + +public: + mutex_lock(mutex_lock&&) = default; + mutex_lock& operator=(mutex_lock&&) = default; + mutex_lock(const mutex_lock&) = delete; + mutex_lock& operator=(const mutex_lock&) = delete; + Mutex& mutex() const noexcept { + return *m_mtx; + } + +private: + mutex_lock(Mutex* mtx) : m_mtx(mtx) {} + Mutex* m_mtx = nullptr; +}; + + +template +class mutex_shared_lock { + friend Mutex; + +public: + mutex_shared_lock(mutex_shared_lock&&) = default; + mutex_shared_lock& operator=(mutex_shared_lock&&) = default; + mutex_shared_lock(const mutex_shared_lock&) = delete; + mutex_shared_lock& operator=(const mutex_shared_lock&) = delete; + Mutex& mutex() const noexcept { + return *m_mtx; + } + +private: + mutex_shared_lock(Mutex* mtx) : m_mtx(mtx) {} + Mutex* m_mtx = nullptr; +}; + + +template +class unique_lock { + using mutex_awaitable = std::invoke_result_t; + struct awaitable { + unique_lock* m_lock; + mutex_awaitable m_awaitable; + + auto await_ready() noexcept { + return m_awaitable.await_ready(); + } + + template + auto await_suspend(std::coroutine_handle enclosing) noexcept { + return m_awaitable.await_suspend(enclosing); + } + + void await_resume() noexcept { + m_awaitable.await_resume(); + m_lock->m_owned = true; + } + }; + +public: + unique_lock(Mutex& mtx) noexcept : m_mtx(&mtx) {} + unique_lock(mutex_lock&& lk) noexcept : m_mtx(&lk.mutex()), m_owned(true) {} + + bool try_lock() noexcept { + assert(!owns_lock()); + m_owned = m_mtx->try_lock(); + return m_owned; + } + + auto operator co_await() noexcept { + assert(!owns_lock()); + return awaitable(this, m_mtx->unique()); + } + + void unlock() noexcept { + assert(owns_lock()); + m_mtx->unlock(); + m_owned = false; + } + + Mutex& mutex() const noexcept { + return *m_mtx; + } + + bool owns_lock() const noexcept { + return m_owned; + } + + operator bool() const noexcept { + return owns_lock(); + } + +private: + Mutex* m_mtx; + bool m_owned = false; +}; + + +template +unique_lock(mutex_lock&& lk) -> unique_lock; + + +template +class shared_lock { + using mutex_awaitable = std::invoke_result_t; + struct awaitable { + shared_lock* m_lock; + mutex_awaitable m_awaitable; + + auto await_ready() noexcept { + return m_awaitable.await_ready(); + } + + template + auto await_suspend(std::coroutine_handle enclosing) noexcept { + return m_awaitable.await_suspend(enclosing); + } + + void await_resume() noexcept { + m_awaitable.await_resume(); + m_lock->m_owned = true; + } + }; + +public: + shared_lock(Mutex& mtx) noexcept : m_mtx(&mtx) {} + shared_lock(mutex_shared_lock lk) noexcept : m_mtx(&lk.mutex()), m_owned(true) {} + + bool try_lock() noexcept { + assert(!owns_lock()); + m_owned = m_mtx->try_lock_shared(); + return m_owned; + } + + auto operator co_await() noexcept { + assert(!owns_lock()); + return awaitable(this, m_mtx->shared()); + } + + void unlock() noexcept { + assert(owns_lock()); + m_mtx->unlock_shared(); + m_owned = false; + } + + Mutex& mutex() const noexcept { + return *m_mtx; + } + + bool owns_lock() const noexcept { + return m_owned; + } + + operator bool() const noexcept { + return owns_lock(); + } + +private: + Mutex* m_mtx; + bool m_owned = false; +}; + + +template +shared_lock(mutex_shared_lock lk) -> shared_lock; + + +} // namespace asyncpp \ No newline at end of file diff --git a/include/async++/mutex.hpp b/include/async++/mutex.hpp index ac8ff70..abea84d 100644 --- a/include/async++/mutex.hpp +++ b/include/async++/mutex.hpp @@ -1,51 +1,17 @@ #pragma once #include "container/atomic_queue.hpp" +#include "lock.hpp" #include "promise.hpp" #include "sync/spinlock.hpp" -#include #include -#include -#include +#include namespace asyncpp { -template -class [[nodiscard]] lock { -public: - lock(Mutex* mtx) : m_mtx(mtx) {} - lock(lock&& rhs) : m_mtx(std::exchange(rhs.m_mtx, nullptr)) {} - lock& operator=(lock&& rhs) { - if (m_mtx) { - m_mtx->unlock(); - } - m_mtx = std::exchange(rhs.m_mtx, nullptr); - return *this; - } - ~lock() { - if (m_mtx) { - (m_mtx->*unlock)(); - } - } - Mutex& parent() const noexcept { - return *m_mtx; - } - -private: - Mutex* m_mtx = nullptr; -}; - - class mutex { - void unlock(); - -public: - using lock = lock; - friend lock; - -private: struct awaitable { awaitable* m_next = nullptr; awaitable* m_prev = nullptr; @@ -54,22 +20,21 @@ class mutex { bool await_ready() noexcept; template Promise> bool await_suspend(std::coroutine_handle enclosing) noexcept; - lock await_resume() noexcept; - void on_ready(lock lk) noexcept; + mutex_lock await_resume() noexcept; + void on_ready() noexcept; private: mutex* m_mtx; impl::resumable_promise* m_enclosing = nullptr; - std::optional m_lk; }; + bool lock_enqueue(awaitable* waiting); + public: - [[nodiscard]] std::optional try_lock() noexcept; + bool try_lock() noexcept; awaitable unique() noexcept; awaitable operator co_await() noexcept; - -private: - std::optional wait(awaitable* waiting); + void unlock(); private: atomic_queue m_queue; @@ -81,73 +46,8 @@ class mutex { template Promise> bool mutex::awaitable::await_suspend(std::coroutine_handle enclosing) noexcept { m_enclosing = &enclosing.promise(); - m_lk = m_mtx->wait(this); - return !m_lk.has_value(); + const bool ready = m_mtx->lock_enqueue(this); + return !ready; } - -template -class unique_lock { - using mutex_awaitable = std::invoke_result_t; - struct awaitable { - unique_lock* m_lock; - mutex_awaitable m_awaitable; - - auto await_ready() noexcept { - return m_awaitable.await_ready(); - } - - template - auto await_suspend(std::coroutine_handle enclosing) noexcept { - return m_awaitable.await_suspend(enclosing); - } - - void await_resume() noexcept { - m_lock->m_lk = m_awaitable.await_resume(); - } - }; - -public: - unique_lock(Mutex& mtx) noexcept : m_mtx(mtx) {} - - template - unique_lock(lock lk) noexcept : m_mtx(lk.parent()), m_lk(std::move(lk)) {} - - bool try_lock() noexcept { - assert(!owns_lock()); - m_lk = m_mtx.try_lock(); - return m_lk.has_value(); - } - - auto operator co_await() noexcept { - assert(!owns_lock()); - return awaitable(this, m_mtx.unique()); - } - - void unlock() noexcept { - assert(owns_lock()); - m_lk = std::nullopt; - } - - Mutex& mutex() const noexcept { - return m_mtx; - } - - bool owns_lock() const noexcept { - return m_lk.has_value(); - } - - operator bool() const noexcept { - return owns_lock(); - } - -private: - Mutex& m_mtx; - std::optional m_lk; -}; - - -template -unique_lock(lock lk) -> unique_lock; - } // namespace asyncpp \ No newline at end of file diff --git a/include/async++/shared_mutex.hpp b/include/async++/shared_mutex.hpp new file mode 100644 index 0000000..05c1f00 --- /dev/null +++ b/include/async++/shared_mutex.hpp @@ -0,0 +1,93 @@ +#pragma once + +#include "container/atomic_queue.hpp" +#include "lock.hpp" +#include "promise.hpp" +#include "sync/spinlock.hpp" + +#include +#include + + +namespace asyncpp { + +class shared_mutex { + struct basic_awaitable { + basic_awaitable* m_next = nullptr; + basic_awaitable* m_prev = nullptr; + + basic_awaitable(shared_mutex* mtx) : m_mtx(mtx) {} + virtual ~basic_awaitable() = default; + virtual void on_ready() noexcept = 0; + virtual bool is_shared() const noexcept = 0; + + protected: + shared_mutex* m_mtx; + }; + + struct awaitable : basic_awaitable { + using basic_awaitable::basic_awaitable; + + bool await_ready() noexcept; + template Promise> + bool await_suspend(std::coroutine_handle enclosing) noexcept; + mutex_lock await_resume() noexcept; + void on_ready() noexcept final; + bool is_shared() const noexcept final; + + private: + impl::resumable_promise* m_enclosing = nullptr; + }; + + struct shared_awaitable : basic_awaitable { + using basic_awaitable::basic_awaitable; + + bool await_ready() noexcept; + template Promise> + bool await_suspend(std::coroutine_handle enclosing) noexcept; + mutex_shared_lock await_resume() noexcept; + void on_ready() noexcept final; + bool is_shared() const noexcept final; + + private: + impl::resumable_promise* m_enclosing = nullptr; + }; + + bool lock_enqueue(awaitable* waiting); + bool lock_enqueue_shared(shared_awaitable* waiting); + +public: + bool try_lock() noexcept; + bool try_lock_shared() noexcept; + awaitable unique() noexcept; + shared_awaitable shared() noexcept; + void unlock(); + void unlock_shared(); + + +private: + using queue_t = atomic_queue; + queue_t m_queue; + intptr_t m_locked = 0; + intptr_t m_unique_waiting = false; + spinlock m_spinlock; +}; + + +template Promise> +bool shared_mutex::awaitable::await_suspend(std::coroutine_handle enclosing) noexcept { + m_enclosing = &enclosing.promise(); + const bool ready = m_mtx->lock_enqueue(this); + return !ready; +} + + +template Promise> +bool shared_mutex::shared_awaitable::await_suspend(std::coroutine_handle enclosing) noexcept { + m_enclosing = &enclosing.promise(); + const bool ready = m_mtx->lock_enqueue_shared(this); + return !ready; +} + + +} // namespace asyncpp \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b934d4b..348f928 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,6 +4,7 @@ target_sources(async++ PRIVATE thread_pool.cpp mutex.cpp + shared_mutex.cpp interleaving/runner.cpp interleaving/sequencer.cpp interleaving/state_tree.cpp diff --git a/src/interleaving/runner.cpp b/src/interleaving/runner.cpp index 7473a45..543672f 100644 --- a/src/interleaving/runner.cpp +++ b/src/interleaving/runner.cpp @@ -27,6 +27,7 @@ namespace impl_sp { sequence_point initial_point{ .acquire = false, .name = "", .file = __FILE__, .line = __LINE__ }; sequence_point final_point{ .acquire = false, .name = "", .file = __FILE__, .line = __LINE__ }; thread_local std::shared_ptr local_sequencer; + thread_local filter local_filter; template @@ -43,7 +44,9 @@ namespace impl_sp { void wait(sequence_point& sp) { if (local_sequencer) { - local_sequencer->wait(sp); + if (local_filter(sp)) { + local_sequencer->wait(sp); + } } } @@ -117,7 +120,8 @@ namespace impl_sp { } - void task_thread_func(std::shared_ptr seq, std::function func) { + void task_thread_func(std::shared_ptr seq, std::function func, filter filter_) { + local_filter = std::move(filter_); local_sequencer = seq; seq->wait(initial_point); func(); @@ -189,7 +193,7 @@ namespace impl_sp { } // namespace impl_sp -generator run_all(std::function fixture, std::vector> threads, std::vector names) { +generator run_all(std::function fixture, std::vector> threads, std::vector names, filter filter_) { using namespace impl_sp; std::vector> sequencers; @@ -211,8 +215,9 @@ generator run_all(std::function fixture, std::vector os_threads; auto fixture_val = fixture(); for (size_t i = 0; i < threads.size(); ++i) { - os_threads.emplace_back([exec_thread = sequencers[i], &tsk = threads[i], &fixture_val] { - task_thread_func(exec_thread, [&] { tsk(fixture_val); }); + os_threads.emplace_back([exec_thread = sequencers[i], &tsk = threads[i], &fixture_val, &filter_] { + task_thread_func( + exec_thread, [&] { tsk(fixture_val); }, filter_); }); } interleaving_ = control_thread_func(sequencers, root); @@ -270,4 +275,9 @@ std::ostream& operator<<(std::ostream& os, const interleaving_printer& il) { } +bool filter::operator()(const sequence_point& point) const { + return std::regex_search(point.file.begin(), point.file.end(), m_files); +} + + } // namespace asyncpp::interleaving \ No newline at end of file diff --git a/src/mutex.cpp b/src/mutex.cpp index c90fa66..2209da7 100644 --- a/src/mutex.cpp +++ b/src/mutex.cpp @@ -1,37 +1,34 @@ #include +#include + + namespace asyncpp { bool mutex::awaitable::await_ready() noexcept { - m_lk = m_mtx->try_lock(); - return m_lk.has_value(); + return m_mtx->try_lock(); } -mutex::lock mutex::awaitable::await_resume() noexcept { - assert(m_lk); - return std::move(m_lk.value()); +mutex_lock mutex::awaitable::await_resume() noexcept { + return { m_mtx }; } -void mutex::awaitable::on_ready(lock lk) noexcept { - m_lk = std::move(lk); +void mutex::awaitable::on_ready() noexcept { assert(m_enclosing); m_enclosing->resume(); } -std::optional mutex::try_lock() noexcept { +bool mutex::try_lock() noexcept { std::lock_guard lk(m_spinlock); - if (std::exchange(m_locked, true) == false) { - return lock(this); - } - return std::nullopt; + return std::exchange(m_locked, true) == false; } mutex::awaitable mutex::unique() noexcept { - return awaitable(this); + return { this }; } @@ -40,14 +37,14 @@ mutex::awaitable mutex::operator co_await() noexcept { } -std::optional mutex::wait(awaitable* waiting) { +bool mutex::lock_enqueue(awaitable* waiting) { std::lock_guard lk(m_spinlock); const bool acquired = std::exchange(m_locked, true) == false; if (acquired) { - return lock(this); + return true; } m_queue.push(waiting); - return std::nullopt; + return false; } @@ -55,11 +52,13 @@ void mutex::unlock() { std::unique_lock lk(m_spinlock); assert(m_locked); m_locked = false; - awaitable* const next = m_queue.pop(); - lk.unlock(); + const auto next = m_queue.pop(); if (next) { m_locked = true; - next->on_ready(lock(this)); + } + lk.unlock(); + if (next) { + next->on_ready(); } } diff --git a/src/shared_mutex.cpp b/src/shared_mutex.cpp new file mode 100644 index 0000000..45ae094 --- /dev/null +++ b/src/shared_mutex.cpp @@ -0,0 +1,140 @@ +#include + +#include + + +namespace asyncpp { + +bool shared_mutex::awaitable::await_ready() noexcept { + return m_mtx->try_lock(); +} + + +mutex_lock shared_mutex::awaitable::await_resume() noexcept { + return { m_mtx }; +} + + +void shared_mutex::awaitable::on_ready() noexcept { + assert(m_enclosing); + m_enclosing->resume(); +} + + +bool shared_mutex::awaitable::is_shared() const noexcept { + return false; +} + + +bool shared_mutex::shared_awaitable::await_ready() noexcept { + return m_mtx->try_lock_shared(); +} + + +mutex_shared_lock shared_mutex::shared_awaitable::await_resume() noexcept { + return { m_mtx }; +} + + +void shared_mutex::shared_awaitable::on_ready() noexcept { + assert(m_enclosing); + m_enclosing->resume(); +} + + +bool shared_mutex::shared_awaitable::is_shared() const noexcept { + return true; +} + + +bool shared_mutex::try_lock() noexcept { + std::lock_guard lk(m_spinlock); + if (m_locked == 0) { + --m_locked; + return true; + } + return false; +} + + +bool shared_mutex::try_lock_shared() noexcept { + std::lock_guard lk(m_spinlock); + if (m_locked >= 0 && m_unique_waiting == 0) { + ++m_locked; + return true; + } + return false; +} + + +shared_mutex::awaitable shared_mutex::unique() noexcept { + return { this }; +} + + +shared_mutex::shared_awaitable shared_mutex::shared() noexcept { + return { this }; +} + + +bool shared_mutex::lock_enqueue(awaitable* waiting) { + std::lock_guard lk(m_spinlock); + if (m_locked == 0) { + --m_locked; + return true; + } + m_queue.push(waiting); + ++m_unique_waiting; + return false; +} + + +bool shared_mutex::lock_enqueue_shared(shared_awaitable* waiting) { + std::lock_guard lk(m_spinlock); + if (m_locked >= 0 && m_unique_waiting == 0) { + ++m_locked; + return true; + } + m_queue.push(waiting); + return false; +} + + +void shared_mutex::unlock() { + std::unique_lock lk(m_spinlock); + assert(m_locked == -1); + ++m_locked; + queue_t next_list; + basic_awaitable* next; + do { + next = m_queue.pop(); + if (next) { + m_locked += next->is_shared() ? +1 : -1; + m_unique_waiting -= intptr_t(!next->is_shared()); + next_list.push(next); + } + } while (next && next->is_shared() && !m_queue.empty() && m_queue.back()->is_shared() && m_locked >= 0); + lk.unlock(); + while ((next = next_list.pop()) != nullptr) { + next->on_ready(); + } +} + + +void shared_mutex::unlock_shared() { + std::unique_lock lk(m_spinlock); + assert(m_locked > 0); + --m_locked; + if (m_locked == 0) { + const auto next = m_queue.pop(); + if (next) { + assert(!next->is_shared()); // Shared ones would have been continued immediately. + --m_locked; + --m_unique_waiting; + lk.unlock(); + next->on_ready(); + } + } +} + +} // namespace asyncpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 51d6ec2..a959e00 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources(test test_thread_pool.cpp test_task.cpp test_mutex.cpp + test_shared_mutex.cpp interleaving/test_runner.cpp test_shared_task.cpp container/test_atomic_queue.cpp diff --git a/test/test_mutex.cpp b/test/test_mutex.cpp index ea87dc4..999c47e 100644 --- a/test/test_mutex.cpp +++ b/test/test_mutex.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -8,7 +9,7 @@ using namespace asyncpp; TEST_CASE("Mutex: try lock", "[Mutex]") { static const auto coro = [](mutex& mtx) -> task { - const auto lock = mtx.try_lock(); + REQUIRE(mtx.try_lock()); REQUIRE(!mtx.try_lock()); co_return; }; @@ -20,7 +21,7 @@ TEST_CASE("Mutex: try lock", "[Mutex]") { TEST_CASE("Mutex: lock", "[Mutex]") { static const auto coro = [](mutex& mtx) -> task { - const auto lock = co_await mtx; + co_await mtx; REQUIRE(!mtx.try_lock()); }; @@ -29,6 +30,18 @@ TEST_CASE("Mutex: lock", "[Mutex]") { } +TEST_CASE("Mutex: unlock", "[Mutex]") { + static const auto coro = [](mutex& mtx) -> task { + co_await mtx; + mtx.unlock(); + REQUIRE(mtx.try_lock()); + }; + + mutex mtx; + coro(mtx).get(); +} + + TEST_CASE("Mutex: unique lock try", "[Mutex]") { static const auto coro = [](mutex& mtx) -> task { unique_lock lk(mtx); @@ -66,4 +79,41 @@ TEST_CASE("Mutex: unique lock start locked", "[Mutex]") { mutex mtx; coro(mtx).get(); +} + + +TEST_CASE("Mutex: unique lock unlock", "[Mutex]") { + static const auto coro = [](mutex& mtx) -> task { + unique_lock lk(co_await mtx); + lk.unlock(); + REQUIRE(!lk.owns_lock()); + co_return; + }; + + mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Mutex: resume awaiting", "[Mutex]") { + static const auto awaiter = [](mutex& mtx, std::vector& sequence, int id) -> task { + co_await mtx; + sequence.push_back(id); + mtx.unlock(); + }; + static const auto main = [](mutex& mtx, std::vector& sequence) -> task { + auto t1 = awaiter(mtx, sequence, 1); + auto t2 = awaiter(mtx, sequence, 2); + + co_await mtx; + sequence.push_back(0); + t1.launch(); + t2.launch(); + mtx.unlock(); + }; + + mutex mtx; + std::vector sequence; + main(mtx, sequence).get(); + REQUIRE(sequence == std::vector{ 0, 1, 2 }); } \ No newline at end of file diff --git a/test/test_shared_mutex.cpp b/test/test_shared_mutex.cpp new file mode 100644 index 0000000..252a995 --- /dev/null +++ b/test/test_shared_mutex.cpp @@ -0,0 +1,262 @@ +#include +#include +#include + +#include + +using namespace asyncpp; + + +TEST_CASE("Shared mutex: try lock", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + REQUIRE(mtx.try_lock()); + REQUIRE(!mtx.try_lock()); + REQUIRE(!mtx.try_lock_shared()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: lock", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + co_await mtx.unique(); + REQUIRE(!mtx.try_lock()); + REQUIRE(!mtx.try_lock_shared()); + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: unlock", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + co_await mtx.unique(); + mtx.unlock(); + REQUIRE(mtx.try_lock()); + mtx.unlock(); + REQUIRE(mtx.try_lock_shared()); + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: try lock shared", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + REQUIRE(mtx.try_lock_shared()); + REQUIRE(!mtx.try_lock()); + REQUIRE(mtx.try_lock_shared()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: lock shared", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + co_await mtx.shared(); + REQUIRE(!mtx.try_lock()); + REQUIRE(mtx.try_lock_shared()); + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: unlock shared", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + co_await mtx.shared(); + mtx.unlock_shared(); + REQUIRE(mtx.try_lock()); + mtx.unlock(); + REQUIRE(mtx.try_lock_shared()); + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: unique lock try", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + unique_lock lk(mtx); + REQUIRE(!lk.owns_lock()); + REQUIRE(lk.try_lock()); + REQUIRE(lk.owns_lock()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: unique lock await", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + unique_lock lk(mtx); + REQUIRE(!lk.owns_lock()); + co_await lk; + REQUIRE(lk.owns_lock()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: unique lock start locked", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + unique_lock lk(co_await mtx.unique()); + REQUIRE(lk.owns_lock()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: unique lock unlock", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + unique_lock lk(co_await mtx.unique()); + lk.unlock(); + REQUIRE(!lk.owns_lock()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: shared lock try", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + shared_lock lk(mtx); + REQUIRE(!lk.owns_lock()); + REQUIRE(lk.try_lock()); + REQUIRE(lk.owns_lock()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: shared lock await", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + shared_lock lk(mtx); + REQUIRE(!lk.owns_lock()); + co_await lk; + REQUIRE(lk.owns_lock()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: shared lock start locked", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + shared_lock lk(co_await mtx.shared()); + REQUIRE(lk.owns_lock()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: shared lock unlock", "[Shared mutex]") { + static const auto coro = [](shared_mutex& mtx) -> task { + shared_lock lk(co_await mtx.shared()); + REQUIRE(lk.owns_lock()); + lk.unlock(); + REQUIRE(!lk.owns_lock()); + co_return; + }; + + shared_mutex mtx; + coro(mtx).get(); +} + + +TEST_CASE("Shared mutex: resume awaiting", "[Shared mutex]") { + static const auto awaiter = [](shared_mutex& mtx, std::vector& sequence, int id) -> task { + co_await mtx.unique(); + sequence.push_back(id); + mtx.unlock(); + }; + static const auto shared_awaiter = [](shared_mutex& mtx, std::vector& sequence, int id) -> task { + co_await mtx.shared(); + sequence.push_back(id); + mtx.unlock_shared(); + }; + static const auto main = [](shared_mutex& mtx, std::vector& sequence) -> task { + auto s1 = shared_awaiter(mtx, sequence, 1); + auto s2 = shared_awaiter(mtx, sequence, 2); + auto u1 = awaiter(mtx, sequence, -1); + auto u2 = awaiter(mtx, sequence, -2); + + co_await mtx.unique(); + sequence.push_back(0); + s1.launch(); + s2.launch(); + u1.launch(); + u2.launch(); + mtx.unlock(); + }; + + shared_mutex mtx; + std::vector sequence; + main(mtx, sequence).get(); + REQUIRE(sequence == std::vector{ 0, 1, 2, -1, -2 }); +} + + +TEST_CASE("Shared mutex: unique starvation", "[Shared mutex]") { + static const auto awaiter = [](shared_mutex& mtx, std::vector& sequence, int id) -> task { + co_await mtx.unique(); + sequence.push_back(id); + mtx.unlock(); + }; + static const auto shared_awaiter = [](shared_mutex& mtx, std::vector& sequence, int id) -> task { + co_await mtx.shared(); + sequence.push_back(id); + mtx.unlock_shared(); + }; + static const auto main = [](shared_mutex& mtx, std::vector& sequence) -> task { + auto s1 = shared_awaiter(mtx, sequence, 1); + auto s2 = shared_awaiter(mtx, sequence, 2); + auto s3 = shared_awaiter(mtx, sequence, 3); + auto s4 = shared_awaiter(mtx, sequence, 4); + auto u1 = awaiter(mtx, sequence, -1); + + co_await mtx.shared(); + sequence.push_back(0); + s1.launch(); + s2.launch(); + u1.launch(); + s3.launch(); + mtx.unlock_shared(); + co_await mtx.shared(); + sequence.push_back(0); + s4.launch(); + }; + + shared_mutex mtx; + std::vector sequence; + main(mtx, sequence).get(); + REQUIRE(sequence == std::vector{ 0, 1, 2, -1, 3, 0, 4 }); +} \ No newline at end of file diff --git a/test/test_shared_task.cpp b/test/test_shared_task.cpp index 13b82a1..de14bda 100644 --- a/test/test_shared_task.cpp +++ b/test/test_shared_task.cpp @@ -47,7 +47,7 @@ TEST_CASE("Shared task: interleave sync", "[Shared task]") { INFO((interleaving::interleaving_printer{ il, true })); REQUIRE(tester); } - REQUIRE(count > 0); + REQUIRE(count >= 3); } @@ -97,7 +97,7 @@ TEST_CASE("Shared task: interleaving co_await", "[Shared task]") { INFO((interleaving::interleaving_printer{ il, true })); REQUIRE(tester); } - REQUIRE(count > 0); + REQUIRE(count >= 3); } @@ -135,7 +135,7 @@ TEST_CASE("Shared task: interleaving abandon", "[Shared task]") { INFO((interleaving::interleaving_printer{ il, true })); REQUIRE(tester); } - REQUIRE(count > 0); + REQUIRE(count >= 3); } @@ -195,4 +195,4 @@ TEST_CASE("Shared task: co_await void", "[Shared task]") { }; auto task = enclosing(); task.get(); -} +} \ No newline at end of file diff --git a/test/test_task.cpp b/test/test_task.cpp index 460706f..af5be2b 100644 --- a/test/test_task.cpp +++ b/test/test_task.cpp @@ -44,7 +44,7 @@ TEST_CASE("Task: interleaving sync", "[Task]") { INFO((interleaving::interleaving_printer{ il, true })); REQUIRE(tester); } - REQUIRE(count > 0); + REQUIRE(count >= 3); } @@ -94,7 +94,7 @@ TEST_CASE("Task: interleaving co_await", "[Task]") { INFO((interleaving::interleaving_printer{ il, true })); REQUIRE(tester); } - REQUIRE(count > 0); + REQUIRE(count >= 3); } @@ -132,7 +132,7 @@ TEST_CASE("Task: interleaving abandon", "[Task]") { INFO((interleaving::interleaving_printer{ il, true })); REQUIRE(tester); } - REQUIRE(count > 0); + REQUIRE(count >= 3); }