Skip to content

Commit

Permalink
shared mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
petiaccja committed Nov 19, 2023
1 parent 4a08066 commit f6b25e3
Show file tree
Hide file tree
Showing 14 changed files with 817 additions and 166 deletions.
31 changes: 13 additions & 18 deletions include/async++/container/atomic_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "../sync/spinlock.hpp"

#include <atomic>
#include <mutex>


namespace asyncpp {
Expand All @@ -14,6 +15,7 @@ class atomic_queue {
std::lock_guard lk(m_mtx);
const auto prev_front = m_front.load(std::memory_order_relaxed);
element->*prev = prev_front;
element->*next = nullptr;
m_front.store(element, std::memory_order_relaxed);
if (prev_front == nullptr) {
m_back.store(element, std::memory_order_relaxed);
Expand All @@ -24,24 +26,6 @@ class atomic_queue {
return prev_front;
}

bool compare_push(Element*& expected, Element* element) {
std::lock_guard lk(m_mtx);
const auto prev_front = m_front.load(std::memory_order_relaxed);
if (prev_front == expected) {
element->*prev = prev_front;
m_front.store(element, std::memory_order_relaxed);
if (prev_front == nullptr) {
m_back.store(element, std::memory_order_relaxed);
}
else {
prev_front->*next = element;
}
return true;
}
expected = prev_front;
return false;
}

Element* pop() noexcept {
std::lock_guard lk(m_mtx);
const auto prev_back = m_back.load(std::memory_order_relaxed);
Expand All @@ -51,10 +35,21 @@ class atomic_queue {
if (new_back == nullptr) {
m_front.store(nullptr, std::memory_order_relaxed);
}
else {
new_back->*prev = nullptr;
}
}
return prev_back;
}

Element* front() {
return m_front.load(std::memory_order_relaxed);
}

Element* back() {
return m_back.load(std::memory_order_relaxed);
}

bool empty() const noexcept {
return m_back.load(std::memory_order_relaxed) == nullptr;
}
Expand Down
31 changes: 26 additions & 5 deletions include/async++/interleaving/runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <concepts>
#include <functional>
#include <iostream>
#include <regex>


namespace asyncpp::interleaving {
Expand All @@ -31,12 +32,30 @@ struct interleaving_printer {
std::ostream& operator<<(std::ostream& os, const interleaving_printer& il);


generator<interleaving> run_all(std::function<std::any()> fixture, std::vector<std::function<void(std::any&)>> threads, std::vector<std::string_view> names = {});
class filter {
public:
filter() : filter(".*") {}
explicit filter(std::string_view file_regex) : m_files(file_regex.begin(), file_regex.end()) {}

bool operator()(const sequence_point& point) const;

private:
std::regex m_files;
};


generator<interleaving> run_all(std::function<std::any()> fixture,
std::vector<std::function<void(std::any&)>> threads,
std::vector<std::string_view> names = {},
filter filter_ = {});


template <class Fixture, class Input>
requires std::convertible_to<Fixture&, Input>
generator<interleaving> run_all(std::function<Fixture()> fixture, std::vector<std::function<void(Input)>> threads, std::vector<std::string_view> names = {}) {
generator<interleaving> run_all(std::function<Fixture()> fixture,
std::vector<std::function<void(Input)>> threads,
std::vector<std::string_view> names = {},
filter filter_ = {}) {
std::function<std::any()> wrapped_init = [fixture = std::move(fixture)]() -> std::any {
if constexpr (!std::is_void_v<Fixture>) {
return std::any(fixture());
Expand All @@ -51,18 +70,20 @@ generator<interleaving> run_all(std::function<Fixture()> fixture, std::vector<st
});
});

return run_all(std::move(wrapped_init), std::move(wrapped_threads), std::move(names));
return run_all(std::move(wrapped_init), std::move(wrapped_threads), std::move(names), filter_);
}


inline generator<interleaving> run_all(std::vector<std::function<void()>> threads, std::vector<std::string_view> names = {}) {
inline generator<interleaving> run_all(std::vector<std::function<void()>> threads,
std::vector<std::string_view> names = {},
filter filter_ = {}) {
std::vector<std::function<void(std::any&)>> wrapped_threads;
std::ranges::transform(threads, std::back_inserter(wrapped_threads), [](auto& thread) {
return std::function<void(std::any&)>([thread = std::move(thread)](std::any&) {
return thread();
});
});
return run_all([] { return std::any(); }, std::move(wrapped_threads), std::move(names));
return run_all([] { return std::any(); }, std::move(wrapped_threads), std::move(names), filter_);
}

} // namespace asyncpp::interleaving
179 changes: 179 additions & 0 deletions include/async++/lock.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#pragma once

#include <cassert>
#include <coroutine>
#include <optional>
#include <utility>


namespace asyncpp {

template <class Mutex>
class mutex_lock {
friend Mutex;

public:
mutex_lock(mutex_lock&&) = default;
mutex_lock& operator=(mutex_lock&&) = default;
mutex_lock(const mutex_lock&) = delete;
mutex_lock& operator=(const mutex_lock&) = delete;
Mutex& mutex() const noexcept {
return *m_mtx;
}

private:
mutex_lock(Mutex* mtx) : m_mtx(mtx) {}
Mutex* m_mtx = nullptr;
};


template <class Mutex>
class mutex_shared_lock {
friend Mutex;

public:
mutex_shared_lock(mutex_shared_lock&&) = default;
mutex_shared_lock& operator=(mutex_shared_lock&&) = default;
mutex_shared_lock(const mutex_shared_lock&) = delete;
mutex_shared_lock& operator=(const mutex_shared_lock&) = delete;
Mutex& mutex() const noexcept {
return *m_mtx;
}

private:
mutex_shared_lock(Mutex* mtx) : m_mtx(mtx) {}
Mutex* m_mtx = nullptr;
};


template <class Mutex>
class unique_lock {
using mutex_awaitable = std::invoke_result_t<decltype(&Mutex::unique), Mutex*>;
struct awaitable {
unique_lock* m_lock;
mutex_awaitable m_awaitable;

auto await_ready() noexcept {
return m_awaitable.await_ready();
}

template <class Promise>
auto await_suspend(std::coroutine_handle<Promise> enclosing) noexcept {
return m_awaitable.await_suspend(enclosing);
}

void await_resume() noexcept {
m_awaitable.await_resume();
m_lock->m_owned = true;
}
};

public:
unique_lock(Mutex& mtx) noexcept : m_mtx(&mtx) {}
unique_lock(mutex_lock<Mutex>&& lk) noexcept : m_mtx(&lk.mutex()), m_owned(true) {}

bool try_lock() noexcept {
assert(!owns_lock());
m_owned = m_mtx->try_lock();
return m_owned;
}

auto operator co_await() noexcept {
assert(!owns_lock());
return awaitable(this, m_mtx->unique());
}

void unlock() noexcept {
assert(owns_lock());
m_mtx->unlock();
m_owned = false;
}

Mutex& mutex() const noexcept {
return *m_mtx;
}

bool owns_lock() const noexcept {
return m_owned;
}

operator bool() const noexcept {
return owns_lock();
}

private:
Mutex* m_mtx;
bool m_owned = false;
};


template <class Mutex>
unique_lock(mutex_lock<Mutex>&& lk) -> unique_lock<Mutex>;


template <class Mutex>
class shared_lock {
using mutex_awaitable = std::invoke_result_t<decltype(&Mutex::shared), Mutex*>;
struct awaitable {
shared_lock* m_lock;
mutex_awaitable m_awaitable;

auto await_ready() noexcept {
return m_awaitable.await_ready();
}

template <class Promise>
auto await_suspend(std::coroutine_handle<Promise> enclosing) noexcept {
return m_awaitable.await_suspend(enclosing);
}

void await_resume() noexcept {
m_awaitable.await_resume();
m_lock->m_owned = true;
}
};

public:
shared_lock(Mutex& mtx) noexcept : m_mtx(&mtx) {}
shared_lock(mutex_shared_lock<Mutex> lk) noexcept : m_mtx(&lk.mutex()), m_owned(true) {}

bool try_lock() noexcept {
assert(!owns_lock());
m_owned = m_mtx->try_lock_shared();
return m_owned;
}

auto operator co_await() noexcept {
assert(!owns_lock());
return awaitable(this, m_mtx->shared());
}

void unlock() noexcept {
assert(owns_lock());
m_mtx->unlock_shared();
m_owned = false;
}

Mutex& mutex() const noexcept {
return *m_mtx;
}

bool owns_lock() const noexcept {
return m_owned;
}

operator bool() const noexcept {
return owns_lock();
}

private:
Mutex* m_mtx;
bool m_owned = false;
};


template <class Mutex>
shared_lock(mutex_shared_lock<Mutex> lk) -> shared_lock<Mutex>;


} // namespace asyncpp
Loading

0 comments on commit f6b25e3

Please sign in to comment.