Skip to content

Commit

Permalink
feat: Make Asio based register_x functions compatible with GrpcCont…
Browse files Browse the repository at this point in the history
…ext::run_completion_queue/poll_completion_queue()
  • Loading branch information
Tradias committed Dec 7, 2024
1 parent 04d2223 commit 88cd6ec
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 77 deletions.
50 changes: 27 additions & 23 deletions src/agrpc/detail/register_rpc_handler_asio_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef AGRPC_DETAIL_REGISTER_RPC_HANDLER_ASIO_BASE_HPP
#define AGRPC_DETAIL_REGISTER_RPC_HANDLER_ASIO_BASE_HPP

#include <agrpc/alarm.hpp>
#include <agrpc/detail/association.hpp>
#include <agrpc/detail/register_rpc_handler_base.hpp>
#include <agrpc/detail/rethrow_first_arg.hpp>
Expand All @@ -28,14 +29,14 @@ AGRPC_NAMESPACE_BEGIN()

namespace detail
{
inline constexpr auto REGISTER_RPC_HANDLER_COMPLETE =
static_cast<detail::OperationResult>(detail::to_underlying(detail::OperationResult::OK_) + 1);
void register_rpc_handler_asio_completion_trampoline(agrpc::GrpcContext& grpc_context,
detail::RegisterRPCHandlerOperationComplete& operation);

template <class ServerRPC, class RPCHandler, class CompletionHandlerT>
class RegisterRPCHandlerOperationAsioBase
: public detail::RegisterRPCHandlerOperationBase<ServerRPC, RPCHandler,
detail::CancellationSlotT<CompletionHandlerT&>>,
public detail::QueueableOperationBase,
public detail::RegisterRPCHandlerOperationComplete,
private detail::WorkTracker<detail::AssociatedExecutorT<CompletionHandlerT>>
{
public:
Expand All @@ -44,6 +45,7 @@ class RegisterRPCHandlerOperationAsioBase

private:
using Base = detail::RegisterRPCHandlerOperationBase<ServerRPC, RPCHandler, StopToken>;
using CompletionBase = detail::RegisterRPCHandlerOperationComplete;
using WorkTracker = detail::WorkTracker<detail::AssociatedExecutorT<CompletionHandlerT>>;

struct Decrementer
Expand All @@ -52,7 +54,7 @@ class RegisterRPCHandlerOperationAsioBase
{
if (self_.decrement_ref_count())
{
self_.complete(REGISTER_RPC_HANDLER_COMPLETE, self_.grpc_context());
detail::register_rpc_handler_asio_completion_trampoline(self_.grpc_context(), self_);
}
}

Expand All @@ -69,9 +71,9 @@ class RegisterRPCHandlerOperationAsioBase

template <class Ch>
RegisterRPCHandlerOperationAsioBase(const ServerRPCExecutor& executor, Service& service, RPCHandler&& rpc_handler,
Ch&& completion_handler, detail::OperationOnComplete on_complete)
Ch&& completion_handler, CompletionBase::Complete on_complete)
: Base(executor, service, static_cast<RPCHandler&&>(rpc_handler)),
detail::QueueableOperationBase(on_complete),
CompletionBase(on_complete),
WorkTracker(asio::get_associated_executor(completion_handler)),
completion_handler_(static_cast<Ch&&>(completion_handler))
{
Expand Down Expand Up @@ -105,26 +107,28 @@ struct RegisterRPCHandlerInitiator
detail::ServerRPCServiceT<ServerRPC>& service_;
};

inline void register_rpc_handler_asio_completion_trampoline(agrpc::GrpcContext& grpc_context,
detail::RegisterRPCHandlerOperationComplete& operation)
{
detail::ScopeGuard guard{[&operation]
{
operation.complete();
}};
agrpc::Alarm{grpc_context}.wait(detail::GrpcContextImplementation::TIME_ZERO,
[g = std::move(guard)](auto&&...) mutable
{
g.execute();
});
grpc_context.work_finished();
}

template <class Operation>
static void register_rpc_handler_asio_do_complete(detail::OperationBase* operation, detail::OperationResult result,
agrpc::GrpcContext&)
void register_rpc_handler_asio_do_complete(detail::RegisterRPCHandlerOperationComplete& operation) noexcept
{
auto& self = *static_cast<Operation*>(operation);
auto& self = static_cast<Operation&>(operation);
detail::AllocationGuard guard{self, self.get_allocator()};
if (REGISTER_RPC_HANDLER_COMPLETE == result)
{
if AGRPC_LIKELY (!detail::GrpcContextImplementation::is_shutdown(self.grpc_context()))
{
detail::GrpcContextImplementation::add_operation(self.grpc_context(), &self);
guard.release();
}
return;
}
if AGRPC_LIKELY (!detail::is_shutdown(result))
{
auto eptr{static_cast<std::exception_ptr&&>(self.error())};
detail::dispatch_complete(guard, static_cast<std::exception_ptr&&>(eptr));
}
auto eptr{static_cast<std::exception_ptr&&>(self.error())};
detail::dispatch_complete(guard, static_cast<std::exception_ptr&&>(eptr));
}
}

