Skip to content

Commit

Permalink
feat: add coro::CallbackTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
OEOTYAN committed Feb 12, 2025
1 parent 0d165b0 commit a6f6e64
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 11 deletions.
78 changes: 69 additions & 9 deletions src-test/common/ExecTest.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#include "ll/core/LeviLamina.h"

#include "ll/api/chrono/GameChrono.h"
#include "ll/api/coro/CallbackTransformer.h"
#include "ll/api/coro/CoroTask.h"
#include "ll/api/coro/Generator.h"
#include "ll/api/coro/InterruptableSleep.h"
#include "ll/api/thread/InplaceExecutor.h"
#include "ll/api/thread/ServerThreadExecutor.h"
#include "ll/api/thread/ThreadName.h"
#include "ll/api/thread/ThreadPoolExecutor.h"

using namespace ll;
Expand Down Expand Up @@ -34,14 +36,14 @@ CoroTask<Expected<int>> coroutine() {
std::thread{[] {
std::this_thread::sleep_for(1s);
sleeper.interrupt(true);
getLogger().info("thread: {}", std::this_thread::get_id());
getLogger().info("thread: {}", *thread::getThreadName());
}}.detach();
getLogger().info("coroutine: {}, thread: {}", std::chrono::system_clock::now(), std::this_thread::get_id());
getLogger().info("coroutine: {}, thread: {}", std::chrono::system_clock::now(), *thread::getThreadName());
co_await sleeper.sleep();
getLogger().info("coroutine: {}, thread: {}", std::chrono::system_clock::now(), std::this_thread::get_id());
getLogger().info("coroutine: {}, thread: {}", std::chrono::system_clock::now(), *thread::getThreadName());
for (size_t i = 0;; i++) {
co_await 2_tick;
getLogger().info("coroutine: {}, thread: {}", std::chrono::system_clock::now(), std::this_thread::get_id());
getLogger().info("coroutine: {}, thread: {}", std::chrono::system_clock::now(), *thread::getThreadName());
if (i > 10) {
break;
}
Expand All @@ -53,7 +55,7 @@ CoroTask<Expected<int>> coroutine() {
getLogger().info(
"coroutine: collectAll use {}",
std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::steady_clock::now() - parbegin),
std::this_thread::get_id()
*thread::getThreadName()
);
std::vector<ll::coro::CoroTask<int>> tasks{};
parbegin = std::chrono::steady_clock::now();
Expand All @@ -64,22 +66,80 @@ CoroTask<Expected<int>> coroutine() {
getLogger().info(
"coroutine: generator use {}",
std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::steady_clock::now() - parbegin),
std::this_thread::get_id()
*thread::getThreadName()
);
auto vec = co_await collectAll(std::move(tasks));
getLogger().info(
"coroutine: collectAll use {}",
std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::steady_clock::now() - parbegin),
std::this_thread::get_id()
*thread::getThreadName()
);
co_return v1.value() + v2.value() + vec[0].value() + vec[1].value();
}
static bool run = [] {
using namespace ll;

thread::ServerThreadExecutor::getDefault().execute([&] {
auto val = coroutine().syncLaunch(thread::ThreadPoolExecutor::getDefault());
getLogger().info("coroutine done, {}", val.value());
coroutine().launch(thread::ThreadPoolExecutor::getDefault(), [](Expected<int>&& val) {
getLogger().info("coroutine done, {}", val.value());
});

auto f = []() -> CoroTask<> {
CallbackTransformer<int> t;
std::thread{[&] {
auto setter = t.getValueSetter();
auto callback = [&](int i) {
setter.emplace(i);
return !setter.finished();
};
getLogger().info("start callback");
for (int i = 0; i < 20; ++i) {
getLogger().info("callback {}", i);
if (!callback(i)) {
break;
}
}
}}.detach();
int a = 0;
for (auto iter = co_await t.begin(); iter != t.end(); co_await ++iter) {
auto&& val = *iter;
if (val > 10) {
break;
}
a++;
getLogger().info("val {}, thread {}", val, *thread::getThreadName());
}
getLogger().info("breaked {}", a);
co_await 0.1s;
};
keepThis(std::move(f)).syncLaunch(thread::ThreadPoolExecutor::getDefault());

CallbackTransformer<int> t;

auto f2 = [&t]() -> CoroTask<> {
for (auto iter = co_await t.begin(); iter != t.end(); co_await ++iter) {
auto&& val = *iter;

if (val > 10) {
break;
}
getLogger().info("val {}, thread {}", val, *thread::getThreadName());
}
getLogger().info("breaked");
};
f2().launch(thread::InplaceExecutor::getDefault());

auto callback = [v = t.getValueSetter()](int i) {
v.emplace(i);
return !v.finished();
};
getLogger().info("start callback");
for (int i = 0; i < 7; ++i) {
getLogger().info("callback {}", i);
if (!callback(i)) {
break;
}
}
});

return true;
Expand Down
168 changes: 168 additions & 0 deletions src/ll/api/coro/CallbackTransformer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#pragma once

#include <coroutine>
#include <optional>
#include <semaphore>

#include "ll/api/coro/Executor.h"

namespace ll::coro {
template <class T>
class CallbackTransformer {
using ValueStorage = std::conditional_t<std::is_lvalue_reference_v<T>, std::add_pointer_t<T>, std::optional<T>>;
struct Data {
ExecutorRef exec;
std::coroutine_handle<> handle;
ValueStorage value{};
std::binary_semaphore sem{0};
std::atomic_bool finished{false};

void reset() {
if constexpr (std::is_lvalue_reference_v<T>) {
value = nullptr;
} else {
value.reset();
}
}
void execute() {
if (auto h = exchange(handle, nullptr)) {
exec->execute(h);
}
}
};

std::shared_ptr<Data> data;

public:
class ValueSetter {
std::shared_ptr<Data> data;

public:
explicit ValueSetter(CallbackTransformer const& t) : data(t.data) {}

ValueSetter(ValueSetter const&) = delete;
ValueSetter& operator=(ValueSetter const&) = delete;
ValueSetter(ValueSetter&&) = default;
ValueSetter& operator=(ValueSetter&&) = default;

bool finished() const { return data->finished; }
void finish() const { data->finished.store(true); }

template <class... Args>
requires(!std::is_lvalue_reference_v<T>)
constexpr void emplace(Args&&... args) const {
data->sem.acquire();
data->value.emplace(std::forward<Args>(args)...);
data->execute();
}

constexpr void emplace(T& t) const
requires(std::is_lvalue_reference_v<T>)
{
data->sem.acquire();
data->value = std::addressof(t);
data->execute();
}

~ValueSetter() {
if (data) {
finish();
data->execute();
}
}
};

class Iterator {
std::shared_ptr<Data> data;

class IncAwaiter {
Iterator& iter;

public:
IncAwaiter(Iterator& iter) : iter(iter) {}

constexpr bool await_ready() const noexcept { return iter.data == nullptr || iter.data->finished; }
void await_suspend(std::coroutine_handle<> handle) {
iter.data->handle = handle;
iter.data->reset();
iter.data->sem.release();
}
constexpr Iterator& await_resume() const noexcept {
if (iter.data && iter.data->finished && !iter.data->value) {
iter.data.reset();
}
return iter;
}
};

public:
Iterator() = default;
explicit Iterator(CallbackTransformer const& t) : data(t.data) {}

Iterator(Iterator const&) = delete;
Iterator& operator=(Iterator const&) = delete;
Iterator(Iterator&&) = default;
Iterator& operator=(Iterator&&) = default;

[[nodiscard]] IncAwaiter operator++() { return IncAwaiter{*this}; }

[[nodiscard]] bool operator==(Iterator const& other) const noexcept { return data == other.data; }

[[nodiscard]] bool operator==(std::nullptr_t) const noexcept { return data == nullptr; }

[[nodiscard]] operator bool() const noexcept { return data != nullptr; }

[[nodiscard]] T&& operator*() const noexcept { return static_cast<T&&>(*data->value); }

[[nodiscard]] std::add_pointer_t<T> operator->() const noexcept { return std::addressof(*data->value); }

~Iterator() {
if (data) {
data->finished.store(true);
data->sem.release();
}
}
};

CallbackTransformer() : data(std::make_shared<Data>()) {}

CallbackTransformer(CallbackTransformer const&) = delete;
CallbackTransformer& operator=(CallbackTransformer const&) = delete;
CallbackTransformer(CallbackTransformer&&) = default;
CallbackTransformer& operator=(CallbackTransformer&&) = default;

ValueSetter getValueSetter() const { return ValueSetter{*this}; }

[[nodiscard]] Iterator end() noexcept { return {}; }

[[nodiscard]] auto begin() noexcept {
class BeginAwaiter {
CallbackTransformer& ct;

public:
BeginAwaiter(CallbackTransformer& ct) : ct(ct) {}

constexpr bool await_ready() const noexcept { return ct.data->finished; }
void await_suspend(std::coroutine_handle<> handle) {
ct.data->handle = handle;
ct.data->sem.release();
}
constexpr Iterator await_resume() const noexcept {
if (ct.data->finished && !ct.data->value) {
return {};
}
return Iterator{ct};
}
constexpr void setExecutor(ExecutorRef ex) { ct.data->exec = ex; }
};
return BeginAwaiter{*this};
}

~CallbackTransformer() {
if (data) {
data->finished.store(true);
data->execute();
}
}
};
} // namespace ll::coro
2 changes: 0 additions & 2 deletions src/ll/api/coro/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ struct Generator {

[[nodiscard]] bool operator==(iterator const& other) const noexcept { return handle == other.handle; }

[[nodiscard]] bool operator!=(iterator const& other) const noexcept { return !(*this == other); }

[[nodiscard]] reference operator*() const noexcept { return static_cast<reference>(*handle.promise().ptr); }

[[nodiscard]] pointer operator->() const noexcept { return handle.promise().ptr; }
Expand Down

0 comments on commit a6f6e64

Please sign in to comment.