diff --git a/src/agrpc/detail/register_rpc_handler_asio_base.hpp b/src/agrpc/detail/register_rpc_handler_asio_base.hpp index 24232ca8..7275e793 100644 --- a/src/agrpc/detail/register_rpc_handler_asio_base.hpp +++ b/src/agrpc/detail/register_rpc_handler_asio_base.hpp @@ -15,6 +15,7 @@ #ifndef AGRPC_DETAIL_REGISTER_RPC_HANDLER_ASIO_BASE_HPP #define AGRPC_DETAIL_REGISTER_RPC_HANDLER_ASIO_BASE_HPP +#include #include #include #include @@ -28,14 +29,14 @@ AGRPC_NAMESPACE_BEGIN() namespace detail { -inline constexpr auto REGISTER_RPC_HANDLER_COMPLETE = - static_cast(detail::to_underlying(detail::OperationResult::OK_) + 1); +void register_rpc_handler_asio_completion_trampoline(agrpc::GrpcContext& grpc_context, + detail::RegisterRPCHandlerOperationComplete& operation); template class RegisterRPCHandlerOperationAsioBase : public detail::RegisterRPCHandlerOperationBase>, - public detail::QueueableOperationBase, + public detail::RegisterRPCHandlerOperationComplete, private detail::WorkTracker> { public: @@ -44,6 +45,7 @@ class RegisterRPCHandlerOperationAsioBase private: using Base = detail::RegisterRPCHandlerOperationBase; + using CompletionBase = detail::RegisterRPCHandlerOperationComplete; using WorkTracker = detail::WorkTracker>; struct Decrementer @@ -52,7 +54,7 @@ class RegisterRPCHandlerOperationAsioBase { if (self_.decrement_ref_count()) { - self_.complete(REGISTER_RPC_HANDLER_COMPLETE, self_.grpc_context()); + detail::register_rpc_handler_asio_completion_trampoline(self_.grpc_context(), self_); } } @@ -69,9 +71,9 @@ class RegisterRPCHandlerOperationAsioBase template RegisterRPCHandlerOperationAsioBase(const ServerRPCExecutor& executor, Service& service, RPCHandler&& rpc_handler, - Ch&& completion_handler, detail::OperationOnComplete on_complete) + Ch&& completion_handler, CompletionBase::Complete on_complete) : Base(executor, service, static_cast(rpc_handler)), - detail::QueueableOperationBase(on_complete), + CompletionBase(on_complete), WorkTracker(asio::get_associated_executor(completion_handler)), completion_handler_(static_cast(completion_handler)) { @@ -105,26 +107,28 @@ struct RegisterRPCHandlerInitiator detail::ServerRPCServiceT& service_; }; +inline void register_rpc_handler_asio_completion_trampoline(agrpc::GrpcContext& grpc_context, + detail::RegisterRPCHandlerOperationComplete& operation) +{ + detail::ScopeGuard guard{[&operation] + { + operation.complete(); + }}; + agrpc::Alarm{grpc_context}.wait(detail::GrpcContextImplementation::TIME_ZERO, + [g = std::move(guard)](auto&&...) mutable + { + g.execute(); + }); + grpc_context.work_finished(); +} + template -static void register_rpc_handler_asio_do_complete(detail::OperationBase* operation, detail::OperationResult result, - agrpc::GrpcContext&) +void register_rpc_handler_asio_do_complete(detail::RegisterRPCHandlerOperationComplete& operation) noexcept { - auto& self = *static_cast(operation); + auto& self = static_cast(operation); detail::AllocationGuard guard{self, self.get_allocator()}; - if (REGISTER_RPC_HANDLER_COMPLETE == result) - { - if AGRPC_LIKELY (!detail::GrpcContextImplementation::is_shutdown(self.grpc_context())) - { - detail::GrpcContextImplementation::add_operation(self.grpc_context(), &self); - guard.release(); - } - return; - } - if AGRPC_LIKELY (!detail::is_shutdown(result)) - { - auto eptr{static_cast(self.error())}; - detail::dispatch_complete(guard, static_cast(eptr)); - } + auto eptr{static_cast(self.error())}; + detail::dispatch_complete(guard, static_cast(eptr)); } } diff --git a/src/agrpc/detail/utility.hpp b/src/agrpc/detail/utility.hpp index 4c21ce9d..dd43e8e2 100644 --- a/src/agrpc/detail/utility.hpp +++ b/src/agrpc/detail/utility.hpp @@ -211,6 +211,15 @@ class ScopeGuard constexpr const OnExit& get() const noexcept { return on_exit_; } + constexpr void execute() + { + if (is_armed_) + { + is_armed_ = false; + on_exit_(); + } + } + private: OnExit on_exit_; bool is_armed_{true}; diff --git a/src/agrpc/grpc_context.hpp b/src/agrpc/grpc_context.hpp index 474938f6..cf7a7d7a 100644 --- a/src/agrpc/grpc_context.hpp +++ b/src/agrpc/grpc_context.hpp @@ -181,7 +181,7 @@ class GrpcContext * @brief Run the `grpc::CompletionQueue` * * Runs the main event loop logic until the GrpcContext runs out of work or is stopped. Only events from the - * `grpc::CompletionQueue` will be handled. That means that completion handler that were e.g. created using + * `grpc::CompletionQueue` will be handled. Completion handlers that were, for example, created using * `asio::post(grpc_context, ...)` will not be processed. The GrpcContext will be brought into the ready state when * this function is invoked. Upon return, the GrpcContext will be in the stopped state. * @@ -210,9 +210,9 @@ class GrpcContext /** * @brief Poll the `grpc::CompletionQueue` * - * Processes only ready events of the `grpc::CompletionQueue`. That means that completion handler that were e.g. - * created using `asio::post(grpc_context, ...)` will not be processed. The GrpcContext will be brought into the - * ready state when this function is invoked. + * Processes only ready events of the `grpc::CompletionQueue`. Completion handlers that were, for example, created + * using `asio::post(grpc_context, ...)` will not be processed. The GrpcContext will be brought into the ready state + * when this function is invoked. * * @attention Only one thread may call run_completion_queue() or poll_completion_queue() at a time [unless this * context has been constructed with a `concurrency_hint` greater than one. Even then it may not be called diff --git a/test/src/test_client_rpc_17.cpp b/test/src/test_client_rpc_17.cpp index c7c4b2ab..29f7a6d1 100644 --- a/test/src/test_client_rpc_17.cpp +++ b/test/src/test_client_rpc_17.cpp @@ -71,12 +71,12 @@ struct ClientRPCIoContextTest : ClientRPCRequestResponseTest, test::IoConte std::function&, const asio::yield_context&)> server_func, std::function client_func) { - test::typed_spawn(io_context, - [this, client_func, g = this->get_work_tracking_executor()](const asio::yield_context& yield) - { - client_func(yield); - this->server_shutdown.initiate(); - }); + test::spawn(io_context, + [this, client_func, g = this->get_work_tracking_executor()](const asio::yield_context& yield) + { + client_func(yield); + this->server_shutdown.initiate(); + }); agrpc::register_yield_rpc_handler(this->grpc_context, this->service, server_func, test::RethrowFirstArg{}); this->run_io_context_detached(false); @@ -204,7 +204,7 @@ TEST_CASE_FIXTURE(ClientRPCRequestResponseTest, using RPC = agrpc::UseSender::as_default_on_t>; bool ok{}; test::DeleteGuard guard{}; - register_perform_requests_no_shutdown( + register_and_perform_requests_no_shutdown( [&](auto& rpc, auto& request, const asio::yield_context& yield) { CHECK_EQ(42, request.integer()); diff --git a/test/src/test_server_rpc_17.cpp b/test/src/test_server_rpc_17.cpp index 734efc9d..baab7433 100644 --- a/test/src/test_server_rpc_17.cpp +++ b/test/src/test_server_rpc_17.cpp @@ -25,7 +25,6 @@ #include #include -#include #include #include @@ -459,6 +458,62 @@ TEST_CASE_TEMPLATE("ServerRPC/ClientRPC generic streaming success", RPC, test::G }); } +TEST_CASE("ServerRPC/ClientRPC bidi streaming on io_context success") +{ + using RPC = test::NotifyWhenDoneBidirectionalStreamingServerRPC; + ServerRPCTest test{true}; + asio::io_context io_context{1}; + const auto io_context_thread_id = std::this_thread::get_id(); + std::thread::id final_thread_id{}; + agrpc::register_yield_rpc_handler( + test.get_executor(), test.service, + [&](RPC& rpc, const asio::yield_context& yield) + { + CHECK_EQ(io_context_thread_id, std::this_thread::get_id()); + auto future = test.set_up_notify_when_done(rpc); + RPC::Request request; + CHECK(rpc.read(request, yield)); + CHECK_EQ(1, request.integer()); + CHECK_FALSE(rpc.read(request, yield)); + RPC::Response response; + response.set_integer(11); + CHECK(rpc.write(response, grpc::WriteOptions{}, yield)); + response.set_integer(12); + CHECK(rpc.write_and_finish(response, grpc::Status::OK, yield)); + CHECK_EQ(io_context_thread_id, std::this_thread::get_id()); + test.check_notify_when_done(future, rpc, yield); + }, + asio::bind_executor(io_context, + [&](auto&& ep) + { + final_thread_id = std::this_thread::get_id(); + test::RethrowFirstArg{}(ep); + })); + auto client_function = [&](RPC::Request& request, RPC::Response& response, const asio::yield_context& yield) + { + auto rpc = test.create_rpc(); + test.start_rpc(rpc, request, response, yield); + request.set_integer(1); + CHECK(rpc.write(request, yield)); + CHECK(rpc.writes_done(yield)); + CHECK(rpc.read(response, yield)); + CHECK_EQ(11, response.integer()); + CHECK(rpc.read(response, yield)); + CHECK_EQ(12, response.integer()); + CHECK_FALSE(rpc.read(response, yield)); + CHECK_EQ(12, response.integer()); + CHECK_EQ(grpc::StatusCode::OK, rpc.finish(yield).error_code()); + }; + test.spawn_client_functions(io_context, client_function, client_function, client_function); + std::thread t{[&] + { + test.grpc_context.run_completion_queue(); + }}; + io_context.run(); + t.join(); + CHECK_EQ(final_thread_id, std::this_thread::get_id()); +} + TEST_CASE("ServerRPC::service_name/method_name") { const auto check_eq_and_null_terminated = [](std::string_view expected, std::string_view actual) diff --git a/test/utils/utils/asio_utils.cpp b/test/utils/utils/asio_utils.cpp index 0c594dda..a1eda7b5 100644 --- a/test/utils/utils/asio_utils.cpp +++ b/test/utils/utils/asio_utils.cpp @@ -28,6 +28,11 @@ void spawn(agrpc::GrpcContext& grpc_context, const std::function& function) +{ + test::typed_spawn(io_context, function); +} + void post(agrpc::GrpcContext& grpc_context, const std::function& function) { asio::post(grpc_context, function); diff --git a/test/utils/utils/asio_utils.hpp b/test/utils/utils/asio_utils.hpp index 10045119..509e585c 100644 --- a/test/utils/utils/asio_utils.hpp +++ b/test/utils/utils/asio_utils.hpp @@ -104,8 +104,7 @@ struct FunctionAsReceiver auto get_allocator() const noexcept { return allocator; } #ifdef AGRPC_UNIFEX - friend auto tag_invoke(unifex::tag_t, - const FunctionAsReceiver& receiver) noexcept -> Allocator + friend Allocator tag_invoke(unifex::tag_t, const FunctionAsReceiver& receiver) noexcept { return receiver.allocator; } @@ -207,6 +206,8 @@ void wait(agrpc::Alarm& alarm, std::chrono::system_clock::time_point deadline, void spawn(agrpc::GrpcContext& grpc_context, const std::function& function); +void spawn(asio::io_context& io_context, const std::function& function); + template void typed_spawn(Executor&& executor, Function&& function) { @@ -217,6 +218,12 @@ void typed_spawn(Executor&& executor, Function&& function) #endif } +template +void spawn_many(Executor&& executor, Functions&&... functions) +{ + (test::spawn(executor, std::forward(functions)), ...); +} + template void spawn_and_run(agrpc::GrpcContext& grpc_context, Functions&&... functions) { diff --git a/test/utils/utils/client_rpc_test.hpp b/test/utils/utils/client_rpc_test.hpp index 76d3e4c7..a457fb34 100644 --- a/test/utils/utils/client_rpc_test.hpp +++ b/test/utils/utils/client_rpc_test.hpp @@ -51,73 +51,68 @@ struct ClientServerRPCTest : std::conditional_t<(agrpc::ClientRPCType::GENERIC_U using Response = typename ClientRPC::Response; #if defined(AGRPC_STANDALONE_ASIO) || defined(AGRPC_BOOST_ASIO) - template - void register_perform_requests_no_shutdown(RPCHandler&& handler, ClientFunctions&&... client_functions) + template + void spawn_client_functions(Executor&& executor, ClientFunctions&&... client_functions) { - agrpc::register_yield_rpc_handler(this->get_executor(), this->service, handler, - test::RethrowFirstArg{}); - test::spawn_and_run(this->grpc_context, - [&client_functions](const asio::yield_context& yield) - { - typename ClientRPC::Request request; - typename ClientRPC::Response response; - client_functions(request, response, yield); - }...); - } - - template - void register_callback_and_perform_requests(RPCHandler&& handler, ClientFunctions&&... client_functions) - { - int counter{}; - agrpc::register_callback_rpc_handler(this->get_executor(), this->service, handler, - test::RethrowFirstArg{}); - test::spawn_and_run( - this->grpc_context, - [&counter, &client_functions, &server_shutdown = this->server_shutdown](const asio::yield_context& yield) + auto counter = std::make_shared(); + test::spawn_many( + executor, + [counter, &client_functions, &server_shutdown = this->server_shutdown](const asio::yield_context& yield) { typename ClientRPC::Request request; typename ClientRPC::Response response; client_functions(request, response, yield); - ++counter; - if (counter == sizeof...(client_functions)) + ++(*counter); + if (*counter == sizeof...(client_functions)) { server_shutdown.initiate(); } }...); } + template + void register_callback_and_perform_requests(RPCHandler&& handler, ClientFunctions&&... client_functions) + { + agrpc::register_callback_rpc_handler(this->get_executor(), this->service, handler, + test::RethrowFirstArg{}); + spawn_client_functions(this->grpc_context, static_cast(client_functions)...); + this->grpc_context.run(); + } + template void register_callback_and_perform_three_requests(RPCHandler&& handler, ClientFunction&& client_function) { - register_callback_and_perform_requests(std::forward(handler), client_function, client_function, + register_callback_and_perform_requests(static_cast(handler), client_function, client_function, client_function); } + template + void register_and_perform_requests_no_shutdown(RPCHandler&& handler, ClientFunctions&&... client_functions) + { + agrpc::register_yield_rpc_handler(this->get_executor(), this->service, handler, + test::RethrowFirstArg{}); + test::spawn_and_run(this->grpc_context, + [&client_functions](const asio::yield_context& yield) + { + typename ClientRPC::Request request; + typename ClientRPC::Response response; + client_functions(request, response, yield); + }...); + } + template void register_and_perform_requests(RPCHandler&& handler, ClientFunctions&&... client_functions) { - int counter{}; agrpc::register_yield_rpc_handler(this->get_executor(), this->service, handler, test::RethrowFirstArg{}); - test::spawn_and_run( - this->grpc_context, - [&counter, &client_functions, &server_shutdown = this->server_shutdown](const asio::yield_context& yield) - { - typename ClientRPC::Request request; - typename ClientRPC::Response response; - client_functions(request, response, yield); - ++counter; - if (counter == sizeof...(client_functions)) - { - server_shutdown.initiate(); - } - }...); + spawn_client_functions(this->grpc_context, static_cast(client_functions)...); + this->grpc_context.run(); } template void register_and_perform_three_requests(RPCHandler&& handler, ClientFunction&& client_function) { - register_and_perform_requests(std::forward(handler), client_function, client_function, + register_and_perform_requests(static_cast(handler), client_function, client_function, client_function); }