Skip to content

Commit

Permalink
make all coroutine promises allocator aware
Browse files Browse the repository at this point in the history
  • Loading branch information
petiaccja committed Feb 25, 2024
1 parent 1814683 commit e2848c3
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 72 deletions.
20 changes: 10 additions & 10 deletions include/asyncpp/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
namespace asyncpp {


template <class T>
template <class T, class Alloc>
class generator;


namespace impl_generator {

template <class T>
struct promise {
template <class T, class Alloc>
struct promise : allocator_aware_promise<Alloc> {
auto get_return_object() noexcept {
return generator<T>(this);
return generator<T, Alloc>(this);
}

constexpr auto initial_suspend() const noexcept {
Expand Down Expand Up @@ -49,10 +49,10 @@ namespace impl_generator {
task_result<T> m_result;
};

template <class T>
template <class T, class Alloc>
class iterator {
public:
using promise_type = promise<T>;
using promise_type = promise<T, Alloc>;
using value_type = std::remove_reference_t<T>;
using difference_type = ptrdiff_t;
using pointer = value_type*;
Expand Down Expand Up @@ -100,16 +100,16 @@ namespace impl_generator {
promise_type* m_promise = nullptr;
};

static_assert(std::input_iterator<iterator<int>>);
static_assert(std::input_iterator<iterator<int, void>>);

} // namespace impl_generator


template <class T>
template <class T, class Alloc = void>
class [[nodiscard]] generator {
public:
using promise_type = impl_generator::promise<T>;
using iterator = impl_generator::iterator<T>;
using promise_type = impl_generator::promise<T, Alloc>;
using iterator = impl_generator::iterator<T, Alloc>;

generator(promise_type* promise) : m_promise(promise) {}
generator() = default;
Expand Down
105 changes: 77 additions & 28 deletions include/asyncpp/promise.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include <atomic>
#include <coroutine>
#include <cstddef>
#include <memory>
#include <optional>
#include <utility>
#include <variant>
Expand Down Expand Up @@ -102,38 +103,86 @@ struct result_promise<void> {
};


namespace impl {
template <class Alloc>
struct allocator_aware_promise {
private:
template <size_t Alignment = alignof(std::max_align_t)>
struct aligned_block {
alignas(Alignment) std::byte memory[Alignment];
};

class leak_checked_promise {
using snapshot_type = std::pair<intptr_t, intptr_t>;
using dealloc_t = void (*)(void*, size_t);

public:
#ifdef ASYNCPP_BUILD_TESTS
leak_checked_promise() noexcept { num_alive.fetch_add(1, std::memory_order_relaxed); }
leak_checked_promise(const leak_checked_promise&) noexcept { num_alive.fetch_add(1, std::memory_order_relaxed); }
leak_checked_promise(leak_checked_promise&&) noexcept = delete;
leak_checked_promise& operator=(const leak_checked_promise&) noexcept { return *this; }
leak_checked_promise& operator=(leak_checked_promise&&) noexcept = delete;
~leak_checked_promise() {
num_alive.fetch_sub(1, std::memory_order_relaxed);
version.fetch_add(1, std::memory_order_relaxed);
}
#endif
static constexpr auto dealloc_offset(size_t size) {
static constexpr auto dealloc_alignment = alignof(dealloc_t);
return (size + dealloc_alignment - 1) / dealloc_alignment * dealloc_alignment;
}

static snapshot_type snapshot() noexcept {
return { num_alive.load(std::memory_order_relaxed), version.load(std::memory_order_relaxed) };
}
template <class Alloc_, class... Args>
requires std::convertible_to<Alloc_, Alloc> || std::is_void_v<Alloc>
static void* allocate(size_t size, std::allocator_arg_t, const Alloc_& alloc, Args&&...) {
static constexpr auto alloc_alignment = alignof(Alloc_);
static constexpr auto promise_alignment = std::max({ alignof(std::max_align_t), alignof(Args)... });
static constexpr auto alignment = std::max(alloc_alignment, promise_alignment);
using block_t = aligned_block<alignment>;
using alloc_t = typename std::allocator_traits<Alloc_>::template rebind_alloc<block_t>;
static_assert(alignof(alloc_t) <= alignof(Alloc_));

static constexpr auto alloc_offset = [](size_t size) {
const auto extended_size = dealloc_offset(size) + sizeof(dealloc_t);
return (extended_size + alloc_alignment - 1) / alloc_alignment * alloc_alignment;
};

static constexpr auto total_size = [](size_t size) {
return alloc_offset(size) + sizeof(alloc_t);
};

static constexpr dealloc_t dealloc = [](void* ptr, size_t size) {
auto& alloc = *reinterpret_cast<alloc_t*>(static_cast<std::byte*>(ptr) + alloc_offset(size));
auto moved = std::move(alloc);
alloc.~alloc_t();
const auto num_blocks = (total_size(size) + sizeof(block_t) - 1) / sizeof(block_t);
std::allocator_traits<alloc_t>::deallocate(moved, static_cast<block_t*>(ptr), num_blocks);
};

auto rebound_alloc = alloc_t(alloc);
const auto num_blocks = (total_size(size) + sizeof(block_t) - 1) / sizeof(block_t);
const auto ptr = std::allocator_traits<alloc_t>::allocate(rebound_alloc, num_blocks);
const auto dealloc_ptr = reinterpret_cast<dealloc_t*>(reinterpret_cast<std::byte*>(ptr) + dealloc_offset(size));
const auto alloc_ptr = reinterpret_cast<alloc_t*>(reinterpret_cast<std::byte*>(ptr) + alloc_offset(size));
new (dealloc_ptr) dealloc_t(dealloc);
new (alloc_ptr) alloc_t(std::move(rebound_alloc));
return ptr;
}

static bool check(snapshot_type s) noexcept {
const auto current = snapshot();
return current.first == s.first && current.second > s.second;
}
public:
template <class Alloc_, class... Args>
requires std::convertible_to<Alloc_, Alloc> || std::is_void_v<Alloc>
void* operator new(size_t size, std::allocator_arg_t, const Alloc_& alloc, Args&&... args) {
return allocate(size, std::allocator_arg, alloc, std::forward<Args>(args)...);
}

private:
inline static std::atomic_intptr_t num_alive = 0;
inline static std::atomic_intptr_t version = 0;
};
template <class Self, class Alloc_, class... Args>
requires std::convertible_to<Alloc_, Alloc> || std::is_void_v<Alloc>
void* operator new(size_t size, Self&, std::allocator_arg_t, const Alloc_& alloc, Args&&... args) {
return allocate(size, std::allocator_arg, alloc, std::forward<Args>(args)...);
}

} // namespace impl
template <class... Args>
requires(... && !std::convertible_to<Args, std::allocator_arg_t>)
void* operator new(size_t size, Args&&... args) {
if constexpr (!std::is_void_v<Alloc>) {
return allocate(size, std::allocator_arg, Alloc{}, std::forward<Args>(args)...);
}
else {
return allocate(size, std::allocator_arg, std::allocator<std::byte>{}, std::forward<Args>(args)...);
}
}

void operator delete(void* ptr, size_t size) {
const auto dealloc_ptr = reinterpret_cast<dealloc_t*>(static_cast<std::byte*>(ptr) + dealloc_offset(size));
(*dealloc_ptr)(ptr, size);
}
};

} // namespace asyncpp
24 changes: 12 additions & 12 deletions include/asyncpp/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace asyncpp {


template <class T>
template <class T, class Alloc>
class stream;


Expand Down Expand Up @@ -60,8 +60,8 @@ namespace impl_stream {
};


template <class T>
struct promise : resumable_promise, schedulable_promise, impl::leak_checked_promise, rc_from_this {
template <class T, class Alloc>
struct promise : resumable_promise, schedulable_promise, rc_from_this, allocator_aware_promise<Alloc> {
struct yield_awaitable {
constexpr bool await_ready() const noexcept { return false; }

Expand All @@ -76,7 +76,7 @@ namespace impl_stream {
};

auto get_return_object() noexcept {
return stream<T>(rc_ptr(this));
return stream<T, Alloc>(rc_ptr(this));
}

constexpr std::suspend_always initial_suspend() const noexcept {
Expand Down Expand Up @@ -124,14 +124,14 @@ namespace impl_stream {
};


template <class T>
template <class T, class Alloc>
struct awaitable {
using base = typename event<std::optional<wrapper_type<T>>>::awaitable;

base m_base;
rc_ptr<promise<T>> m_awaited = nullptr;
rc_ptr<promise<T, Alloc>> m_awaited = nullptr;

awaitable(base base, rc_ptr<promise<T>> awaited) : m_base(base), m_awaited(awaited) {}
awaitable(base base, rc_ptr<promise<T, Alloc>> awaited) : m_base(base), m_awaited(awaited) {}

bool await_ready() noexcept {
assert(m_awaited->has_event());
Expand All @@ -151,21 +151,21 @@ namespace impl_stream {
};


template <class T>
auto promise<T>::await() noexcept {
template <class T, class Alloc>
auto promise<T, Alloc>::await() noexcept {
m_event.emplace();
auto aw = awaitable<T>(m_event->operator co_await(), rc_ptr(this));
auto aw = awaitable<T, Alloc>(m_event->operator co_await(), rc_ptr(this));
resume();
return aw;
}

} // namespace impl_stream


template <class T>
template <class T, class Alloc = void>
class [[nodiscard]] stream {
public:
using promise_type = impl_stream::promise<T>;
using promise_type = impl_stream::promise<T, Alloc>;

stream() = default;
stream(const stream&) = delete;
Expand Down
18 changes: 10 additions & 8 deletions include/asyncpp/task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
#include "testing/suspension_point.hpp"

#include <cassert>
#include <fstream>


namespace asyncpp {


namespace impl_task {

template <class T, class Task, class Event>
struct promise : result_promise<T>, resumable_promise, schedulable_promise, rc_from_this {
template <class T, class Alloc, class Task, class Event>
struct promise : result_promise<T>, resumable_promise, schedulable_promise, rc_from_this, allocator_aware_promise<Alloc> {
struct final_awaitable {
constexpr bool await_ready() const noexcept { return false; }
void await_suspend(std::coroutine_handle<promise> handle) const noexcept {
Expand Down Expand Up @@ -80,8 +81,8 @@ namespace impl_task {
}
};

template <class T, class Task, class Event>
auto promise<T, Task, Event>::await(rc_ptr<promise> pr) {
template <class T, class Alloc, class Task, class Event>
auto promise<T, Alloc, Task, Event>::await(rc_ptr<promise> pr) {
assert(pr);
pr->start();
auto base = pr->m_event.operator co_await();
Expand All @@ -91,16 +92,17 @@ namespace impl_task {
} // namespace impl_task


template <class T>
template <class T, class Alloc = void>
class [[nodiscard]] task {
public:
using promise_type = impl_task::promise<T, task, event<T>>;
using promise_type = impl_task::promise<T, Alloc, task, event<T>>;

task() = default;
task(const task& rhs) = delete;
task& operator=(const task& rhs) = delete;
task(task&& rhs) noexcept = default;
task& operator=(task&& rhs) noexcept = default;
~task() = default;
task(rc_ptr<promise_type> promise) : m_promise(std::move(promise)) {}

bool valid() const {
Expand Down Expand Up @@ -134,10 +136,10 @@ class [[nodiscard]] task {
};


template <class T>
template <class T, class Alloc = void>
class [[nodiscard]] shared_task {
public:
using promise_type = impl_task::promise<T, shared_task, broadcast_event<T>>;
using promise_type = impl_task::promise<T, Alloc, shared_task, broadcast_event<T>>;

shared_task() = default;
shared_task(rc_ptr<promise_type> promise) : m_promise(std::move(promise)) {}
Expand Down
3 changes: 3 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ target_sources(test
test_event.cpp
test_sleep.cpp
testing/test_interleaver.cpp
helper_schedulers.hpp
monitor_task.hpp
monitor_allocator.hpp
)


Expand Down
58 changes: 58 additions & 0 deletions test/monitor_allocator.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma once

#include <atomic>
#include <memory>


namespace impl_monitor_allocator {

struct counters {
std::atomic_size_t num_allocations;
std::atomic_size_t num_deallocations;
};

} // namespace impl_monitor_allocator


template <class T = std::byte>
class monitor_allocator {
public:
using value_type = T;

template <class U>
friend class monitor_allocator;

constexpr monitor_allocator()
: m_counters(std::make_shared<impl_monitor_allocator::counters>()) {}

constexpr monitor_allocator(const monitor_allocator& other) noexcept = default;

template <class U>
constexpr monitor_allocator(const monitor_allocator<U>& other) noexcept
: m_counters(other.m_counters) {}

T* allocate(size_t n) {
m_counters->num_allocations.fetch_add(1, std::memory_order_relaxed);
return std::allocator<T>().allocate(n);
}

void deallocate(T* ptr, size_t n) {
m_counters->num_deallocations.fetch_add(1, std::memory_order_relaxed);
return std::allocator<T>().deallocate(ptr, n);
}

size_t get_num_allocations() const {
return m_counters->num_allocations.load(std::memory_order_relaxed);
}

size_t get_num_deallocations() const {
return m_counters->num_allocations.load(std::memory_order_relaxed);
}

size_t get_num_live_objects() const {
return get_num_allocations() - get_num_deallocations();
}

private:
std::shared_ptr<impl_monitor_allocator::counters> m_counters;
};
Loading

0 comments on commit e2848c3

Please sign in to comment.