diff --git a/include/asyncpp/generator.hpp b/include/asyncpp/generator.hpp index d4ea776..0ca53bd 100644 --- a/include/asyncpp/generator.hpp +++ b/include/asyncpp/generator.hpp @@ -10,16 +10,16 @@ namespace asyncpp { -template +template class generator; namespace impl_generator { - template - struct promise { + template + struct promise : allocator_aware_promise { auto get_return_object() noexcept { - return generator(this); + return generator(this); } constexpr auto initial_suspend() const noexcept { @@ -49,10 +49,10 @@ namespace impl_generator { task_result m_result; }; - template + template class iterator { public: - using promise_type = promise; + using promise_type = promise; using value_type = std::remove_reference_t; using difference_type = ptrdiff_t; using pointer = value_type*; @@ -100,16 +100,16 @@ namespace impl_generator { promise_type* m_promise = nullptr; }; - static_assert(std::input_iterator>); + static_assert(std::input_iterator>); } // namespace impl_generator -template +template class [[nodiscard]] generator { public: - using promise_type = impl_generator::promise; - using iterator = impl_generator::iterator; + using promise_type = impl_generator::promise; + using iterator = impl_generator::iterator; generator(promise_type* promise) : m_promise(promise) {} generator() = default; diff --git a/include/asyncpp/promise.hpp b/include/asyncpp/promise.hpp index 31d0667..c37a938 100644 --- a/include/asyncpp/promise.hpp +++ b/include/asyncpp/promise.hpp @@ -1,7 +1,9 @@ #pragma once -#include +#include #include +#include +#include #include #include #include @@ -102,38 +104,86 @@ struct result_promise { }; -namespace impl { +template +struct allocator_aware_promise { +private: + template + struct aligned_block { + alignas(Alignment) std::byte memory[Alignment]; + }; - class leak_checked_promise { - using snapshot_type = std::pair; + using dealloc_t = void (*)(void*, size_t); - public: -#ifdef ASYNCPP_BUILD_TESTS - leak_checked_promise() noexcept { num_alive.fetch_add(1, std::memory_order_relaxed); } - leak_checked_promise(const leak_checked_promise&) noexcept { num_alive.fetch_add(1, std::memory_order_relaxed); } - leak_checked_promise(leak_checked_promise&&) noexcept = delete; - leak_checked_promise& operator=(const leak_checked_promise&) noexcept { return *this; } - leak_checked_promise& operator=(leak_checked_promise&&) noexcept = delete; - ~leak_checked_promise() { - num_alive.fetch_sub(1, std::memory_order_relaxed); - version.fetch_add(1, std::memory_order_relaxed); - } -#endif + static constexpr auto dealloc_offset(size_t size) { + constexpr auto dealloc_alignment = alignof(dealloc_t); + return (size + dealloc_alignment - 1) / dealloc_alignment * dealloc_alignment; + } - static snapshot_type snapshot() noexcept { - return { num_alive.load(std::memory_order_relaxed), version.load(std::memory_order_relaxed) }; - } + template + requires std::convertible_to || std::is_void_v + static void* allocate(size_t size, std::allocator_arg_t, const Alloc_& alloc, Args&&...) { + static constexpr auto alloc_alignment = alignof(Alloc_); + static constexpr auto promise_alignment = std::max({ alignof(std::max_align_t), alignof(Args)... }); + static constexpr auto alignment = std::max(alloc_alignment, promise_alignment); + using block_t = aligned_block; + using alloc_t = typename std::allocator_traits::template rebind_alloc; + static_assert(alignof(alloc_t) <= alignof(Alloc_)); + + static constexpr auto alloc_offset = [](size_t size) { + const auto extended_size = dealloc_offset(size) + sizeof(dealloc_t); + return (extended_size + alloc_alignment - 1) / alloc_alignment * alloc_alignment; + }; + + static constexpr auto total_size = [](size_t size) { + return alloc_offset(size) + sizeof(alloc_t); + }; + + static constexpr dealloc_t dealloc = [](void* ptr, size_t size) { + auto& alloc = *reinterpret_cast(static_cast(ptr) + alloc_offset(size)); + auto moved = std::move(alloc); + alloc.~alloc_t(); + const auto num_blocks = (total_size(size) + sizeof(block_t) - 1) / sizeof(block_t); + std::allocator_traits::deallocate(moved, static_cast(ptr), num_blocks); + }; + + auto rebound_alloc = alloc_t(alloc); + const auto num_blocks = (total_size(size) + sizeof(block_t) - 1) / sizeof(block_t); + const auto ptr = std::allocator_traits::allocate(rebound_alloc, num_blocks); + const auto dealloc_ptr = reinterpret_cast(reinterpret_cast(ptr) + dealloc_offset(size)); + const auto alloc_ptr = reinterpret_cast(reinterpret_cast(ptr) + alloc_offset(size)); + new (dealloc_ptr) dealloc_t(dealloc); + new (alloc_ptr) alloc_t(std::move(rebound_alloc)); + return ptr; + } - static bool check(snapshot_type s) noexcept { - const auto current = snapshot(); - return current.first == s.first && current.second > s.second; - } +public: + template + requires std::convertible_to || std::is_void_v + void* operator new(size_t size, std::allocator_arg_t, const Alloc_& alloc, Args&&... args) { + return allocate(size, std::allocator_arg, alloc, std::forward(args)...); + } - private: - inline static std::atomic_intptr_t num_alive = 0; - inline static std::atomic_intptr_t version = 0; - }; + template + requires std::convertible_to || std::is_void_v + void* operator new(size_t size, Self&, std::allocator_arg_t, const Alloc_& alloc, Args&&... args) { + return allocate(size, std::allocator_arg, alloc, std::forward(args)...); + } -} // namespace impl + template + requires(... && !std::convertible_to) + void* operator new(size_t size, Args&&... args) { + if constexpr (!std::is_void_v) { + return allocate(size, std::allocator_arg, Alloc{}, std::forward(args)...); + } + else { + return allocate(size, std::allocator_arg, std::allocator{}, std::forward(args)...); + } + } + + void operator delete(void* ptr, size_t size) { + const auto dealloc_ptr = reinterpret_cast(static_cast(ptr) + dealloc_offset(size)); + (*dealloc_ptr)(ptr, size); + } +}; } // namespace asyncpp \ No newline at end of file diff --git a/include/asyncpp/shared_task.hpp b/include/asyncpp/shared_task.hpp deleted file mode 100644 index 0ac8210..0000000 --- a/include/asyncpp/shared_task.hpp +++ /dev/null @@ -1,141 +0,0 @@ -#pragma once - -#include "event.hpp" -#include "memory/rc_ptr.hpp" -#include "promise.hpp" -#include "scheduler.hpp" -#include "testing/suspension_point.hpp" - -#include -#include - - -namespace asyncpp { - - -template -class shared_task; - - -namespace impl_shared_task { - - template - struct promise : result_promise, resumable_promise, schedulable_promise, impl::leak_checked_promise, rc_from_this { - struct final_awaitable { - constexpr bool await_ready() const noexcept { return false; } - void await_suspend(std::coroutine_handle handle) noexcept { - auto& owner = handle.promise(); - owner.m_event.set(owner.m_result); - auto self = std::move(owner.m_self); // owner.m_self.reset() call method on owner after it's been deleted. - self.reset(); - } - constexpr void await_resume() const noexcept {} - }; - - auto get_return_object() { - return shared_task(rc_ptr(this)); - } - - auto initial_suspend() noexcept { - return std::suspend_always{}; - } - - auto final_suspend() noexcept { - return final_awaitable{}; - } - - auto handle() -> std::coroutine_handle<> final { - return std::coroutine_handle::from_promise(*this); - } - - void resume() final { - return m_scheduler ? m_scheduler->schedule(*this) : handle().resume(); - } - - void start() noexcept { - if (!INTERLEAVED(m_started.test_and_set(std::memory_order_relaxed))) { - m_self.reset(this); - resume(); - } - } - - static auto await(rc_ptr pr); - - bool ready() const { - return m_event.ready(); - } - - void destroy() { - handle().destroy(); - } - - private: - std::atomic_flag m_started; - broadcast_event m_event; - rc_ptr m_self; - }; - - - template - struct awaitable : broadcast_event::awaitable { - using base = typename broadcast_event::awaitable; - - rc_ptr> m_awaited = nullptr; - - awaitable(base base, rc_ptr> awaited) : broadcast_event::awaitable(std::move(base)), m_awaited(awaited) { - assert(m_awaited); - } - }; - - - template - auto promise::await(rc_ptr pr) { - assert(pr); - pr->start(); - auto base = pr->m_event.operator co_await(); - return awaitable{ std::move(base), std::move(pr) }; - } - -} // namespace impl_shared_task - - -template -class shared_task { -public: - using promise_type = impl_shared_task::promise; - - shared_task() = default; - shared_task(rc_ptr promise) : m_promise(std::move(promise)) {} - - bool valid() const { - return !!m_promise; - } - - bool ready() const { - assert(valid()); - return m_promise->ready(); - } - - void launch() { - assert(valid()); - m_promise->start(); - } - - void bind(scheduler& scheduler) { - assert(valid()); - if (m_promise) { - m_promise->m_scheduler = &scheduler; - } - } - - auto operator co_await() { - assert(valid()); - return promise_type::await(m_promise); - } - -private: - rc_ptr m_promise; -}; - - -} // namespace asyncpp \ No newline at end of file diff --git a/include/asyncpp/stream.hpp b/include/asyncpp/stream.hpp index 214830f..eebfde9 100644 --- a/include/asyncpp/stream.hpp +++ b/include/asyncpp/stream.hpp @@ -14,7 +14,7 @@ namespace asyncpp { -template +template class stream; @@ -60,8 +60,8 @@ namespace impl_stream { }; - template - struct promise : resumable_promise, schedulable_promise, impl::leak_checked_promise, rc_from_this { + template + struct promise : resumable_promise, schedulable_promise, rc_from_this, allocator_aware_promise { struct yield_awaitable { constexpr bool await_ready() const noexcept { return false; } @@ -76,7 +76,7 @@ namespace impl_stream { }; auto get_return_object() noexcept { - return stream(rc_ptr(this)); + return stream(rc_ptr(this)); } constexpr std::suspend_always initial_suspend() const noexcept { @@ -124,14 +124,14 @@ namespace impl_stream { }; - template + template struct awaitable { using base = typename event>>::awaitable; base m_base; - rc_ptr> m_awaited = nullptr; + rc_ptr> m_awaited = nullptr; - awaitable(base base, rc_ptr> awaited) : m_base(base), m_awaited(awaited) {} + awaitable(base base, rc_ptr> awaited) : m_base(base), m_awaited(awaited) {} bool await_ready() noexcept { assert(m_awaited->has_event()); @@ -151,10 +151,10 @@ namespace impl_stream { }; - template - auto promise::await() noexcept { + template + auto promise::await() noexcept { m_event.emplace(); - auto aw = awaitable(m_event->operator co_await(), rc_ptr(this)); + auto aw = awaitable(m_event->operator co_await(), rc_ptr(this)); resume(); return aw; } @@ -162,10 +162,10 @@ namespace impl_stream { } // namespace impl_stream -template +template class [[nodiscard]] stream { public: - using promise_type = impl_stream::promise; + using promise_type = impl_stream::promise; stream() = default; stream(const stream&) = delete; diff --git a/include/asyncpp/task.hpp b/include/asyncpp/task.hpp index 9782d2f..3da938e 100644 --- a/include/asyncpp/task.hpp +++ b/include/asyncpp/task.hpp @@ -6,23 +6,17 @@ #include "scheduler.hpp" #include "testing/suspension_point.hpp" -#include #include -#include -#include +#include namespace asyncpp { -template -class task; - - namespace impl_task { - template - struct promise : result_promise, resumable_promise, schedulable_promise, impl::leak_checked_promise, rc_from_this { + template + struct promise : result_promise, resumable_promise, schedulable_promise, rc_from_this, allocator_aware_promise { struct final_awaitable { constexpr bool await_ready() const noexcept { return false; } void await_suspend(std::coroutine_handle handle) const noexcept { @@ -35,14 +29,14 @@ namespace impl_task { }; auto get_return_object() { - return task(rc_ptr(this)); + return Task(rc_ptr(this)); } - constexpr auto initial_suspend() noexcept { + constexpr auto initial_suspend() const noexcept { return std::suspend_always{}; } - auto final_suspend() noexcept { + auto final_suspend() const noexcept { return final_awaitable{}; } @@ -54,7 +48,7 @@ namespace impl_task { return m_scheduler ? m_scheduler->schedule(*this) : handle().resume(); } - void start() noexcept { + void start() { if (!INTERLEAVED(m_started.test_and_set(std::memory_order_relaxed))) { m_self.reset(this); resume(); @@ -73,42 +67,42 @@ namespace impl_task { private: std::atomic_flag m_started; - event m_event; + Event m_event; rc_ptr m_self; }; - template - struct awaitable : event::awaitable { - using base = typename event::awaitable; - - rc_ptr> m_awaited = nullptr; + template + struct awaitable : Event::awaitable { + rc_ptr m_awaited = nullptr; - awaitable(base base, rc_ptr> awaited) : event::awaitable(std::move(base)), m_awaited(awaited) { + awaitable(typename Event::awaitable base, rc_ptr awaited) + : Event::awaitable(std::move(base)), m_awaited(awaited) { assert(m_awaited); } }; - template - auto promise::await(rc_ptr pr) { + template + auto promise::await(rc_ptr pr) { assert(pr); pr->start(); auto base = pr->m_event.operator co_await(); - return awaitable{ std::move(base), std::move(pr) }; + return awaitable{ std::move(base), std::move(pr) }; } } // namespace impl_task -template +template class [[nodiscard]] task { public: - using promise_type = impl_task::promise; + using promise_type = impl_task::promise>; task() = default; task(const task& rhs) = delete; task& operator=(const task& rhs) = delete; task(task&& rhs) noexcept = default; task& operator=(task&& rhs) noexcept = default; + ~task() = default; task(rc_ptr promise) : m_promise(std::move(promise)) {} bool valid() const { @@ -142,4 +136,43 @@ class [[nodiscard]] task { }; +template +class [[nodiscard]] shared_task { +public: + using promise_type = impl_task::promise>; + + shared_task() = default; + shared_task(rc_ptr promise) : m_promise(std::move(promise)) {} + + bool valid() const { + return !!m_promise; + } + + bool ready() const { + assert(valid()); + return m_promise->ready(); + } + + void launch() { + assert(valid()); + m_promise->start(); + } + + void bind(scheduler& scheduler) { + assert(valid()); + if (m_promise) { + m_promise->m_scheduler = &scheduler; + } + } + + auto operator co_await() { + assert(valid()); + return promise_type::await(m_promise); + } + +private: + rc_ptr m_promise; +}; + + } // namespace asyncpp \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7e6b8c0..232db81 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,6 +18,9 @@ target_sources(test test_event.cpp test_sleep.cpp testing/test_interleaver.cpp + helper_schedulers.hpp + monitor_task.hpp + monitor_allocator.hpp ) diff --git a/test/monitor_allocator.hpp b/test/monitor_allocator.hpp new file mode 100644 index 0000000..04cfeb7 --- /dev/null +++ b/test/monitor_allocator.hpp @@ -0,0 +1,58 @@ +#pragma once + +#include +#include + + +namespace impl_monitor_allocator { + +struct counters { + std::atomic_size_t num_allocations; + std::atomic_size_t num_deallocations; +}; + +} // namespace impl_monitor_allocator + + +template +class monitor_allocator { +public: + using value_type = T; + + template + friend class monitor_allocator; + + constexpr monitor_allocator() + : m_counters(std::make_shared()) {} + + constexpr monitor_allocator(const monitor_allocator& other) noexcept = default; + + template + constexpr monitor_allocator(const monitor_allocator& other) noexcept + : m_counters(other.m_counters) {} + + T* allocate(size_t n) { + m_counters->num_allocations.fetch_add(1, std::memory_order_relaxed); + return std::allocator().allocate(n); + } + + void deallocate(T* ptr, size_t n) { + m_counters->num_deallocations.fetch_add(1, std::memory_order_relaxed); + return std::allocator().deallocate(ptr, n); + } + + size_t get_num_allocations() const { + return m_counters->num_allocations.load(std::memory_order_relaxed); + } + + size_t get_num_deallocations() const { + return m_counters->num_allocations.load(std::memory_order_relaxed); + } + + size_t get_num_live_objects() const { + return get_num_allocations() - get_num_deallocations(); + } + +private: + std::shared_ptr m_counters; +}; \ No newline at end of file diff --git a/test/test_stream.cpp b/test/test_stream.cpp index 61c6147..3d13713 100644 --- a/test/test_stream.cpp +++ b/test/test_stream.cpp @@ -1,3 +1,4 @@ +#include "monitor_allocator.hpp" #include "monitor_task.hpp" #include @@ -88,18 +89,64 @@ TEST_CASE("Stream: destroy", "[Task]") { static const auto coro = []() -> stream { co_yield 0; }; SECTION("no execution") { - const auto before = impl::leak_checked_promise::snapshot(); - { - auto s = coro(); - } - REQUIRE(impl::leak_checked_promise::check(before)); + auto s = coro(); } SECTION("synced") { - const auto before = impl::leak_checked_promise::snapshot(); - { - auto s = coro(); - void(join(s)); - } - REQUIRE(impl::leak_checked_promise::check(before)); + auto s = coro(); + void(join(s)); } } + + +template +auto allocator_free(std::allocator_arg_t, monitor_allocator<>& alloc) -> Stream { + co_yield alloc; +} + + +template +struct allocator_object { + auto member_coro(std::allocator_arg_t, monitor_allocator<>& alloc) -> Stream { + co_yield alloc; + } +}; + + +TEST_CASE("Task: allocator erased", "[Task]") { + monitor_allocator<> alloc; + using stream_t = stream&>; + + SECTION("free function") { + auto task = allocator_free(std::allocator_arg, alloc); + void(*join(task)); + REQUIRE(alloc.get_num_allocations() == 1); + REQUIRE(alloc.get_num_live_objects() == 0); + } + SECTION("member function") { + allocator_object obj; + auto task = obj.member_coro(std::allocator_arg, alloc); + void(*join(task)); + REQUIRE(alloc.get_num_allocations() == 1); + REQUIRE(alloc.get_num_live_objects() == 0); + } +} + + +TEST_CASE("Task: allocator explicit", "[Task]") { + monitor_allocator<> alloc; + using stream_t = stream&, monitor_allocator<>>; + + SECTION("free function") { + auto task = allocator_free(std::allocator_arg, alloc); + void(*join(task)); + REQUIRE(alloc.get_num_allocations() == 1); + REQUIRE(alloc.get_num_live_objects() == 0); + } + SECTION("member function") { + allocator_object obj; + auto task = obj.member_coro(std::allocator_arg, alloc); + void(*join(task)); + REQUIRE(alloc.get_num_allocations() == 1); + REQUIRE(alloc.get_num_live_objects() == 0); + } +} \ No newline at end of file diff --git a/test/test_task.cpp b/test/test_task.cpp index 8bf5bca..c6907fa 100644 --- a/test/test_task.cpp +++ b/test/test_task.cpp @@ -1,7 +1,7 @@ #include "helper_schedulers.hpp" +#include "monitor_allocator.hpp" #include -#include #include #include @@ -118,9 +118,7 @@ TEMPLATE_TEST_CASE("Task: abandon (not started)", "[Task]", task, shared_t static const auto coro = []() -> TestType { co_return; }; - const auto before = impl::leak_checked_promise::snapshot(); static_cast(coro()); - REQUIRE(impl::leak_checked_promise::check(before)); } @@ -167,11 +165,63 @@ TEMPLATE_TEST_CASE("Task: co_await exception", "[Task]", task, shared_task static int value = 42; static const auto coro = []() -> TestType { throw std::runtime_error("test"); - co_return; + co_return; // This statement is necessary! }; static const auto enclosing = []() -> TestType { REQUIRE_THROWS_AS(co_await coro(), std::runtime_error); }; auto task = enclosing(); join(task); +} + + +template +auto allocator_free(std::allocator_arg_t, monitor_allocator<>& alloc) -> Task { + co_return alloc; +} + + +template +struct allocator_object { + auto member_coro(std::allocator_arg_t, monitor_allocator<>& alloc) -> Task { + co_return alloc; + } +}; + + +TEMPLATE_TEST_CASE("Task: allocator erased", "[Task]", task&>, shared_task&>) { + monitor_allocator<> alloc; + + SECTION("free function") { + auto task = allocator_free(std::allocator_arg, alloc); + join(task); + REQUIRE(alloc.get_num_allocations() == 1); + REQUIRE(alloc.get_num_live_objects() == 0); + } + SECTION("member function") { + allocator_object obj; + auto task = obj.member_coro(std::allocator_arg, alloc); + join(task); + REQUIRE(alloc.get_num_allocations() == 1); + REQUIRE(alloc.get_num_live_objects() == 0); + } +} + + +TEMPLATE_TEST_CASE("Task: allocator explicit", "[Task]", (task&, monitor_allocator<>>), (shared_task&, monitor_allocator<>>)) { + monitor_allocator<> alloc; + + SECTION("free function") { + auto task = allocator_free(std::allocator_arg, alloc); + join(task); + REQUIRE(alloc.get_num_allocations() == 1); + REQUIRE(alloc.get_num_live_objects() == 0); + } + SECTION("member function") { + allocator_object obj; + auto task = obj.member_coro(std::allocator_arg, alloc); + join(task); + REQUIRE(alloc.get_num_allocations() == 1); + REQUIRE(alloc.get_num_live_objects() == 0); + } } \ No newline at end of file