Expand Down
9 changes: 9 additions & 0 deletions src/agrpc/detail/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ class ScopeGuard

constexpr const OnExit& get() const noexcept { return on_exit_; }

constexpr void execute()
{
if (is_armed_)
{
is_armed_ = false;
on_exit_();
}
}

private:
OnExit on_exit_;
bool is_armed_{true};
Expand Down
8 changes: 4 additions & 4 deletions src/agrpc/grpc_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class GrpcContext
* @brief Run the `grpc::CompletionQueue`
*
* Runs the main event loop logic until the GrpcContext runs out of work or is stopped. Only events from the
* `grpc::CompletionQueue` will be handled. That means that completion handler that were e.g. created using
* `grpc::CompletionQueue` will be handled. Completion handlers that were, for example, created using
* `asio::post(grpc_context, ...)` will not be processed. The GrpcContext will be brought into the ready state when
* this function is invoked. Upon return, the GrpcContext will be in the stopped state.
*
Expand Down Expand Up @@ -210,9 +210,9 @@ class GrpcContext
/**
* @brief Poll the `grpc::CompletionQueue`
*
* Processes only ready events of the `grpc::CompletionQueue`. That means that completion handler that were e.g.
* created using `asio::post(grpc_context, ...)` will not be processed. The GrpcContext will be brought into the
* ready state when this function is invoked.
* Processes only ready events of the `grpc::CompletionQueue`. Completion handlers that were, for example, created
* using `asio::post(grpc_context, ...)` will not be processed. The GrpcContext will be brought into the ready state
* when this function is invoked.
*
* @attention Only one thread may call run_completion_queue() or poll_completion_queue() at a time [unless this
* context has been constructed with a `concurrency_hint` greater than one. Even then it may not be called
Expand Down
14 changes: 7 additions & 7 deletions test/src/test_client_rpc_17.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ struct ClientRPCIoContextTest : ClientRPCRequestResponseTest<RPC>, test::IoConte
std::function<void(test::TypeIdentityT<SRPC>&, const asio::yield_context&)> server_func,
std::function<void(const asio::yield_context&)> client_func)
{
test::typed_spawn(io_context,
[this, client_func, g = this->get_work_tracking_executor()](const asio::yield_context& yield)
{
client_func(yield);
this->server_shutdown.initiate();
});
test::spawn(io_context,
[this, client_func, g = this->get_work_tracking_executor()](const asio::yield_context& yield)
{
client_func(yield);
this->server_shutdown.initiate();
});
agrpc::register_yield_rpc_handler<SRPC>(this->grpc_context, this->service, server_func,
test::RethrowFirstArg{});
this->run_io_context_detached(false);
Expand Down Expand Up @@ -204,7 +204,7 @@ TEST_CASE_FIXTURE(ClientRPCRequestResponseTest<test::UnaryClientRPC>,
using RPC = agrpc::UseSender::as_default_on_t<agrpc::ClientRPC<&test::v1::Test::Stub::PrepareAsyncUnary>>;
bool ok{};
test::DeleteGuard guard{};
register_perform_requests_no_shutdown(
register_and_perform_requests_no_shutdown(
[&](auto& rpc, auto& request, const asio::yield_context& yield)
{
CHECK_EQ(42, request.integer());
Expand Down
57 changes: 56 additions & 1 deletion test/src/test_server_rpc_17.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

#include <agrpc/client_rpc.hpp>
#include <agrpc/read.hpp>
#include <agrpc/register_yield_rpc_handler.hpp>
#include <agrpc/server_rpc.hpp>
#include <agrpc/waiter.hpp>

Expand Down Expand Up @@ -459,6 +458,62 @@ TEST_CASE_TEMPLATE("ServerRPC/ClientRPC generic streaming success", RPC, test::G
});
}

TEST_CASE("ServerRPC/ClientRPC bidi streaming on io_context success")
{
using RPC = test::NotifyWhenDoneBidirectionalStreamingServerRPC;
ServerRPCTest<RPC> test{true};
asio::io_context io_context{1};
const auto io_context_thread_id = std::this_thread::get_id();
std::thread::id final_thread_id{};
agrpc::register_yield_rpc_handler<RPC>(
test.get_executor(), test.service,
[&](RPC& rpc, const asio::yield_context& yield)
{
CHECK_EQ(io_context_thread_id, std::this_thread::get_id());
auto future = test.set_up_notify_when_done(rpc);
RPC::Request request;
CHECK(rpc.read(request, yield));
CHECK_EQ(1, request.integer());
CHECK_FALSE(rpc.read(request, yield));
RPC::Response response;
response.set_integer(11);
CHECK(rpc.write(response, grpc::WriteOptions{}, yield));
response.set_integer(12);
CHECK(rpc.write_and_finish(response, grpc::Status::OK, yield));
CHECK_EQ(io_context_thread_id, std::this_thread::get_id());
test.check_notify_when_done(future, rpc, yield);
},
asio::bind_executor(io_context,
[&](auto&& ep)
{
final_thread_id = std::this_thread::get_id();
test::RethrowFirstArg{}(ep);
}));
auto client_function = [&](RPC::Request& request, RPC::Response& response, const asio::yield_context& yield)
{
auto rpc = test.create_rpc();
test.start_rpc(rpc, request, response, yield);
request.set_integer(1);
CHECK(rpc.write(request, yield));
CHECK(rpc.writes_done(yield));
CHECK(rpc.read(response, yield));
CHECK_EQ(11, response.integer());
CHECK(rpc.read(response, yield));
CHECK_EQ(12, response.integer());
CHECK_FALSE(rpc.read(response, yield));
CHECK_EQ(12, response.integer());
CHECK_EQ(grpc::StatusCode::OK, rpc.finish(yield).error_code());
};
test.spawn_client_functions(io_context, client_function, client_function, client_function);
std::thread t{[&]
{
test.grpc_context.run_completion_queue();
}};
io_context.run();
t.join();
CHECK_EQ(final_thread_id, std::this_thread::get_id());
}

TEST_CASE("ServerRPC::service_name/method_name")
{
const auto check_eq_and_null_terminated = [](std::string_view expected, std::string_view actual)
Expand Down
5 changes: 5 additions & 0 deletions test/utils/utils/asio_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ void spawn(agrpc::GrpcContext& grpc_context, const std::function<void(const asio
test::typed_spawn(grpc_context, function);
}

void spawn(asio::io_context& io_context, const std::function<void(const asio::yield_context&)>& function)
{
test::typed_spawn(io_context, function);
}

void post(agrpc::GrpcContext& grpc_context, const std::function<void()>& function)
{
asio::post(grpc_context, function);
Expand Down
11 changes: 9 additions & 2 deletions test/utils/utils/asio_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ struct FunctionAsReceiver
auto get_allocator() const noexcept { return allocator; }

#ifdef AGRPC_UNIFEX
friend auto tag_invoke(unifex::tag_t<unifex::get_allocator>,
const FunctionAsReceiver& receiver) noexcept -> Allocator
friend Allocator tag_invoke(unifex::tag_t<unifex::get_allocator>, const FunctionAsReceiver& receiver) noexcept
{
return receiver.allocator;
}
Expand Down Expand Up @@ -207,6 +206,8 @@ void wait(agrpc::Alarm& alarm, std::chrono::system_clock::time_point deadline,

void spawn(agrpc::GrpcContext& grpc_context, const std::function<void(const asio::yield_context&)>& function);

void spawn(asio::io_context& io_context, const std::function<void(const asio::yield_context&)>& function);

template <class Executor, class Function>
void typed_spawn(Executor&& executor, Function&& function)
{
Expand All @@ -217,6 +218,12 @@ void typed_spawn(Executor&& executor, Function&& function)
#endif
}

template <class Executor, class... Functions>
void spawn_many(Executor&& executor, Functions&&... functions)
{
(test::spawn(executor, std::forward<Functions>(functions)), ...);
}

template <class... Functions>
void spawn_and_run(agrpc::GrpcContext& grpc_context, Functions&&... functions)
{
Expand Down
75 changes: 35 additions & 40 deletions test/utils/utils/client_rpc_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,73 +51,68 @@ struct ClientServerRPCTest : std::conditional_t<(agrpc::ClientRPCType::GENERIC_U
using Response = typename ClientRPC::Response;

#if defined(AGRPC_STANDALONE_ASIO) || defined(AGRPC_BOOST_ASIO)
template <class RPCHandler, class... ClientFunctions>
void register_perform_requests_no_shutdown(RPCHandler&& handler, ClientFunctions&&... client_functions)
template <class Executor, class... ClientFunctions>
void spawn_client_functions(Executor&& executor, ClientFunctions&&... client_functions)
{
agrpc::register_yield_rpc_handler<ServerRPC>(this->get_executor(), this->service, handler,
test::RethrowFirstArg{});
test::spawn_and_run(this->grpc_context,
[&client_functions](const asio::yield_context& yield)
{
typename ClientRPC::Request request;
typename ClientRPC::Response response;
client_functions(request, response, yield);
}...);
}

template <class RPCHandler, class... ClientFunctions>
void register_callback_and_perform_requests(RPCHandler&& handler, ClientFunctions&&... client_functions)
{
int counter{};
agrpc::register_callback_rpc_handler<ServerRPC>(this->get_executor(), this->service, handler,
test::RethrowFirstArg{});
test::spawn_and_run(
this->grpc_context,
[&counter, &client_functions, &server_shutdown = this->server_shutdown](const asio::yield_context& yield)
auto counter = std::make_shared<int>();
test::spawn_many(
executor,
[counter, &client_functions, &server_shutdown = this->server_shutdown](const asio::yield_context& yield)
{
typename ClientRPC::Request request;
typename ClientRPC::Response response;
client_functions(request, response, yield);
++counter;
if (counter == sizeof...(client_functions))
++(*counter);
if (*counter == sizeof...(client_functions))
{
server_shutdown.initiate();
}
}...);
}

template <class RPCHandler, class... ClientFunctions>
void register_callback_and_perform_requests(RPCHandler&& handler, ClientFunctions&&... client_functions)
{
agrpc::register_callback_rpc_handler<ServerRPC>(this->get_executor(), this->service, handler,
test::RethrowFirstArg{});
spawn_client_functions(this->grpc_context, static_cast<ClientFunctions&&>(client_functions)...);
this->grpc_context.run();
}

template <class RPCHandler, class ClientFunction>
void register_callback_and_perform_three_requests(RPCHandler&& handler, ClientFunction&& client_function)
{
register_callback_and_perform_requests(std::forward<RPCHandler>(handler), client_function, client_function,
register_callback_and_perform_requests(static_cast<RPCHandler&&>(handler), client_function, client_function,
client_function);
}

template <class RPCHandler, class... ClientFunctions>
void register_and_perform_requests_no_shutdown(RPCHandler&& handler, ClientFunctions&&... client_functions)
{
agrpc::register_yield_rpc_handler<ServerRPC>(this->get_executor(), this->service, handler,
test::RethrowFirstArg{});
test::spawn_and_run(this->grpc_context,
[&client_functions](const asio::yield_context& yield)
{
typename ClientRPC::Request request;
typename ClientRPC::Response response;
client_functions(request, response, yield);
}...);
}

template <class RPCHandler, class... ClientFunctions>
void register_and_perform_requests(RPCHandler&& handler, ClientFunctions&&... client_functions)
{
int counter{};
agrpc::register_yield_rpc_handler<ServerRPC>(this->get_executor(), this->service, handler,
test::RethrowFirstArg{});
test::spawn_and_run(
this->grpc_context,
[&counter, &client_functions, &server_shutdown = this->server_shutdown](const asio::yield_context& yield)
{
typename ClientRPC::Request request;
typename ClientRPC::Response response;
client_functions(request, response, yield);
++counter;
if (counter == sizeof...(client_functions))
{
server_shutdown.initiate();
}
}...);
spawn_client_functions(this->grpc_context, static_cast<ClientFunctions&&>(client_functions)...);
this->grpc_context.run();
}

template <class RPCHandler, class ClientFunction>
void register_and_perform_three_requests(RPCHandler&& handler, ClientFunction&& client_function)
{
register_and_perform_requests(std::forward<RPCHandler>(handler), client_function, client_function,
register_and_perform_requests(static_cast<RPCHandler&&>(handler), client_function, client_function,
client_function);
}

Expand Down

0 comments on commit 88cd6ec

Please sign in to comment.