Skip to content

Commit

Permalink
join any awaitable (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
petiaccja authored Nov 19, 2023
1 parent 085298b commit e0d29ba
Show file tree
Hide file tree
Showing 15 changed files with 234 additions and 322 deletions.
48 changes: 48 additions & 0 deletions include/async++/concepts.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once

#include <concepts>


namespace asyncpp {

// clang-format off

template <class T>
concept directly_awaitable = requires(std::remove_reference_t<T>& t) {
{ t.await_ready() } -> std::convertible_to<bool>;
{ t.await_resume() };
};


template <class T>
concept indirectly_awaitable = requires(std::remove_reference_t<T>& t) {
{ t.operator co_await() } -> directly_awaitable;
};


template <class T>
concept awaitable = directly_awaitable<T> || indirectly_awaitable<T>;

// clang-format on

template <class T>
struct await_result {};


template <directly_awaitable T>
struct await_result<T> {
using type = decltype(std::declval<T>().await_resume());
};


template <indirectly_awaitable T>
struct await_result<T> {
using type = decltype(std::declval<T>().operator co_await().await_resume());
};


template <class T>
using await_result_t = typename await_result<T>::type;


}; // namespace asyncpp
4 changes: 2 additions & 2 deletions include/async++/interleaving/sequence_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace impl_sp {


#define INTERLEAVED(EXPR) \
[&](std::string_view func) { \
[&](std::string_view func) -> decltype(auto) { \
static ::asyncpp::interleaving::sequence_point sp = { \
false, \
#EXPR, \
Expand All @@ -70,7 +70,7 @@ namespace impl_sp {


#define INTERLEAVED_ACQUIRE(EXPR) \
[&](std::string_view func) { \
[&](std::string_view func) -> decltype(auto) { \
static ::asyncpp::interleaving::sequence_point sp = { \
true, \
#EXPR, \
Expand Down
86 changes: 86 additions & 0 deletions include/async++/join.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#pragma once

#include "concepts.hpp"
#include "interleaving/sequence_point.hpp"
#include "promise.hpp"

#include <coroutine>
#include <exception>
#include <future>


namespace asyncpp {

namespace impl_join {

template <class T>
struct joiner;

template <class T>
struct basic_promise : impl::resumable_promise {
std::promise<T> m_promise;

joiner<T> get_return_object() {
return { m_promise.get_future() };
}

constexpr auto initial_suspend() const noexcept {
return std::suspend_never{};
}

void unhandled_exception() noexcept {
m_promise.set_exception(std::current_exception());
}

constexpr auto final_suspend() const noexcept {
return std::suspend_never{};
}
};

template <class T>
struct promise : basic_promise<T> {
void return_value(T value) noexcept {
this->m_promise.set_value(std::forward<T>(value));
}

void resume() noexcept override {
std::coroutine_handle<promise>::from_promise(*this).resume();
}
};

template <>
struct promise<void> : basic_promise<void> {
void return_void() noexcept {
m_promise.set_value();
}

void resume() noexcept override {
std::coroutine_handle<promise>::from_promise(*this).resume();
}
};

template <class T>
struct joiner {
using promise_type = promise<T>;

std::future<T> future;
};

} // namespace impl_join


template <awaitable Awaitable>
auto join(Awaitable&& object) -> await_result_t<std::remove_reference_t<Awaitable>> {
using T = await_result_t<std::remove_reference_t<Awaitable>>;
auto joiner_ = [&object]() -> impl_join::joiner<T> {
co_return co_await object;
}();
if constexpr (std::is_void_v<T> || std::is_reference_v<T>) {
return INTERLEAVED_ACQUIRE(joiner_.future.get());
}
else {
return std::forward<T>(INTERLEAVED_ACQUIRE(joiner_.future.get()));
}
}

} // namespace asyncpp
41 changes: 0 additions & 41 deletions include/async++/promise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,45 +55,4 @@ namespace impl {

} // namespace impl


struct resumer_coroutine_object;


struct resumer_coroutine_promise {
impl::resumable_promise& m_promise;

resumer_coroutine_promise(impl::resumable_promise& promise) : m_promise(promise) {}

auto get_return_object() noexcept {
return std::coroutine_handle<resumer_coroutine_promise>::from_promise(*this);
}

constexpr auto initial_suspend() const noexcept {
return std::suspend_always{};
}

constexpr void return_void() const noexcept {}

void unhandled_exception() const noexcept {
std::terminate();
}

constexpr auto final_suspend() const noexcept {
return std::suspend_never{};
}
};


struct resumer_coroutine_handle : std::coroutine_handle<resumer_coroutine_promise> {
resumer_coroutine_handle(std::coroutine_handle<resumer_coroutine_promise> handle)
: std::coroutine_handle<resumer_coroutine_promise>(handle) {}
using promise_type = resumer_coroutine_promise;
};


inline resumer_coroutine_handle resumer_coroutine(impl::resumable_promise& promise) {
promise.resume();
co_return;
}

} // namespace asyncpp
28 changes: 0 additions & 28 deletions include/async++/shared_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,28 +157,6 @@ namespace impl_shared_task {
}
};


template <class T>
struct sync_awaitable : chained_awaitable<T> {
promise<T>* m_awaited = nullptr;
std::promise<task_result<T>*> m_promise;
std::future<task_result<T>*> m_future = m_promise.get_future();

sync_awaitable(promise<T>* awaited) noexcept : m_awaited(awaited) {
m_awaited->acquire();
const bool ready = m_awaited->await(this);
if (ready) {
m_promise.set_value(&m_awaited->get_result());
}
}
~sync_awaitable() override {
m_awaited->release();
}
void on_ready() noexcept final {
return m_promise.set_value(&m_awaited->get_result());
}
};

} // namespace impl_shared_task


Expand Down Expand Up @@ -234,12 +212,6 @@ class shared_task {
return m_promise->ready();
}

auto get() const -> typename impl::task_result<T>::reference {
assert(valid());
impl_shared_task::sync_awaitable<T> awaitable(m_promise);
return INTERLEAVED_ACQUIRE(awaitable.m_future.get())->get_or_throw();
}

auto operator co_await() const {
assert(valid());
return impl_shared_task::awaitable<T>(m_promise);
Expand Down
27 changes: 0 additions & 27 deletions include/async++/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,6 @@ namespace impl_stream {
}
};


template <class T>
struct sync_awaitable : basic_awaitable<T> {
promise<T>* m_awaited;
sync_awaitable(promise<T>* awaited) noexcept : m_awaited(awaited) {
const bool ready = awaited->await(this);
if (ready) {
m_promise.set_value(awaited->get_result());
}
}
void on_ready() noexcept final {
m_promise.set_value(m_awaited->get_result());
}
std::promise<task_result<T>> m_promise;
std::future<task_result<T>> m_future = m_promise.get_future();
};

} // namespace impl_stream


Expand All @@ -177,16 +160,6 @@ class [[nodiscard]] stream {
release();
}

auto get() const -> std::optional<typename impl::task_result<T>::wrapper_type> {
assert(good() && "stream is finished");
impl_stream::sync_awaitable<T> awaitable(m_promise);
auto result = awaitable.m_future.get();
if (!result.has_value()) {
return std::nullopt;
}
return { std::forward<T>(result.get_or_throw()) };
}

auto operator co_await() const {
return impl_stream::awaitable<T>(m_promise);
}
Expand Down
31 changes: 0 additions & 31 deletions include/async++/task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,26 +151,6 @@ namespace impl_task {
}
};


template <class T>
struct sync_awaitable : basic_awaitable<T> {
promise<T>* m_awaited = nullptr;
std::promise<task_result<T>> m_promise;
std::future<task_result<T>> m_future = m_promise.get_future();

sync_awaitable(promise<T>* awaited) noexcept : m_awaited(awaited) {
const bool ready = m_awaited->await(this);
if (ready) {
m_promise.set_value(m_awaited->get_result());
m_awaited->release();
}
}
void on_ready() noexcept final {
return m_promise.set_value(m_awaited->get_result());
}
};


} // namespace impl_task


Expand Down Expand Up @@ -205,17 +185,6 @@ class [[nodiscard]] task {
return m_promise->ready();
}

T get() {
assert(valid());
impl_task::sync_awaitable<T> awaitable(std::exchange(m_promise, nullptr));
if constexpr (!std::is_void_v<T>) {
return std::forward<T>(INTERLEAVED_ACQUIRE(awaitable.m_future.get()).get_or_throw());
}
else {
INTERLEAVED_ACQUIRE(awaitable.m_future.get()).get_or_throw();
}
}

auto operator co_await() {
assert(valid());
return impl_task::awaitable<T>(std::exchange(m_promise, nullptr));
Expand Down
15 changes: 8 additions & 7 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ add_executable(test)

target_sources(test
PRIVATE
container/test_atomic_collection.cpp
container/test_atomic_queue.cpp
container/test_atomic_stack.cpp
interleaving/test_runner.cpp
main.cpp
test_generator.cpp
test_stream.cpp
test_thread_pool.cpp
test_task.cpp
test_join.cpp
test_mutex.cpp
test_shared_mutex.cpp
interleaving/test_runner.cpp
test_shared_task.cpp
container/test_atomic_queue.cpp
container/test_atomic_stack.cpp
container/test_atomic_collection.cpp
test_stream.cpp
test_task.cpp
test_thread_pool.cpp
)


Expand Down
Loading

0 comments on commit e0d29ba

Please sign in to comment.