diff --git a/docs/examples/config/example-config.json b/docs/examples/config/example-config.json index 579696724..f9d66ce37 100644 --- a/docs/examples/config/example-config.json +++ b/docs/examples/config/example-config.json @@ -77,7 +77,8 @@ // send a reply for each request whenever it is ready. "parallel_requests_limit": 10, // Optional parameter, used only if "processing_strategy" is "parallel". It limits the number of requests for one client connection processed in parallel. Infinite if not specified. // Max number of responses to queue up before sent successfully. If a client's waiting queue is too long, the server will close the connection. - "ws_max_sending_queue_size": 1500 + "ws_max_sending_queue_size": 1500, + "__ng_web_server": false // Use ng web server. This is a temporary setting which will be deleted after switching to ng web server }, // Time in seconds for graceful shutdown. Defaults to 10 seconds. Not fully implemented yet. "graceful_period": 10.0, diff --git a/src/app/CliArgs.cpp b/src/app/CliArgs.cpp index 4780f2f69..b3a57a4d1 100644 --- a/src/app/CliArgs.cpp +++ b/src/app/CliArgs.cpp @@ -44,6 +44,7 @@ CliArgs::parse(int argc, char const* argv[]) ("help,h", "print help message and exit") ("version,v", "print version and exit") ("conf,c", po::value()->default_value(defaultConfigPath), "configuration file") + ("ng-web-server,w", "Use ng-web-server") ; // clang-format on po::positional_options_description positional; @@ -64,7 +65,8 @@ CliArgs::parse(int argc, char const* argv[]) } auto configPath = parsed["conf"].as(); - return Action{Action::Run{std::move(configPath)}}; + return Action{Action::Run{.configPath = std::move(configPath), .useNgWebServer = parsed.count("ng-web-server") != 0} + }; } } // namespace app diff --git a/src/app/CliArgs.hpp b/src/app/CliArgs.hpp index bc7dd738c..77dd8eb76 100644 --- a/src/app/CliArgs.hpp +++ b/src/app/CliArgs.hpp @@ -43,14 +43,13 @@ class CliArgs { public: /** @brief Run action. */ struct Run { - /** @brief Configuration file path. */ - std::string configPath; + std::string configPath; ///< Configuration file path. + bool useNgWebServer; ///< Whether to use a ng web server }; /** @brief Exit action. */ struct Exit { - /** @brief Exit code. */ - int exitCode; + int exitCode; ///< Exit code. }; /** diff --git a/src/app/ClioApplication.cpp b/src/app/ClioApplication.cpp index 95c4b33a0..a9573659b 100644 --- a/src/app/ClioApplication.cpp +++ b/src/app/ClioApplication.cpp @@ -26,25 +26,39 @@ #include "etl/NetworkValidatedLedgers.hpp" #include "feed/SubscriptionManager.hpp" #include "rpc/Counters.hpp" +#include "rpc/Errors.hpp" #include "rpc/RPCEngine.hpp" #include "rpc/WorkQueue.hpp" #include "rpc/common/impl/HandlerProvider.hpp" +#include "util/Assert.hpp" #include "util/build/Build.hpp" #include "util/config/Config.hpp" #include "util/log/Logger.hpp" +#include "util/prometheus/Http.hpp" #include "util/prometheus/Prometheus.hpp" +#include "web/AdminVerificationStrategy.hpp" #include "web/RPCServerHandler.hpp" #include "web/Server.hpp" +#include "web/SubscriptionContextInterface.hpp" #include "web/dosguard/DOSGuard.hpp" #include "web/dosguard/IntervalSweepHandler.hpp" #include "web/dosguard/WhitelistHandler.hpp" +#include "web/ng/Connection.hpp" +#include "web/ng/RPCServerHandler.hpp" +#include "web/ng/Request.hpp" +#include "web/ng/Response.hpp" +#include "web/ng/Server.hpp" #include +#include +#include #include #include +#include #include #include +#include #include namespace app { @@ -79,7 +93,7 @@ ClioApplication::ClioApplication(util::Config const& config) : config_(config), } int -ClioApplication::run() +ClioApplication::run(bool const useNgWebServer) { auto const threads = config_.valueOr("io_threads", 2); if (threads <= 0) { @@ -126,9 +140,91 @@ ClioApplication::run() auto const rpcEngine = RPCEngineType::make_RPCEngine(config_, backend, balancer, dosGuard, workQueue, counters, handlerProvider); + if (useNgWebServer or config_.valueOr("server.__ng_web_server", false)) { + web::ng::RPCServerHandler handler{config_, backend, rpcEngine, etl}; + + auto expectedAdminVerifier = web::make_AdminVerificationStrategy(config_); + if (not expectedAdminVerifier.has_value()) { + LOG(util::LogService::error()) << "Error creating admin verifier: " << expectedAdminVerifier.error(); + return EXIT_FAILURE; + } + auto const adminVerifier = std::move(expectedAdminVerifier).value(); + + auto httpServer = web::ng::make_Server(config_, ioc); + + if (not httpServer.has_value()) { + LOG(util::LogService::error()) << "Error creating web server: " << httpServer.error(); + return EXIT_FAILURE; + } + + httpServer->onGet( + "/metrics", + [adminVerifier]( + web::ng::Request const& request, + web::ng::ConnectionMetadata& connectionMetadata, + web::SubscriptionContextPtr, + boost::asio::yield_context + ) -> web::ng::Response { + auto const maybeHttpRequest = request.asHttpRequest(); + ASSERT(maybeHttpRequest.has_value(), "Got not a http request in Get"); + auto const& httpRequest = maybeHttpRequest->get(); + + // FIXME(#1702): Using veb server thread to handle prometheus request. Better to post on work queue. + auto maybeResponse = util::prometheus::handlePrometheusRequest( + httpRequest, adminVerifier->isAdmin(httpRequest, connectionMetadata.ip()) + ); + ASSERT(maybeResponse.has_value(), "Got unexpected request for Prometheus"); + return web::ng::Response{std::move(maybeResponse).value(), request}; + } + ); + + util::Logger webServerLog{"WebServer"}; + auto onRequest = [adminVerifier, &webServerLog, &handler]( + web::ng::Request const& request, + web::ng::ConnectionMetadata& connectionMetadata, + web::SubscriptionContextPtr subscriptionContext, + boost::asio::yield_context yield + ) -> web::ng::Response { + LOG(webServerLog.info()) << connectionMetadata.tag() + << "Received request from ip = " << connectionMetadata.ip() + << " - posting to WorkQueue"; + + connectionMetadata.setIsAdmin([&adminVerifier, &request, &connectionMetadata]() { + return adminVerifier->isAdmin(request.httpHeaders(), connectionMetadata.ip()); + }); + + try { + return handler(request, connectionMetadata, std::move(subscriptionContext), yield); + } catch (std::exception const&) { + return web::ng::Response{ + boost::beast::http::status::internal_server_error, + rpc::makeError(rpc::RippledError::rpcINTERNAL), + request + }; + } + }; + + httpServer->onPost("/", onRequest); + httpServer->onWs(onRequest); + + auto const maybeError = httpServer->run(); + if (maybeError.has_value()) { + LOG(util::LogService::error()) << "Error starting web server: " << *maybeError; + return EXIT_FAILURE; + } + + // Blocks until stopped. + // When stopped, shared_ptrs fall out of scope + // Calls destructors on all resources, and destructs in order + start(ioc, threads); + + return EXIT_SUCCESS; + } + // Init the web server auto handler = std::make_shared>(config_, backend, rpcEngine, etl); + auto const httpServer = web::make_HttpServer(config_, ioc, dosGuard, handler); // Blocks until stopped. diff --git a/src/app/ClioApplication.hpp b/src/app/ClioApplication.hpp index 6bb31ea7a..30fbaf8cc 100644 --- a/src/app/ClioApplication.hpp +++ b/src/app/ClioApplication.hpp @@ -42,10 +42,12 @@ class ClioApplication { /** * @brief Run the application * + * @param useNgWebServer Whether to use the new web server + * * @return exit code */ int - run(); + run(bool useNgWebServer); }; } // namespace app diff --git a/src/feed/Types.hpp b/src/feed/Types.hpp index c4399d09c..b00935559 100644 --- a/src/feed/Types.hpp +++ b/src/feed/Types.hpp @@ -19,12 +19,13 @@ #pragma once -#include "web/interface/ConnectionBase.hpp" +#include "web/SubscriptionContextInterface.hpp" #include namespace feed { -using Subscriber = web::ConnectionBase; + +using Subscriber = web::SubscriptionContextInterface; using SubscriberPtr = Subscriber*; using SubscriberSharedPtr = std::shared_ptr; diff --git a/src/feed/impl/ProposedTransactionFeed.cpp b/src/feed/impl/ProposedTransactionFeed.cpp index 8ceae9d2c..780d16e8b 100644 --- a/src/feed/impl/ProposedTransactionFeed.cpp +++ b/src/feed/impl/ProposedTransactionFeed.cpp @@ -48,7 +48,7 @@ ProposedTransactionFeed::sub(SubscriberSharedPtr const& subscriber) if (added) { LOG(logger_.info()) << subscriber->tag() << "Subscribed tx_proposed"; ++subAllCount_.get(); - subscriber->onDisconnect.connect([this](SubscriberPtr connection) { unsubInternal(connection); }); + subscriber->onDisconnect([this](SubscriberPtr connection) { unsubInternal(connection); }); } } @@ -73,9 +73,7 @@ ProposedTransactionFeed::sub(ripple::AccountID const& account, SubscriberSharedP if (added) { LOG(logger_.info()) << subscriber->tag() << "Subscribed accounts_proposed " << account; ++subAccountCount_.get(); - subscriber->onDisconnect.connect([this, account](SubscriberPtr connection) { - unsubInternal(account, connection); - }); + subscriber->onDisconnect([this, account](SubscriberPtr connection) { unsubInternal(account, connection); }); } } diff --git a/src/feed/impl/SingleFeedBase.cpp b/src/feed/impl/SingleFeedBase.cpp index f6dc085b5..4650cb267 100644 --- a/src/feed/impl/SingleFeedBase.cpp +++ b/src/feed/impl/SingleFeedBase.cpp @@ -49,7 +49,7 @@ SingleFeedBase::sub(SubscriberSharedPtr const& subscriber) if (added) { LOG(logger_.info()) << subscriber->tag() << "Subscribed " << name_; ++subCount_.get(); - subscriber->onDisconnect.connect([this](SubscriberPtr connectionDisconnecting) { + subscriber->onDisconnect([this](SubscriberPtr connectionDisconnecting) { unsubInternal(connectionDisconnecting); }); }; diff --git a/src/feed/impl/TrackableSignalMap.hpp b/src/feed/impl/TrackableSignalMap.hpp index 2e79ad305..38dd91be5 100644 --- a/src/feed/impl/TrackableSignalMap.hpp +++ b/src/feed/impl/TrackableSignalMap.hpp @@ -23,6 +23,7 @@ #include +#include #include #include #include diff --git a/src/feed/impl/TransactionFeed.cpp b/src/feed/impl/TransactionFeed.cpp index af3de6ae6..7ff0ba852 100644 --- a/src/feed/impl/TransactionFeed.cpp +++ b/src/feed/impl/TransactionFeed.cpp @@ -53,14 +53,14 @@ namespace feed::impl { void TransactionFeed::TransactionSlot::operator()(AllVersionTransactionsType const& allVersionMsgs) const { - if (auto connection = connectionWeakPtr.lock(); connection) { + if (auto connection = subscriptionContextWeakPtr.lock(); connection) { // Check if this connection already sent if (feed.get().notified_.contains(connection.get())) return; feed.get().notified_.insert(connection.get()); - if (connection->apiSubVersion < 2u) { + if (connection->apiSubversion() < 2u) { connection->send(allVersionMsgs[0]); return; } @@ -75,7 +75,7 @@ TransactionFeed::sub(SubscriberSharedPtr const& subscriber) if (added) { LOG(logger_.info()) << subscriber->tag() << "Subscribed transactions"; ++subAllCount_.get(); - subscriber->onDisconnect.connect([this](SubscriberPtr connection) { unsubInternal(connection); }); + subscriber->onDisconnect([this](SubscriberPtr connection) { unsubInternal(connection); }); } } @@ -86,18 +86,16 @@ TransactionFeed::sub(ripple::AccountID const& account, SubscriberSharedPtr const if (added) { LOG(logger_.info()) << subscriber->tag() << "Subscribed account " << account; ++subAccountCount_.get(); - subscriber->onDisconnect.connect([this, account](SubscriberPtr connection) { - unsubInternal(account, connection); - }); + subscriber->onDisconnect([this, account](SubscriberPtr connection) { unsubInternal(account, connection); }); } } void TransactionFeed::subProposed(SubscriberSharedPtr const& subscriber) { - auto const added = txProposedsignal_.connectTrackableSlot(subscriber, TransactionSlot(*this, subscriber)); + auto const added = txProposedSignal_.connectTrackableSlot(subscriber, TransactionSlot(*this, subscriber)); if (added) { - subscriber->onDisconnect.connect([this](SubscriberPtr connection) { unsubProposedInternal(connection); }); + subscriber->onDisconnect([this](SubscriberPtr connection) { unsubProposedInternal(connection); }); } } @@ -107,7 +105,7 @@ TransactionFeed::subProposed(ripple::AccountID const& account, SubscriberSharedP auto const added = accountProposedSignal_.connectTrackableSlot(subscriber, account, TransactionSlot(*this, subscriber)); if (added) { - subscriber->onDisconnect.connect([this, account](SubscriberPtr connection) { + subscriber->onDisconnect([this, account](SubscriberPtr connection) { unsubProposedInternal(account, connection); }); } @@ -120,7 +118,7 @@ TransactionFeed::sub(ripple::Book const& book, SubscriberSharedPtr const& subscr if (added) { LOG(logger_.info()) << subscriber->tag() << "Subscribed book " << book; ++subBookCount_.get(); - subscriber->onDisconnect.connect([this, book](SubscriberPtr connection) { unsubInternal(book, connection); }); + subscriber->onDisconnect([this, book](SubscriberPtr connection) { unsubInternal(book, connection); }); } } @@ -285,7 +283,7 @@ TransactionFeed::pub( // clear the notified set. If the same connection subscribes both transactions + proposed_transactions, // rippled SENDS the same message twice notified_.clear(); - txProposedsignal_.emit(allVersionsMsgs); + txProposedSignal_.emit(allVersionsMsgs); notified_.clear(); // check duplicate for account and proposed_account, this prevents sending the same message multiple times // if it affects multiple accounts watched by the same connection @@ -323,7 +321,7 @@ TransactionFeed::unsubInternal(ripple::AccountID const& account, SubscriberPtr s void TransactionFeed::unsubProposedInternal(SubscriberPtr subscriber) { - txProposedsignal_.disconnect(subscriber); + txProposedSignal_.disconnect(subscriber); } void diff --git a/src/feed/impl/TransactionFeed.hpp b/src/feed/impl/TransactionFeed.hpp index e57cc12bf..787fd7614 100644 --- a/src/feed/impl/TransactionFeed.hpp +++ b/src/feed/impl/TransactionFeed.hpp @@ -52,10 +52,10 @@ class TransactionFeed { struct TransactionSlot { std::reference_wrapper feed; - std::weak_ptr connectionWeakPtr; + std::weak_ptr subscriptionContextWeakPtr; TransactionSlot(TransactionFeed& feed, SubscriberSharedPtr const& connection) - : feed(feed), connectionWeakPtr(connection) + : feed(feed), subscriptionContextWeakPtr(connection) { } @@ -76,7 +76,7 @@ class TransactionFeed { // Signals for proposed tx subscribers TrackableSignalMap accountProposedSignal_; - TrackableSignal txProposedsignal_; + TrackableSignal txProposedSignal_; std::unordered_set notified_; // Used by slots to prevent double notifications if tx contains multiple subscribed accounts diff --git a/src/main/Main.cpp b/src/main/Main.cpp index f637402df..ffd4dbd96 100644 --- a/src/main/Main.cpp +++ b/src/main/Main.cpp @@ -44,7 +44,8 @@ try { } util::LogService::init(config); app::ClioApplication clio{config}; - return clio.run(); + + return clio.run(run.useNgWebServer); } ); } catch (std::exception const& e) { diff --git a/src/rpc/Factories.cpp b/src/rpc/Factories.cpp index adbb30e51..c90c0b48d 100644 --- a/src/rpc/Factories.cpp +++ b/src/rpc/Factories.cpp @@ -25,6 +25,7 @@ #include "rpc/common/Types.hpp" #include "util/Taggable.hpp" #include "web/Context.hpp" +#include "web/SubscriptionContextInterface.hpp" #include #include @@ -34,8 +35,8 @@ #include #include -#include #include +#include using namespace std; using namespace util; @@ -46,11 +47,12 @@ std::expected make_WsContext( boost::asio::yield_context yc, boost::json::object const& request, - std::shared_ptr const& session, + web::SubscriptionContextPtr session, util::TagDecoratorFactory const& tagFactory, data::LedgerRange const& range, std::string const& clientIp, - std::reference_wrapper apiVersionParser + std::reference_wrapper apiVersionParser, + bool isAdmin ) { boost::json::value commandValue = nullptr; @@ -68,7 +70,7 @@ make_WsContext( return Error{{ClioError::rpcINVALID_API_VERSION, apiVersion.error()}}; auto const command = boost::json::value_to(commandValue); - return web::Context(yc, command, *apiVersion, request, session, tagFactory, range, clientIp, session->isAdmin()); + return web::Context(yc, command, *apiVersion, request, std::move(session), tagFactory, range, clientIp, isAdmin); } std::expected @@ -94,7 +96,7 @@ make_HttpContext( auto const command = boost::json::value_to(request.at("method")); if (command == "subscribe" || command == "unsubscribe") - return Error{{RippledError::rpcBAD_SYNTAX, "Subscribe and unsubscribe are only allowed or websocket."}}; + return Error{{RippledError::rpcBAD_SYNTAX, "Subscribe and unsubscribe are only allowed for websocket."}}; if (!request.at("params").is_array()) return Error{{ClioError::rpcPARAMS_UNPARSEABLE, "Missing params array."}}; diff --git a/src/rpc/Factories.hpp b/src/rpc/Factories.hpp index 88e4c2966..08e1b7261 100644 --- a/src/rpc/Factories.hpp +++ b/src/rpc/Factories.hpp @@ -24,7 +24,7 @@ #include "rpc/common/APIVersion.hpp" #include "util/Taggable.hpp" #include "web/Context.hpp" -#include "web/interface/ConnectionBase.hpp" +#include "web/SubscriptionContextInterface.hpp" #include #include @@ -32,7 +32,6 @@ #include #include -#include #include /* @@ -49,22 +48,24 @@ namespace rpc { * * @param yc The coroutine context * @param request The request as JSON object - * @param session The connection + * @param session The subscription context * @param tagFactory A factory that provides tags to track requests * @param range The ledger range that is available at request time * @param clientIp The IP address of the connected client * @param apiVersionParser A parser that is used to parse out the "api_version" field + * @param isAdmin Whether the request has admin privileges * @return A Websocket context or error Status */ std::expected make_WsContext( boost::asio::yield_context yc, boost::json::object const& request, - std::shared_ptr const& session, + web::SubscriptionContextPtr session, util::TagDecoratorFactory const& tagFactory, data::LedgerRange const& range, std::string const& clientIp, - std::reference_wrapper apiVersionParser + std::reference_wrapper apiVersionParser, + bool isAdmin ); /** diff --git a/src/rpc/common/Types.hpp b/src/rpc/common/Types.hpp index 0ae5ffad1..6b68daddb 100644 --- a/src/rpc/common/Types.hpp +++ b/src/rpc/common/Types.hpp @@ -20,6 +20,7 @@ #pragma once #include "rpc/Errors.hpp" +#include "web/SubscriptionContextInterface.hpp" #include #include @@ -32,7 +33,6 @@ #include #include -#include #include #include #include @@ -117,7 +117,7 @@ struct VoidOutput {}; */ struct Context { boost::asio::yield_context yield; - std::shared_ptr session = {}; // NOLINT(readability-redundant-member-init) + web::SubscriptionContextPtr session = {}; // NOLINT(readability-redundant-member-init) bool isAdmin = false; std::string clientIp = {}; // NOLINT(readability-redundant-member-init) uint32_t apiVersion = 0u; // invalid by default diff --git a/src/rpc/handlers/Subscribe.cpp b/src/rpc/handlers/Subscribe.cpp index 118484ac1..367168eeb 100644 --- a/src/rpc/handlers/Subscribe.cpp +++ b/src/rpc/handlers/Subscribe.cpp @@ -22,6 +22,7 @@ #include "data/BackendInterface.hpp" #include "data/Types.hpp" #include "feed/SubscriptionManagerInterface.hpp" +#include "feed/Types.hpp" #include "rpc/Errors.hpp" #include "rpc/JS.hpp" #include "rpc/RPCHelpers.hpp" @@ -114,7 +115,7 @@ SubscribeHandler::process(Input input, Context const& ctx) const auto output = Output{}; // Mimic rippled. No matter what the request is, the api version changes for the whole session - ctx.session->apiSubVersion = ctx.apiVersion; + ctx.session->setApiSubversion(ctx.apiVersion); if (input.streams) { auto const ledger = subscribeToStreams(ctx.yield, *(input.streams), ctx.session); @@ -138,7 +139,7 @@ boost::json::object SubscribeHandler::subscribeToStreams( boost::asio::yield_context yield, std::vector const& streams, - std::shared_ptr const& session + feed::SubscriberSharedPtr const& session ) const { auto response = boost::json::object{}; @@ -165,7 +166,7 @@ SubscribeHandler::subscribeToStreams( void SubscribeHandler::subscribeToAccountsProposed( std::vector const& accounts, - std::shared_ptr const& session + feed::SubscriberSharedPtr const& session ) const { for (auto const& account : accounts) { @@ -177,7 +178,7 @@ SubscribeHandler::subscribeToAccountsProposed( void SubscribeHandler::subscribeToAccounts( std::vector const& accounts, - std::shared_ptr const& session + feed::SubscriberSharedPtr const& session ) const { for (auto const& account : accounts) { @@ -189,7 +190,7 @@ SubscribeHandler::subscribeToAccounts( void SubscribeHandler::subscribeToBooks( std::vector const& books, - std::shared_ptr const& session, + feed::SubscriberSharedPtr const& session, boost::asio::yield_context yield, Output& output ) const diff --git a/src/rpc/handlers/Subscribe.hpp b/src/rpc/handlers/Subscribe.hpp index 2f0824776..89c295c13 100644 --- a/src/rpc/handlers/Subscribe.hpp +++ b/src/rpc/handlers/Subscribe.hpp @@ -21,6 +21,7 @@ #include "data/BackendInterface.hpp" #include "feed/SubscriptionManagerInterface.hpp" +#include "feed/Types.hpp" #include "rpc/common/Specs.hpp" #include "rpc/common/Types.hpp" @@ -128,23 +129,20 @@ class SubscribeHandler { subscribeToStreams( boost::asio::yield_context yield, std::vector const& streams, - std::shared_ptr const& session + feed::SubscriberSharedPtr const& session ) const; void - subscribeToAccounts(std::vector const& accounts, std::shared_ptr const& session) - const; + subscribeToAccounts(std::vector const& accounts, feed::SubscriberSharedPtr const& session) const; void - subscribeToAccountsProposed( - std::vector const& accounts, - std::shared_ptr const& session - ) const; + subscribeToAccountsProposed(std::vector const& accounts, feed::SubscriberSharedPtr const& session) + const; void subscribeToBooks( std::vector const& books, - std::shared_ptr const& session, + feed::SubscriberSharedPtr const& session, boost::asio::yield_context yield, Output& output ) const; diff --git a/src/rpc/handlers/Unsubscribe.cpp b/src/rpc/handlers/Unsubscribe.cpp index ff8e316b6..177b4d2e6 100644 --- a/src/rpc/handlers/Unsubscribe.cpp +++ b/src/rpc/handlers/Unsubscribe.cpp @@ -21,6 +21,7 @@ #include "data/BackendInterface.hpp" #include "feed/SubscriptionManagerInterface.hpp" +#include "feed/Types.hpp" #include "rpc/Errors.hpp" #include "rpc/JS.hpp" #include "rpc/RPCHelpers.hpp" @@ -106,10 +107,11 @@ UnsubscribeHandler::process(Input input, Context const& ctx) const return Output{}; } + void UnsubscribeHandler::unsubscribeFromStreams( std::vector const& streams, - std::shared_ptr const& session + feed::SubscriberSharedPtr const& session ) const { for (auto const& stream : streams) { @@ -130,21 +132,21 @@ UnsubscribeHandler::unsubscribeFromStreams( } } } + void -UnsubscribeHandler::unsubscribeFromAccounts( - std::vector accounts, - std::shared_ptr const& session -) const +UnsubscribeHandler::unsubscribeFromAccounts(std::vector accounts, feed::SubscriberSharedPtr const& session) + const { for (auto const& account : accounts) { auto const accountID = accountFromStringStrict(account); subscriptions_->unsubAccount(*accountID, session); } } + void UnsubscribeHandler::unsubscribeFromProposedAccounts( std::vector accountsProposed, - std::shared_ptr const& session + feed::SubscriberSharedPtr const& session ) const { for (auto const& account : accountsProposed) { @@ -153,10 +155,8 @@ UnsubscribeHandler::unsubscribeFromProposedAccounts( } } void -UnsubscribeHandler::unsubscribeFromBooks( - std::vector const& books, - std::shared_ptr const& session -) const +UnsubscribeHandler::unsubscribeFromBooks(std::vector const& books, feed::SubscriberSharedPtr const& session) + const { for (auto const& orderBook : books) { subscriptions_->unsubBook(orderBook.book, session); diff --git a/src/rpc/handlers/Unsubscribe.hpp b/src/rpc/handlers/Unsubscribe.hpp index ef20b43e9..1f89022b0 100644 --- a/src/rpc/handlers/Unsubscribe.hpp +++ b/src/rpc/handlers/Unsubscribe.hpp @@ -21,6 +21,7 @@ #include "data/BackendInterface.hpp" #include "feed/SubscriptionManagerInterface.hpp" +#include "feed/Types.hpp" #include "rpc/common/Specs.hpp" #include "rpc/common/Types.hpp" @@ -105,22 +106,17 @@ class UnsubscribeHandler { private: void - unsubscribeFromStreams(std::vector const& streams, std::shared_ptr const& session) - const; + unsubscribeFromStreams(std::vector const& streams, feed::SubscriberSharedPtr const& session) const; void - unsubscribeFromAccounts(std::vector accounts, std::shared_ptr const& session) - const; + unsubscribeFromAccounts(std::vector accounts, feed::SubscriberSharedPtr const& session) const; void - unsubscribeFromProposedAccounts( - std::vector accountsProposed, - std::shared_ptr const& session - ) const; + unsubscribeFromProposedAccounts(std::vector accountsProposed, feed::SubscriberSharedPtr const& session) + const; void - unsubscribeFromBooks(std::vector const& books, std::shared_ptr const& session) - const; + unsubscribeFromBooks(std::vector const& books, feed::SubscriberSharedPtr const& session) const; /** * @brief Convert a JSON object to an Input diff --git a/src/util/CoroutineGroup.cpp b/src/util/CoroutineGroup.cpp index d8aafd0a3..195df2f36 100644 --- a/src/util/CoroutineGroup.cpp +++ b/src/util/CoroutineGroup.cpp @@ -31,7 +31,7 @@ namespace util { -CoroutineGroup::CoroutineGroup(boost::asio::yield_context yield, std::optional maxChildren) +CoroutineGroup::CoroutineGroup(boost::asio::yield_context yield, std::optional maxChildren) : timer_{yield.get_executor(), boost::asio::steady_timer::duration::max()}, maxChildren_{maxChildren} { } @@ -41,28 +41,30 @@ CoroutineGroup::~CoroutineGroup() ASSERT(childrenCounter_ == 0, "CoroutineGroup is destroyed without waiting for child coroutines to finish"); } -bool -CoroutineGroup::canSpawn() const -{ - return not maxChildren_.has_value() or childrenCounter_ < *maxChildren_; -} - bool CoroutineGroup::spawn(boost::asio::yield_context yield, std::function fn) { - if (not canSpawn()) + if (isFull()) return false; ++childrenCounter_; boost::asio::spawn(yield, [this, fn = std::move(fn)](boost::asio::yield_context yield) { fn(yield); - --childrenCounter_; - if (childrenCounter_ == 0) - timer_.cancel(); + onCoroutineCompleted(); }); return true; } +std::optional> +CoroutineGroup::registerForeign() +{ + if (isFull()) + return std::nullopt; + + ++childrenCounter_; + return [this]() { onCoroutineCompleted(); }; +} + void CoroutineGroup::asyncWait(boost::asio::yield_context yield) { @@ -79,4 +81,20 @@ CoroutineGroup::size() const return childrenCounter_; } +bool +CoroutineGroup::isFull() const +{ + return maxChildren_.has_value() && childrenCounter_ >= *maxChildren_; +} + +void +CoroutineGroup::onCoroutineCompleted() +{ + ASSERT(childrenCounter_ != 0, "onCoroutineCompleted() called more times than the number of child coroutines"); + + --childrenCounter_; + if (childrenCounter_ == 0) + timer_.cancel(); +} + } // namespace util diff --git a/src/util/CoroutineGroup.hpp b/src/util/CoroutineGroup.hpp index 0654d1991..9e5b70b18 100644 --- a/src/util/CoroutineGroup.hpp +++ b/src/util/CoroutineGroup.hpp @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -31,11 +32,12 @@ namespace util { /** * @brief CoroutineGroup is a helper class to manage a group of coroutines. It allows to spawn multiple coroutines and * wait for all of them to finish. + * @note This class is safe to use from multiple threads. */ class CoroutineGroup { boost::asio::steady_timer timer_; - std::optional maxChildren_; - int childrenCounter_{0}; + std::optional maxChildren_; + std::atomic_size_t childrenCounter_{0}; public: /** @@ -45,7 +47,7 @@ class CoroutineGroup { * @param maxChildren The maximum number of coroutines that can be spawned at the same time. If not provided, there * is no limit */ - CoroutineGroup(boost::asio::yield_context yield, std::optional maxChildren = std::nullopt); + CoroutineGroup(boost::asio::yield_context yield, std::optional maxChildren = std::nullopt); /** * @brief Destroy the Coroutine Group object @@ -54,14 +56,6 @@ class CoroutineGroup { */ ~CoroutineGroup(); - /** - * @brief Check if a new coroutine can be spawned (i.e. there is space for a new coroutine in the group) - * - * @return true If a new coroutine can be spawned. false if the maximum number of coroutines has been reached - */ - bool - canSpawn() const; - /** * @brief Spawn a new coroutine in the group * @@ -74,6 +68,16 @@ class CoroutineGroup { bool spawn(boost::asio::yield_context yield, std::function fn); + /** + * @brief Register a foreign coroutine this group should wait for. + * @note A foreign coroutine is still counted as a child one, i.e. calling this method increases the size of the + * group. + * + * @return A callback to call on foreign coroutine completes or std::nullopt if the group is already full. + */ + std::optional> + registerForeign(); + /** * @brief Wait for all the coroutines in the group to finish * @@ -91,6 +95,18 @@ class CoroutineGroup { */ size_t size() const; + + /** + * @brief Check if the group is full + * + * @return true If the group is full false otherwise + */ + bool + isFull() const; + +private: + void + onCoroutineCompleted(); }; } // namespace util diff --git a/src/web/impl/AdminVerificationStrategy.cpp b/src/web/AdminVerificationStrategy.cpp similarity index 75% rename from src/web/impl/AdminVerificationStrategy.cpp rename to src/web/AdminVerificationStrategy.cpp index 4295279ad..520d66401 100644 --- a/src/web/impl/AdminVerificationStrategy.cpp +++ b/src/web/AdminVerificationStrategy.cpp @@ -17,7 +17,7 @@ */ //============================================================================== -#include "web/impl/AdminVerificationStrategy.hpp" +#include "web/AdminVerificationStrategy.hpp" #include "util/JsonUtils.hpp" #include "util/config/Config.hpp" @@ -33,10 +33,10 @@ #include #include -namespace web::impl { +namespace web { bool -IPAdminVerificationStrategy::isAdmin(RequestType const&, std::string_view ip) const +IPAdminVerificationStrategy::isAdmin(RequestHeader const&, std::string_view ip) const { return ip == "127.0.0.1"; } @@ -54,7 +54,7 @@ PasswordAdminVerificationStrategy::PasswordAdminVerificationStrategy(std::string } bool -PasswordAdminVerificationStrategy::isAdmin(RequestType const& request, std::string_view) const +PasswordAdminVerificationStrategy::isAdmin(RequestHeader const& request, std::string_view) const { auto it = request.find(boost::beast::http::field::authorization); if (it == request.end()) { @@ -81,19 +81,21 @@ make_AdminVerificationStrategy(std::optional password) } std::expected, std::string> -make_AdminVerificationStrategy(util::Config const& serverConfig) +make_AdminVerificationStrategy(util::Config const& config) { - auto adminPassword = serverConfig.maybeValue("admin_password"); - auto const localAdmin = serverConfig.maybeValue("local_admin"); - bool const localAdminEnabled = localAdmin && localAdmin.value(); - - if (localAdminEnabled == adminPassword.has_value()) { - if (adminPassword.has_value()) - return std::unexpected{"Admin config error, local_admin and admin_password can not be set together."}; - return std::unexpected{"Admin config error, either local_admin and admin_password must be specified."}; + auto adminPassword = config.maybeValue("server.admin_password"); + auto const localAdmin = config.maybeValue("server.local_admin"); + + if (adminPassword.has_value() and localAdmin.has_value() and *localAdmin) + return std::unexpected{"Admin config error: 'local_admin' and admin_password can not be set together."}; + + if (localAdmin.has_value() and !*localAdmin and !adminPassword.has_value()) { + return std::unexpected{ + "Admin config error: either 'local_admin' should be enabled or 'admin_password' must be specified." + }; } return make_AdminVerificationStrategy(std::move(adminPassword)); } -} // namespace web::impl +} // namespace web diff --git a/src/web/impl/AdminVerificationStrategy.hpp b/src/web/AdminVerificationStrategy.hpp similarity index 67% rename from src/web/impl/AdminVerificationStrategy.hpp rename to src/web/AdminVerificationStrategy.hpp index fd7726123..4bc17ca41 100644 --- a/src/web/impl/AdminVerificationStrategy.hpp +++ b/src/web/AdminVerificationStrategy.hpp @@ -31,11 +31,14 @@ #include #include -namespace web::impl { +namespace web { +/** + * @brief Interface for admin verification strategies. + */ class AdminVerificationStrategy { public: - using RequestType = boost::beast::http::request; + using RequestHeader = boost::beast::http::request::header_type; virtual ~AdminVerificationStrategy() = default; /** @@ -46,9 +49,12 @@ class AdminVerificationStrategy { * @return true if authorized; false otherwise */ virtual bool - isAdmin(RequestType const& request, std::string_view ip) const = 0; + isAdmin(RequestHeader const& request, std::string_view ip) const = 0; }; +/** + * @brief Admin verification strategy that checks the ip address of the client. + */ class IPAdminVerificationStrategy : public AdminVerificationStrategy { public: /** @@ -59,16 +65,27 @@ class IPAdminVerificationStrategy : public AdminVerificationStrategy { * @return true if authorized; false otherwise */ bool - isAdmin(RequestType const&, std::string_view ip) const override; + isAdmin(RequestHeader const&, std::string_view ip) const override; }; +/** + * @brief Admin verification strategy that checks the password from the request header. + */ class PasswordAdminVerificationStrategy : public AdminVerificationStrategy { private: std::string passwordSha256_; public: + /** + * @brief The prefix for the password in the request header. + */ static constexpr std::string_view passwordPrefix = "Password "; + /** + * @brief Construct a new PasswordAdminVerificationStrategy object + * + * @param password The password to check + */ PasswordAdminVerificationStrategy(std::string const& password); /** @@ -79,13 +96,26 @@ class PasswordAdminVerificationStrategy : public AdminVerificationStrategy { * @return true if the password from request matches admin password from config */ bool - isAdmin(RequestType const& request, std::string_view) const override; + isAdmin(RequestHeader const& request, std::string_view) const override; }; +/** + * @brief Factory function for creating an admin verification strategy. + * + * @param password The optional password to check. + * @return Admin verification strategy. If password is provided, it will be PasswordAdminVerificationStrategy. + * Otherwise, it will be IPAdminVerificationStrategy. + */ std::shared_ptr make_AdminVerificationStrategy(std::optional password); +/** + * @brief Factory function for creating an admin verification strategy from server config. + * + * @param serverConfig The clio config. + * @return Admin verification strategy according to the config or an error message. + */ std::expected, std::string> make_AdminVerificationStrategy(util::Config const& serverConfig); -} // namespace web::impl +} // namespace web diff --git a/src/web/CMakeLists.txt b/src/web/CMakeLists.txt index b5b0ae9f0..083eae90c 100644 --- a/src/web/CMakeLists.txt +++ b/src/web/CMakeLists.txt @@ -2,18 +2,21 @@ add_library(clio_web) target_sources( clio_web - PRIVATE Resolver.cpp + PRIVATE AdminVerificationStrategy.cpp dosguard/DOSGuard.cpp dosguard/IntervalSweepHandler.cpp dosguard/WhitelistHandler.cpp - impl/AdminVerificationStrategy.cpp ng/Connection.cpp + ng/impl/ErrorHandling.cpp ng/impl/ConnectionHandler.cpp ng/impl/ServerSslContext.cpp ng/impl/WsConnection.cpp - ng/Server.cpp ng/Request.cpp ng/Response.cpp + ng/Server.cpp + ng/SubscriptionContext.cpp + Resolver.cpp + SubscriptionContext.cpp ) target_link_libraries(clio_web PUBLIC clio_util) diff --git a/src/web/Context.hpp b/src/web/Context.hpp index 43514537d..8ca16b7c6 100644 --- a/src/web/Context.hpp +++ b/src/web/Context.hpp @@ -22,14 +22,13 @@ #include "data/Types.hpp" #include "util/Taggable.hpp" #include "util/log/Logger.hpp" -#include "web/interface/ConnectionBase.hpp" +#include "web/SubscriptionContextInterface.hpp" #include #include #include #include -#include #include #include @@ -43,7 +42,7 @@ struct Context : util::Taggable { std::string method; std::uint32_t apiVersion; boost::json::object params; - std::shared_ptr session; + SubscriptionContextPtr session; data::LedgerRange range; std::string clientIp; bool isAdmin; @@ -55,7 +54,7 @@ struct Context : util::Taggable { * @param command The method/command requested * @param apiVersion The api_version parsed from the request * @param params Request's parameters/data as a JSON object - * @param session The connection to the peer + * @param subscriptionContext The subscription context of the connection * @param tagFactory A factory that is used to generate tags to track requests and connections * @param range The ledger range that is available at the time of the request * @param clientIp IP of the peer @@ -66,7 +65,7 @@ struct Context : util::Taggable { std::string command, std::uint32_t apiVersion, boost::json::object params, - std::shared_ptr const& session, + SubscriptionContextPtr subscriptionContext, util::TagDecoratorFactory const& tagFactory, data::LedgerRange const& range, std::string clientIp, @@ -77,7 +76,7 @@ struct Context : util::Taggable { , method(std::move(command)) , apiVersion(apiVersion) , params(std::move(params)) - , session(session) + , session(std::move(subscriptionContext)) , range(range) , clientIp(std::move(clientIp)) , isAdmin(isAdmin) diff --git a/src/web/HttpSession.hpp b/src/web/HttpSession.hpp index fa4ab0573..a60979b23 100644 --- a/src/web/HttpSession.hpp +++ b/src/web/HttpSession.hpp @@ -20,9 +20,11 @@ #pragma once #include "util/Taggable.hpp" +#include "web/AdminVerificationStrategy.hpp" #include "web/PlainWsSession.hpp" #include "web/dosguard/DOSGuardInterface.hpp" #include "web/impl/HttpBase.hpp" +#include "web/interface/Concepts.hpp" #include "web/interface/ConnectionBase.hpp" #include @@ -71,7 +73,7 @@ class HttpSession : public impl::HttpBase, explicit HttpSession( tcp::socket&& socket, std::string const& ip, - std::shared_ptr const& adminVerification, + std::shared_ptr const& adminVerification, std::reference_wrapper tagFactory, std::reference_wrapper dosGuard, std::shared_ptr const& handler, diff --git a/src/web/RPCServerHandler.hpp b/src/web/RPCServerHandler.hpp index ed61fe932..46ee718ba 100644 --- a/src/web/RPCServerHandler.hpp +++ b/src/web/RPCServerHandler.hpp @@ -161,11 +161,12 @@ class RPCServerHandler { return rpc::make_WsContext( yield, request, - connection, + connection->makeSubscriptionContext(tagFactory_), tagFactory_.with(connection->tag()), *range, connection->clientIp, - std::cref(apiVersionParser_) + std::cref(apiVersionParser_), + connection->isAdmin() ); } return rpc::make_HttpContext( diff --git a/src/web/Server.hpp b/src/web/Server.hpp index 7a6a3deac..c01d7f3e9 100644 --- a/src/web/Server.hpp +++ b/src/web/Server.hpp @@ -21,6 +21,7 @@ #include "util/Taggable.hpp" #include "util/log/Logger.hpp" +#include "web/AdminVerificationStrategy.hpp" #include "web/HttpSession.hpp" #include "web/SslHttpSession.hpp" #include "web/dosguard/DOSGuardInterface.hpp" @@ -84,7 +85,7 @@ class Detector : public std::enable_shared_from_this const dosGuard_; std::shared_ptr const handler_; boost::beast::flat_buffer buffer_; - std::shared_ptr const adminVerification_; + std::shared_ptr const adminVerification_; std::uint32_t maxWsSendingQueueSize_; public: @@ -105,7 +106,7 @@ class Detector : public std::enable_shared_from_this tagFactory, std::reference_wrapper dosGuard, std::shared_ptr handler, - std::shared_ptr adminVerification, + std::shared_ptr adminVerification, std::uint32_t maxWsSendingQueueSize ) : stream_(std::move(socket)) @@ -216,7 +217,7 @@ class Server : public std::enable_shared_from_this dosGuard_; std::shared_ptr handler_; tcp::acceptor acceptor_; - std::shared_ptr adminVerification_; + std::shared_ptr adminVerification_; std::uint32_t maxWsSendingQueueSize_; public: @@ -229,7 +230,7 @@ class Server : public std::enable_shared_from_this handler, - std::optional adminPassword, + std::shared_ptr adminVerification, std::uint32_t maxWsSendingQueueSize ) : ioc_(std::ref(ioc)) @@ -248,7 +249,7 @@ class Server : public std::enable_shared_from_this("ip")); auto const port = serverConfig.value("port"); - auto adminPassword = serverConfig.maybeValue("admin_password"); - auto const localAdmin = serverConfig.maybeValue("local_admin"); - - // Throw config error when localAdmin is true and admin_password is also set - if (localAdmin && localAdmin.value() && adminPassword) { - LOG(log.error()) << "local_admin is true but admin_password is also set, please specify only one method " - "to authorize admin"; - throw std::logic_error("Admin config error, local_admin and admin_password can not be set together."); - } - // Throw config error when localAdmin is false but admin_password is not set - if (localAdmin && !localAdmin.value() && !adminPassword) { - LOG(log.error()) << "local_admin is false but admin_password is not set, please specify one method " - "to authorize admin"; - throw std::logic_error("Admin config error, one method must be specified to authorize admin."); + + auto expectedAdminVerification = make_AdminVerificationStrategy(config); + if (not expectedAdminVerification.has_value()) { + LOG(log.error()) << expectedAdminVerification.error(); + throw std::logic_error{expectedAdminVerification.error()}; } // If the transactions number is 200 per ledger, A client which subscribes everything will send 400+ feeds for @@ -382,7 +374,7 @@ make_HttpServer( util::TagDecoratorFactory(config), dosGuard, handler, - std::move(adminPassword), + std::move(expectedAdminVerification).value(), maxWsSendingQueueSize ); diff --git a/src/web/SslHttpSession.hpp b/src/web/SslHttpSession.hpp index 82814e69d..43ee85e03 100644 --- a/src/web/SslHttpSession.hpp +++ b/src/web/SslHttpSession.hpp @@ -20,6 +20,7 @@ #pragma once #include "util/Taggable.hpp" +#include "web/AdminVerificationStrategy.hpp" #include "web/SslWsSession.hpp" #include "web/dosguard/DOSGuardInterface.hpp" #include "web/impl/HttpBase.hpp" @@ -79,7 +80,7 @@ class SslHttpSession : public impl::HttpBase, explicit SslHttpSession( tcp::socket&& socket, std::string const& ip, - std::shared_ptr const& adminVerification, + std::shared_ptr const& adminVerification, boost::asio::ssl::context& ctx, std::reference_wrapper tagFactory, std::reference_wrapper dosGuard, diff --git a/src/web/SubscriptionContext.cpp b/src/web/SubscriptionContext.cpp new file mode 100644 index 000000000..d831d1a5f --- /dev/null +++ b/src/web/SubscriptionContext.cpp @@ -0,0 +1,71 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "web/SubscriptionContext.hpp" + +#include "util/Taggable.hpp" +#include "web/SubscriptionContextInterface.hpp" +#include "web/interface/ConnectionBase.hpp" + +#include +#include +#include +#include + +namespace web { + +SubscriptionContext::SubscriptionContext( + util::TagDecoratorFactory const& factory, + std::shared_ptr connection +) + : SubscriptionContextInterface{factory}, connection_{connection} +{ +} + +SubscriptionContext::~SubscriptionContext() +{ + onDisconnect_(this); +} + +void +SubscriptionContext::send(std::shared_ptr message) +{ + if (auto connection = connection_.lock(); connection != nullptr) + connection->send(std::move(message)); +} + +void +SubscriptionContext::onDisconnect(OnDisconnectSlot const& slot) +{ + onDisconnect_.connect(slot); +} + +void +SubscriptionContext::setApiSubversion(uint32_t value) +{ + apiSubVersion_ = value; +} + +uint32_t +SubscriptionContext::apiSubversion() const +{ + return apiSubVersion_; +} + +} // namespace web diff --git a/src/web/SubscriptionContext.hpp b/src/web/SubscriptionContext.hpp new file mode 100644 index 000000000..2c64fa43d --- /dev/null +++ b/src/web/SubscriptionContext.hpp @@ -0,0 +1,96 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include "util/Taggable.hpp" +#include "web/SubscriptionContextInterface.hpp" +#include "web/interface/Concepts.hpp" +#include "web/interface/ConnectionBase.hpp" + +#include + +#include +#include +#include +#include + +namespace web { + +/** + * @brief A context of a WsBase connection for subscriptions. + */ +class SubscriptionContext : public SubscriptionContextInterface { + std::weak_ptr connection_; + boost::signals2::signal onDisconnect_; + /** + * @brief The API version of the web stream client. + * This is used to track the api version of this connection, which mainly is used by subscription. It is different + * from the api version in Context, which is only used for the current request. + */ + std::atomic_uint32_t apiSubVersion_ = 0; + +public: + /** + * @brief Construct a new Subscription Context object + * + * @param factory The tag decorator factory to use to init taggable. + * @param connection The connection for which the context is created. + */ + SubscriptionContext(util::TagDecoratorFactory const& factory, std::shared_ptr connection); + + /** + * @brief Destroy the Subscription Context object + */ + ~SubscriptionContext() override; + + /** + * @brief Send message to the client + * @note This method will not do anything if the related connection got disconnected. + * + * @param message The message to send. + */ + void + send(std::shared_ptr message) override; + + /** + * @brief Connect a slot to onDisconnect connection signal. + * + * @param slot The slot to connect. + */ + void + onDisconnect(OnDisconnectSlot const& slot) override; + + /** + * @brief Set the API subversion. + * @param value The value to set. + */ + void + setApiSubversion(uint32_t value) override; + + /** + * @brief Get the API subversion. + * + * @return The API subversion. + */ + uint32_t + apiSubversion() const override; +}; + +} // namespace web diff --git a/src/web/SubscriptionContextInterface.hpp b/src/web/SubscriptionContextInterface.hpp new file mode 100644 index 000000000..5ad1452dd --- /dev/null +++ b/src/web/SubscriptionContextInterface.hpp @@ -0,0 +1,88 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include "util/Taggable.hpp" + +#include +#include + +#include +#include +#include +#include + +namespace web { + +/** + * @brief An interface to provide connection functionality for subscriptions. + * @note Since subscription is only allowed for websocket connection, this interface is used only for websocket + * connections. + */ +class SubscriptionContextInterface : public util::Taggable { +public: + /** + * @brief Reusing Taggable constructor + */ + using util::Taggable::Taggable; + + /** + * @brief Send message to the client + * + * @param message The message to send. + */ + virtual void + send(std::shared_ptr message) = 0; + + /** + * @brief Alias for on disconnect slot. + */ + using OnDisconnectSlot = std::function; + + /** + * @brief Connect a slot to onDisconnect connection signal. + * + * @param slot The slot to connect. + */ + virtual void + onDisconnect(OnDisconnectSlot const& slot) = 0; + + /** + * @brief Set the API subversion. + * @param value The value to set. + */ + virtual void + setApiSubversion(uint32_t value) = 0; + + /** + * @brief Get the API subversion. + * + * @return The API subversion. + */ + virtual uint32_t + apiSubversion() const = 0; +}; + +/** + * @brief An alias for shared pointer to a SubscriptionContextInterface. + */ +using SubscriptionContextPtr = std::shared_ptr; + +} // namespace web diff --git a/src/web/impl/ErrorHandling.hpp b/src/web/impl/ErrorHandling.hpp index b78876335..bc3fdf0b1 100644 --- a/src/web/impl/ErrorHandling.hpp +++ b/src/web/impl/ErrorHandling.hpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include diff --git a/src/web/impl/HttpBase.hpp b/src/web/impl/HttpBase.hpp index f986800a3..609eb7e2d 100644 --- a/src/web/impl/HttpBase.hpp +++ b/src/web/impl/HttpBase.hpp @@ -20,12 +20,14 @@ #pragma once #include "rpc/Errors.hpp" +#include "util/Assert.hpp" #include "util/Taggable.hpp" #include "util/build/Build.hpp" #include "util/log/Logger.hpp" #include "util/prometheus/Http.hpp" +#include "web/AdminVerificationStrategy.hpp" +#include "web/SubscriptionContextInterface.hpp" #include "web/dosguard/DOSGuardInterface.hpp" -#include "web/impl/AdminVerificationStrategy.hpp" #include "web/interface/Concepts.hpp" #include "web/interface/ConnectionBase.hpp" @@ -275,6 +277,13 @@ class HttpBase : public ConnectionBase { sender_(httpResponse(status, "application/json", std::move(msg))); } + SubscriptionContextPtr + makeSubscriptionContext(util::TagDecoratorFactory const&) override + { + ASSERT(false, "SubscriptionContext can't be created for a HTTP connection"); + std::unreachable(); + } + void onWrite(bool close, boost::beast::error_code ec, std::size_t bytes_transferred) { diff --git a/src/web/impl/WsBase.hpp b/src/web/impl/WsBase.hpp index 2d941bd7a..003ec029f 100644 --- a/src/web/impl/WsBase.hpp +++ b/src/web/impl/WsBase.hpp @@ -23,6 +23,8 @@ #include "rpc/common/Types.hpp" #include "util/Taggable.hpp" #include "util/log/Logger.hpp" +#include "web/SubscriptionContext.hpp" +#include "web/SubscriptionContextInterface.hpp" #include "web/dosguard/DOSGuardInterface.hpp" #include "web/interface/Concepts.hpp" #include "web/interface/ConnectionBase.hpp" @@ -79,6 +81,7 @@ class WsBase : public ConnectionBase, public std::enable_shared_from_this> messages_; std::shared_ptr const handler_; + SubscriptionContextPtr subscriptionContext_; std::uint32_t maxSendingQueueSize_; protected: @@ -184,6 +187,21 @@ class WsBase : public ConnectionBase, public std::enable_shared_from_this(factory, shared_from_this()); + } + return subscriptionContext_; + } + /** * @brief Send a message to the client * @param msg The message to send diff --git a/src/web/interface/Concepts.hpp b/src/web/interface/Concepts.hpp index b8b90602b..b80285c2a 100644 --- a/src/web/interface/Concepts.hpp +++ b/src/web/interface/Concepts.hpp @@ -19,8 +19,6 @@ #pragma once -#include "web/interface/ConnectionBase.hpp" - #include #include @@ -29,6 +27,8 @@ namespace web { +struct ConnectionBase; + /** * @brief Specifies the requirements a Webserver handler must fulfill. */ diff --git a/src/web/interface/ConnectionBase.hpp b/src/web/interface/ConnectionBase.hpp index 439e0c660..839b1d9b0 100644 --- a/src/web/interface/ConnectionBase.hpp +++ b/src/web/interface/ConnectionBase.hpp @@ -20,13 +20,13 @@ #pragma once #include "util/Taggable.hpp" +#include "web/SubscriptionContextInterface.hpp" #include #include #include #include -#include #include #include #include @@ -49,13 +49,6 @@ struct ConnectionBase : public util::Taggable { public: std::string const clientIp; bool upgraded = false; - boost::signals2::signal onDisconnect; - /** - * @brief The API version of the web stream client. - * This is used to track the api version of this connection, which mainly is used by subscription. It is different - * from the api version in Context, which is only used for the current request. - */ - std::uint32_t apiSubVersion = 0; /** * @brief Create a new connection base. @@ -68,11 +61,6 @@ struct ConnectionBase : public util::Taggable { { } - ~ConnectionBase() override - { - onDisconnect(this); - }; - /** * @brief Send the response to the client. * @@ -94,6 +82,15 @@ struct ConnectionBase : public util::Taggable { throw std::logic_error("web server can not send the shared payload"); } + /** + * @brief Get the subscription context for this connection. + * + * @param factory Tag TagDecoratorFactory to use to create the context. + * @return The subscription context for this connection. + */ + virtual SubscriptionContextPtr + makeSubscriptionContext(util::TagDecoratorFactory const& factory) = 0; + /** * @brief Indicates whether the connection had an error and is considered dead. * diff --git a/src/web/ng/Connection.cpp b/src/web/ng/Connection.cpp index 4bfb4bc22..e02fe64b8 100644 --- a/src/web/ng/Connection.cpp +++ b/src/web/ng/Connection.cpp @@ -23,34 +23,34 @@ #include -#include #include #include namespace web::ng { -Connection::Connection( - std::string ip, - boost::beast::flat_buffer buffer, - util::TagDecoratorFactory const& tagDecoratorFactory -) - : util::Taggable(tagDecoratorFactory), ip_{std::move(ip)}, buffer_{std::move(buffer)} +ConnectionMetadata::ConnectionMetadata(std::string ip, util::TagDecoratorFactory const& tagDecoratorFactory) + : util::Taggable(tagDecoratorFactory), ip_{std::move(ip)} { } -ConnectionContext -Connection::context() const +std::string const& +ConnectionMetadata::ip() const { - return ConnectionContext{*this}; + return ip_; } -std::string const& -Connection::ip() const +bool +ConnectionMetadata::isAdmin() const { - return ip_; + return isAdmin_.value_or(false); } -ConnectionContext::ConnectionContext(Connection const& connection) : connection_{connection} +Connection::Connection( + std::string ip, + boost::beast::flat_buffer buffer, + util::TagDecoratorFactory const& tagDecoratorFactory +) + : ConnectionMetadata{std::move(ip), tagDecoratorFactory}, buffer_{std::move(buffer)} { } diff --git a/src/web/ng/Connection.hpp b/src/web/ng/Connection.hpp index 45b20052e..edd57b1d8 100644 --- a/src/web/ng/Connection.hpp +++ b/src/web/ng/Connection.hpp @@ -28,9 +28,9 @@ #include #include +#include #include #include -#include #include #include #include @@ -38,16 +38,67 @@ namespace web::ng { /** - * @brief A forward declaration of ConnectionContext. + * @brief An interface for a connection metadata class. */ -class ConnectionContext; +class ConnectionMetadata : public util::Taggable { +protected: + std::string ip_; // client ip + std::optional isAdmin_; + +public: + /** + * @brief Construct a new ConnectionMetadata object. + * + * @param ip The client ip. + * @param tagDecoratorFactory The factory for creating tag decorators. + */ + ConnectionMetadata(std::string ip, util::TagDecoratorFactory const& tagDecoratorFactory); + + /** + * @brief Whether the connection was upgraded. Upgraded connections are websocket connections. + * + * @return true if the connection was upgraded. + */ + virtual bool + wasUpgraded() const = 0; + + /** + * @brief Get the ip of the client. + * + * @return The ip of the client. + */ + std::string const& + ip() const; + + /** + * @brief Get whether the client is an admin. + * + * @return true if the client is an admin. + */ + bool + isAdmin() const; + + /** + * @brief Set the isAdmin field. + * @note This function is lazy, it will update isAdmin only if it is not set yet. + * + * @tparam T The invocable type of the function to call to set the isAdmin. + * @param setter The function to call to set the isAdmin. + */ + template + void + setIsAdmin(T&& setter) + { + if (not isAdmin_.has_value()) + isAdmin_ = setter(); + } +}; /** - *@brief A class representing a connection to a client. + * @brief A class representing a connection to a client. */ -class Connection : public util::Taggable { +class Connection : public ConnectionMetadata { protected: - std::string ip_; // client ip boost::beast::flat_buffer buffer_; public: @@ -65,14 +116,6 @@ class Connection : public util::Taggable { */ Connection(std::string ip, boost::beast::flat_buffer buffer, util::TagDecoratorFactory const& tagDecoratorFactory); - /** - * @brief Whether the connection was upgraded. Upgraded connections are websocket connections. - * - * @return true if the connection was upgraded. - */ - virtual bool - wasUpgraded() const = 0; - /** * @brief Send a response to the client. * @@ -81,7 +124,6 @@ class Connection : public util::Taggable { * @param timeout The timeout for the operation. * @return An error if the operation failed or nullopt if it succeeded. */ - virtual std::optional send( Response response, @@ -107,22 +149,6 @@ class Connection : public util::Taggable { */ virtual void close(boost::asio::yield_context yield, std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT) = 0; - - /** - * @brief Get the connection context. - * - * @return The connection context. - */ - ConnectionContext - context() const; - - /** - * @brief Get the ip of the client. - * - * @return The ip of the client. - */ - std::string const& - ip() const; }; /** @@ -130,19 +156,4 @@ class Connection : public util::Taggable { */ using ConnectionPtr = std::unique_ptr; -/** - * @brief A class representing the context of a connection. - */ -class ConnectionContext { - std::reference_wrapper connection_; - -public: - /** - * @brief Construct a new ConnectionContext object. - * - * @param connection The connection. - */ - explicit ConnectionContext(Connection const& connection); -}; - } // namespace web::ng diff --git a/src/web/ng/MessageHandler.hpp b/src/web/ng/MessageHandler.hpp index f518238f8..6b3224c31 100644 --- a/src/web/ng/MessageHandler.hpp +++ b/src/web/ng/MessageHandler.hpp @@ -19,6 +19,7 @@ #pragma once +#include "web/SubscriptionContextInterface.hpp" #include "web/ng/Connection.hpp" #include "web/ng/Request.hpp" #include "web/ng/Response.hpp" @@ -32,6 +33,7 @@ namespace web::ng { /** * @brief Handler for messages. */ -using MessageHandler = std::function; +using MessageHandler = + std::function; } // namespace web::ng diff --git a/src/web/ng/ProcessingPolicy.hpp b/src/web/ng/ProcessingPolicy.hpp new file mode 100644 index 000000000..ad69a7444 --- /dev/null +++ b/src/web/ng/ProcessingPolicy.hpp @@ -0,0 +1,29 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +namespace web::ng { + +/** + * @brief Requests processing policy. + */ +enum class ProcessingPolicy { Sequential, Parallel }; + +} // namespace web::ng diff --git a/src/web/ng/RPCServerHandler.hpp b/src/web/ng/RPCServerHandler.hpp new file mode 100644 index 000000000..2774fa931 --- /dev/null +++ b/src/web/ng/RPCServerHandler.hpp @@ -0,0 +1,336 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2023, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include "data/BackendInterface.hpp" +#include "rpc/Errors.hpp" +#include "rpc/Factories.hpp" +#include "rpc/JS.hpp" +#include "rpc/RPCHelpers.hpp" +#include "rpc/common/impl/APIVersionParser.hpp" +#include "util/Assert.hpp" +#include "util/CoroutineGroup.hpp" +#include "util/JsonUtils.hpp" +#include "util/Profiler.hpp" +#include "util/Taggable.hpp" +#include "util/config/Config.hpp" +#include "util/log/Logger.hpp" +#include "web/SubscriptionContextInterface.hpp" +#include "web/ng/Connection.hpp" +#include "web/ng/Request.hpp" +#include "web/ng/Response.hpp" +#include "web/ng/impl/ErrorHandling.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace web::ng { + +/** + * @brief The server handler for RPC requests called by web server. + * + * Note: see @ref web::SomeServerHandler concept + */ +template +class RPCServerHandler { + std::shared_ptr const backend_; + std::shared_ptr const rpcEngine_; + std::shared_ptr const etl_; + util::TagDecoratorFactory const tagFactory_; + rpc::impl::ProductionAPIVersionParser apiVersionParser_; // can be injected if needed + + util::Logger log_{"RPC"}; + util::Logger perfLog_{"Performance"}; + +public: + /** + * @brief Create a new server handler. + * + * @param config Clio config to use + * @param backend The backend to use + * @param rpcEngine The RPC engine to use + * @param etl The ETL to use + */ + RPCServerHandler( + util::Config const& config, + std::shared_ptr const& backend, + std::shared_ptr const& rpcEngine, + std::shared_ptr const& etl + ) + : backend_(backend) + , rpcEngine_(rpcEngine) + , etl_(etl) + , tagFactory_(config) + , apiVersionParser_(config.sectionOr("api_version", {})) + { + } + + /** + * @brief The callback when server receives a request. + * + * @param request The request + * @param connectionMetadata The connection metadata + * @param subscriptionContext The subscription context + * @param yield The yield context + * @return The response + */ + [[nodiscard]] Response + operator()( + Request const& request, + ConnectionMetadata const& connectionMetadata, + SubscriptionContextPtr subscriptionContext, + boost::asio::yield_context yield + ) + { + std::optional response; + util::CoroutineGroup coroutineGroup{yield, 1}; + auto const onTaskComplete = coroutineGroup.registerForeign(); + ASSERT(onTaskComplete.has_value(), "Coroutine group can't be full"); + + bool const postSuccessful = rpcEngine_->post( + [this, + &request, + &response, + &onTaskComplete = onTaskComplete.value(), + &connectionMetadata, + subscriptionContext = std::move(subscriptionContext)](boost::asio::yield_context yield) mutable { + try { + auto parsedRequest = boost::json::parse(request.message()).as_object(); + LOG(perfLog_.debug()) << connectionMetadata.tag() << "Adding to work queue"; + + if (not connectionMetadata.wasUpgraded() and shouldReplaceParams(parsedRequest)) + parsedRequest[JS(params)] = boost::json::array({boost::json::object{}}); + + response = handleRequest( + yield, request, std::move(parsedRequest), connectionMetadata, std::move(subscriptionContext) + ); + } catch (boost::system::system_error const& ex) { + // system_error thrown when json parsing failed + rpcEngine_->notifyBadSyntax(); + response = impl::ErrorHelper{request}.makeJsonParsingError(); + LOG(log_.warn()) << "Error parsing JSON: " << ex.what() << ". For request: " << request.message(); + } catch (std::invalid_argument const& ex) { + // thrown when json parses something that is not an object at top level + rpcEngine_->notifyBadSyntax(); + LOG(log_.warn()) << "Invalid argument error: " << ex.what() + << ". For request: " << request.message(); + response = impl::ErrorHelper{request}.makeJsonParsingError(); + } catch (std::exception const& ex) { + LOG(perfLog_.error()) << connectionMetadata.tag() << "Caught exception: " << ex.what(); + rpcEngine_->notifyInternalError(); + response = impl::ErrorHelper{request}.makeInternalError(); + } + + // notify the coroutine group that the foreign task is done + onTaskComplete(); + }, + connectionMetadata.ip() + ); + + if (not postSuccessful) { + // onTaskComplete must be called to notify coroutineGroup that the foreign task is done + onTaskComplete->operator()(); + rpcEngine_->notifyTooBusy(); + return impl::ErrorHelper{request}.makeTooBusyError(); + } + + // Put the coroutine to sleep until the foreign task is done + coroutineGroup.asyncWait(yield); + ASSERT(response.has_value(), "Woke up coroutine without setting response"); + return std::move(response).value(); + } + +private: + Response + handleRequest( + boost::asio::yield_context yield, + Request const& rawRequest, + boost::json::object&& request, + ConnectionMetadata const& connectionMetadata, + SubscriptionContextPtr subscriptionContext + ) + { + LOG(log_.info()) << connectionMetadata.tag() << (connectionMetadata.wasUpgraded() ? "ws" : "http") + << " received request from work queue: " << util::removeSecret(request) + << " ip = " << connectionMetadata.ip(); + + try { + auto const range = backend_->fetchLedgerRange(); + if (!range) { + // for error that happened before the handler, we don't attach any warnings + rpcEngine_->notifyNotReady(); + return impl::ErrorHelper{rawRequest, std::move(request)}.makeNotReadyError(); + } + + auto const context = [&] { + if (connectionMetadata.wasUpgraded()) { + ASSERT(subscriptionContext != nullptr, "Subscription context must exist for a WS connecton"); + return rpc::make_WsContext( + yield, + request, + std::move(subscriptionContext), + tagFactory_.with(connectionMetadata.tag()), + *range, + connectionMetadata.ip(), + std::cref(apiVersionParser_), + connectionMetadata.isAdmin() + ); + } + return rpc::make_HttpContext( + yield, + request, + tagFactory_.with(connectionMetadata.tag()), + *range, + connectionMetadata.ip(), + std::cref(apiVersionParser_), + connectionMetadata.isAdmin() + ); + }(); + + if (!context) { + auto const err = context.error(); + LOG(perfLog_.warn()) << connectionMetadata.tag() << "Could not create Web context: " << err; + LOG(log_.warn()) << connectionMetadata.tag() << "Could not create Web context: " << err; + + // we count all those as BadSyntax - as the WS path would. + // Although over HTTP these will yield a 400 status with a plain text response (for most). + rpcEngine_->notifyBadSyntax(); + return impl::ErrorHelper(rawRequest, std::move(request)).makeError(err); + } + + auto [result, timeDiff] = util::timed([&]() { return rpcEngine_->buildResponse(*context); }); + + auto us = std::chrono::duration(timeDiff); + rpc::logDuration(*context, us); + + boost::json::object response; + + if (auto const status = std::get_if(&result.response)) { + // note: error statuses are counted/notified in buildResponse itself + response = impl::ErrorHelper(rawRequest, request).composeError(*status); + auto const responseStr = boost::json::serialize(response); + + LOG(perfLog_.debug()) << context->tag() << "Encountered error: " << responseStr; + LOG(log_.debug()) << context->tag() << "Encountered error: " << responseStr; + } else { + // This can still technically be an error. Clio counts forwarded requests as successful. + rpcEngine_->notifyComplete(context->method, us); + + auto& json = std::get(result.response); + auto const isForwarded = + json.contains("forwarded") && json.at("forwarded").is_bool() && json.at("forwarded").as_bool(); + + if (isForwarded) + json.erase("forwarded"); + + // if the result is forwarded - just use it as is + // if forwarded request has error, for http, error should be in "result"; for ws, error should + // be at top + if (isForwarded && (json.contains(JS(result)) || connectionMetadata.wasUpgraded())) { + for (auto const& [k, v] : json) + response.insert_or_assign(k, v); + } else { + response[JS(result)] = json; + } + + if (isForwarded) + response["forwarded"] = true; + + // for ws there is an additional field "status" in the response, + // otherwise the "status" is in the "result" field + if (connectionMetadata.wasUpgraded()) { + auto const appendFieldIfExist = [&](auto const& field) { + if (request.contains(field) and not request.at(field).is_null()) + response[field] = request.at(field); + }; + + appendFieldIfExist(JS(id)); + appendFieldIfExist(JS(api_version)); + + if (!response.contains(JS(error))) + response[JS(status)] = JS(success); + + response[JS(type)] = JS(response); + } else { + if (response.contains(JS(result)) && !response[JS(result)].as_object().contains(JS(error))) + response[JS(result)].as_object()[JS(status)] = JS(success); + } + } + + boost::json::array warnings = std::move(result.warnings); + warnings.emplace_back(rpc::makeWarning(rpc::warnRPC_CLIO)); + + if (etl_->lastCloseAgeSeconds() >= 60) + warnings.emplace_back(rpc::makeWarning(rpc::warnRPC_OUTDATED)); + + response["warnings"] = warnings; + return Response{boost::beast::http::status::ok, response, rawRequest}; + } catch (std::exception const& ex) { + // note: while we are catching this in buildResponse too, this is here to make sure + // that any other code that may throw is outside of buildResponse is also worked around. + LOG(perfLog_.error()) << connectionMetadata.tag() << "Caught exception: " << ex.what(); + LOG(log_.error()) << connectionMetadata.tag() << "Caught exception: " << ex.what(); + + rpcEngine_->notifyInternalError(); + return impl::ErrorHelper(rawRequest, std::move(request)).makeInternalError(); + } + } + + bool + shouldReplaceParams(boost::json::object const& req) const + { + auto const hasParams = req.contains(JS(params)); + auto const paramsIsArray = hasParams and req.at(JS(params)).is_array(); + auto const paramsIsEmptyString = + hasParams and req.at(JS(params)).is_string() and req.at(JS(params)).as_string().empty(); + auto const paramsIsEmptyObject = + hasParams and req.at(JS(params)).is_object() and req.at(JS(params)).as_object().empty(); + auto const paramsIsNull = hasParams and req.at(JS(params)).is_null(); + auto const arrayIsEmpty = paramsIsArray and req.at(JS(params)).as_array().empty(); + auto const arrayIsNotEmpty = paramsIsArray and not req.at(JS(params)).as_array().empty(); + auto const firstArgIsNull = arrayIsNotEmpty and req.at(JS(params)).as_array().at(0).is_null(); + auto const firstArgIsEmptyString = arrayIsNotEmpty and req.at(JS(params)).as_array().at(0).is_string() and + req.at(JS(params)).as_array().at(0).as_string().empty(); + + // Note: all this compatibility dance is to match `rippled` as close as possible + return not hasParams or paramsIsEmptyString or paramsIsNull or paramsIsEmptyObject or arrayIsEmpty or + firstArgIsEmptyString or firstArgIsNull; + } +}; + +} // namespace web::ng diff --git a/src/web/ng/Request.cpp b/src/web/ng/Request.cpp index 1a60736ba..9bf1c9a41 100644 --- a/src/web/ng/Request.cpp +++ b/src/web/ng/Request.cpp @@ -19,6 +19,8 @@ #include "web/ng/Request.hpp" +#include "util/OverloadSet.hpp" + #include #include #include @@ -104,6 +106,18 @@ Request::target() const return httpRequest().target(); } +Request::HttpHeaders const& +Request::httpHeaders() const +{ + return std::visit( + util::OverloadSet{ + [](HttpRequest const& httpRequest) -> HttpHeaders const& { return httpRequest; }, + [](WsData const& wsData) -> HttpHeaders const& { return wsData.headers.get(); } + }, + data_ + ); +} + std::optional Request::headerValue(boost::beast::http::field headerName) const { diff --git a/src/web/ng/Request.hpp b/src/web/ng/Request.hpp index 060181566..55f4ff560 100644 --- a/src/web/ng/Request.hpp +++ b/src/web/ng/Request.hpp @@ -112,6 +112,14 @@ class Request { std::optional target() const; + /** + * @brief Get the headers of the request. + * + * @return The headers of the request. + */ + HttpHeaders const& + httpHeaders() const; + /** * @brief Get the value of a header. * diff --git a/src/web/ng/Response.cpp b/src/web/ng/Response.cpp index 70c915065..ac76f1ac8 100644 --- a/src/web/ng/Response.cpp +++ b/src/web/ng/Response.cpp @@ -20,6 +20,7 @@ #include "web/ng/Response.hpp" #include "util/Assert.hpp" +#include "util/OverloadSet.hpp" #include "util/build/Build.hpp" #include "web/ng/Request.hpp" @@ -34,83 +35,98 @@ #include #include -#include #include #include +#include namespace http = boost::beast::http; + namespace web::ng { namespace { -std::string_view -asString(Response::HttpData::ContentType type) +template +consteval bool +isString() { - switch (type) { - case Response::HttpData::ContentType::TextHtml: - return "text/html"; - case Response::HttpData::ContentType::ApplicationJson: - return "application/json"; - } - ASSERT(false, "Unknown content type"); - std::unreachable(); + return std::is_same_v; +} + +http::response +prepareResponse(http::response response, http::request const& request) +{ + response.set(http::field::server, fmt::format("clio-server-{}", util::build::getClioVersionString())); + response.keep_alive(request.keep_alive()); + response.prepare_payload(); + return response; } template -std::optional -makeHttpData(http::status status, Request const& request) +std::variant, std::string> +makeData(http::status status, MessageType message, Request const& request) { - if (request.isHttp()) { - auto const& httpRequest = request.asHttpRequest()->get(); - auto constexpr contentType = std::is_same_v, std::string> - ? Response::HttpData::ContentType::TextHtml - : Response::HttpData::ContentType::ApplicationJson; - return Response::HttpData{ - .status = status, - .contentType = contentType, - .keepAlive = httpRequest.keep_alive(), - .version = httpRequest.version() - }; + std::string body; + if constexpr (isString()) { + body = std::move(message); + } else { + body = boost::json::serialize(message); } - return std::nullopt; + + if (not request.isHttp()) + return body; + + auto const& httpRequest = request.asHttpRequest()->get(); + std::string const contentType = isString() ? "text/html" : "application/json"; + + http::response result{status, httpRequest.version(), std::move(body)}; + result.set(http::field::content_type, contentType); + return prepareResponse(std::move(result), httpRequest); } + } // namespace Response::Response(boost::beast::http::status status, std::string message, Request const& request) - : message_(std::move(message)), httpData_{makeHttpData(status, request)} + : data_{makeData(status, std::move(message), request)} { } Response::Response(boost::beast::http::status status, boost::json::object const& message, Request const& request) - : message_(boost::json::serialize(message)), httpData_{makeHttpData(status, request)} + : data_{makeData(status, message, request)} { } +Response::Response(boost::beast::http::response response, Request const& request) +{ + ASSERT(request.isHttp(), "Request must be HTTP to construct response from HTTP response"); + data_ = prepareResponse(std::move(response), request.asHttpRequest()->get()); +} + std::string const& Response::message() const { - return message_; + return std::visit( + util::OverloadSet{ + [](http::response const& response) -> std::string const& { return response.body(); }, + [](std::string const& message) -> std::string const& { return message; }, + }, + data_ + ); } http::response Response::intoHttpResponse() && { - ASSERT(httpData_.has_value(), "Response must have http data to be converted into http response"); - - http::response result{httpData_->status, httpData_->version}; - result.set(http::field::server, fmt::format("clio-server-{}", util::build::getClioVersionString())); - result.set(http::field::content_type, asString(httpData_->contentType)); - result.keep_alive(httpData_->keepAlive); - result.body() = std::move(message_); - result.prepare_payload(); - return result; + ASSERT(std::holds_alternative>(data_), "Response must contain HTTP data"); + + return std::move(std::get>(data_)); } boost::asio::const_buffer -Response::asConstBuffer() const& +Response::asWsResponse() const& { - ASSERT(not httpData_.has_value(), "Losing existing http data"); - return boost::asio::buffer(message_.data(), message_.size()); + ASSERT(std::holds_alternative(data_), "Response must contain WebSocket data"); + auto const& message = std::get(data_); + return boost::asio::buffer(message.data(), message.size()); } } // namespace web::ng diff --git a/src/web/ng/Response.hpp b/src/web/ng/Response.hpp index 33348e5a6..b9d197283 100644 --- a/src/web/ng/Response.hpp +++ b/src/web/ng/Response.hpp @@ -27,8 +27,9 @@ #include #include -#include #include +#include + namespace web::ng { /** @@ -36,30 +37,13 @@ namespace web::ng { */ class Response { public: - /** - * @brief The data for an HTTP response. - */ - struct HttpData { - /** - * @brief The content type of the response. - */ - enum class ContentType { ApplicationJson, TextHtml }; - - boost::beast::http::status status; ///< The HTTP status. - ContentType contentType; ///< The content type. - bool keepAlive; ///< Whether the connection should be kept alive. - unsigned int version; ///< The HTTP version. - }; - -private: - std::string message_; - std::optional httpData_; + std::variant, std::string> data_; public: /** * @brief Construct a Response from string. Content type will be text/html. * - * @param status The HTTP status. + * @param status The HTTP status. It will be ignored if request is WebSocket. * @param message The message to send. * @param request The request that triggered this response. Used to determine whether the response should contain * HTTP or WebSocket data. @@ -69,13 +53,21 @@ class Response { /** * @brief Construct a Response from JSON object. Content type will be application/json. * - * @param status The HTTP status. + * @param status The HTTP status. It will be ignored if request is WebSocket. * @param message The message to send. * @param request The request that triggered this response. Used to determine whether the response should contain * HTTP or WebSocket */ Response(boost::beast::http::status status, boost::json::object const& message, Request const& request); + /** + * @brief Construct a Response from HTTP response. + * + * @param response The HTTP response. + * @param request The request that triggered this response. It must be an HTTP request. + */ + Response(boost::beast::http::response response, Request const& request); + /** * @brief Get the message of the response. * @@ -100,7 +92,7 @@ class Response { * @return The message of the response as a const buffer. */ boost::asio::const_buffer - asConstBuffer() const&; + asWsResponse() const&; }; } // namespace web::ng diff --git a/src/web/ng/Server.cpp b/src/web/ng/Server.cpp index 02da3efb8..083e3b7b2 100644 --- a/src/web/ng/Server.cpp +++ b/src/web/ng/Server.cpp @@ -25,6 +25,7 @@ #include "util/log/Logger.hpp" #include "web/ng/Connection.hpp" #include "web/ng/MessageHandler.hpp" +#include "web/ng/ProcessingPolicy.hpp" #include "web/ng/impl/HttpConnection.hpp" #include "web/ng/impl/ServerSslContext.hpp" @@ -41,6 +42,7 @@ #include #include #include +#include #include #include @@ -178,14 +180,16 @@ Server::Server( boost::asio::io_context& ctx, boost::asio::ip::tcp::endpoint endpoint, std::optional sslContext, - impl::ConnectionHandler connectionHandler, - util::TagDecoratorFactory tagDecoratorFactory + ProcessingPolicy processingPolicy, + std::optional parallelRequestLimit, + util::TagDecoratorFactory tagDecoratorFactory, + std::optional maxSubscriptionSendQueueSize ) : ctx_{ctx} , sslContext_{std::move(sslContext)} - , connectionHandler_{std::move(connectionHandler)} - , endpoint_{std::move(endpoint)} , tagDecoratorFactory_{tagDecoratorFactory} + , connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize} + , endpoint_{std::move(endpoint)} { } @@ -297,24 +301,28 @@ make_Server(util::Config const& config, boost::asio::io_context& context) if (not expectedSslContext) return std::unexpected{std::move(expectedSslContext).error()}; - impl::ConnectionHandler::ProcessingPolicy processingPolicy{impl::ConnectionHandler::ProcessingPolicy::Parallel}; + ProcessingPolicy processingPolicy{ProcessingPolicy::Parallel}; std::optional parallelRequestLimit; auto const processingStrategyStr = serverConfig.valueOr("processing_policy", "parallel"); if (processingStrategyStr == "sequent") { - processingPolicy = impl::ConnectionHandler::ProcessingPolicy::Sequential; + processingPolicy = ProcessingPolicy::Sequential; } else if (processingStrategyStr == "parallel") { parallelRequestLimit = serverConfig.maybeValue("parallel_requests_limit"); } else { return std::unexpected{fmt::format("Invalid 'server.processing_strategy': {}", processingStrategyStr)}; } + auto const maxSubscriptionSendQueueSize = serverConfig.maybeValue("ws_max_sending_queue_size"); + return Server{ context, std::move(endpoint).value(), std::move(expectedSslContext).value(), - impl::ConnectionHandler{processingPolicy, parallelRequestLimit}, - util::TagDecoratorFactory(config) + processingPolicy, + parallelRequestLimit, + util::TagDecoratorFactory(config), + maxSubscriptionSendQueueSize }; } diff --git a/src/web/ng/Server.hpp b/src/web/ng/Server.hpp index dd68ad0b9..674d9167c 100644 --- a/src/web/ng/Server.hpp +++ b/src/web/ng/Server.hpp @@ -22,8 +22,8 @@ #include "util/Taggable.hpp" #include "util/config/Config.hpp" #include "util/log/Logger.hpp" -#include "web/impl/AdminVerificationStrategy.hpp" #include "web/ng/MessageHandler.hpp" +#include "web/ng/ProcessingPolicy.hpp" #include "web/ng/impl/ConnectionHandler.hpp" #include @@ -44,16 +44,15 @@ namespace web::ng { class Server { util::Logger log_{"WebServer"}; util::Logger perfLog_{"Performance"}; - std::reference_wrapper ctx_; + std::reference_wrapper ctx_; std::optional sslContext_; - impl::ConnectionHandler connectionHandler_; + util::TagDecoratorFactory tagDecoratorFactory_; + impl::ConnectionHandler connectionHandler_; boost::asio::ip::tcp::endpoint endpoint_; - util::TagDecoratorFactory tagDecoratorFactory_; - bool running_{false}; public: @@ -63,15 +62,20 @@ class Server { * @param ctx The boost::asio::io_context to use. * @param endpoint The endpoint to listen on. * @param sslContext The SSL context to use (optional). - * @param connectionHandler The connection handler. + * @param processingPolicy The requests processing policy (parallel or sequential). + * @param parallelRequestLimit The limit of requests for one connection that can be processed in parallel. Only used + * if processingPolicy is parallel. * @param tagDecoratorFactory The tag decorator factory. + * @param maxSubscriptionSendQueueSize The maximum size of the subscription send queue. */ Server( boost::asio::io_context& ctx, boost::asio::ip::tcp::endpoint endpoint, std::optional sslContext, - impl::ConnectionHandler connectionHandler, - util::TagDecoratorFactory tagDecoratorFactory + ProcessingPolicy processingPolicy, + std::optional parallelRequestLimit, + util::TagDecoratorFactory tagDecoratorFactory, + std::optional maxSubscriptionSendQueueSize ); /** diff --git a/src/web/ng/SubscriptionContext.cpp b/src/web/ng/SubscriptionContext.cpp new file mode 100644 index 000000000..8820334a4 --- /dev/null +++ b/src/web/ng/SubscriptionContext.cpp @@ -0,0 +1,101 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "web/ng/SubscriptionContext.hpp" + +#include "util/Taggable.hpp" +#include "web/SubscriptionContextInterface.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace web::ng { + +SubscriptionContext::SubscriptionContext( + util::TagDecoratorFactory const& factory, + impl::WsConnectionBase& connection, + std::optional maxSendQueueSize, + boost::asio::yield_context yield, + ErrorHandler errorHandler +) + : web::SubscriptionContextInterface(factory) + , connection_(connection) + , maxSendQueueSize_(maxSendQueueSize) + , tasksGroup_(yield) + , yield_(yield) + , errorHandler_(std::move(errorHandler)) +{ +} + +void +SubscriptionContext::send(std::shared_ptr message) +{ + if (disconnected_) + return; + + if (maxSendQueueSize_.has_value() and tasksGroup_.size() >= *maxSendQueueSize_) { + tasksGroup_.spawn(yield_, [this](boost::asio::yield_context innerYield) { + connection_.get().close(innerYield); + }); + disconnected_ = true; + return; + } + + tasksGroup_.spawn(yield_, [this, message = std::move(message)](boost::asio::yield_context innerYield) { + auto const maybeError = connection_.get().sendBuffer(boost::asio::buffer(*message), innerYield); + if (maybeError.has_value() and errorHandler_(*maybeError, connection_)) + connection_.get().close(innerYield); + }); +} + +void +SubscriptionContext::onDisconnect(OnDisconnectSlot const& slot) +{ + onDisconnect_.connect(slot); +} + +void +SubscriptionContext::setApiSubversion(uint32_t value) +{ + apiSubversion_ = value; +} + +uint32_t +SubscriptionContext::apiSubversion() const +{ + return apiSubversion_; +} + +void +SubscriptionContext::disconnect(boost::asio::yield_context yield) +{ + onDisconnect_(this); + disconnected_ = true; + tasksGroup_.asyncWait(yield); +} + +} // namespace web::ng diff --git a/src/web/ng/SubscriptionContext.hpp b/src/web/ng/SubscriptionContext.hpp new file mode 100644 index 000000000..9da22e7b1 --- /dev/null +++ b/src/web/ng/SubscriptionContext.hpp @@ -0,0 +1,132 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include "util/CoroutineGroup.hpp" +#include "util/Taggable.hpp" +#include "web/SubscriptionContextInterface.hpp" +#include "web/ng/Connection.hpp" +#include "web/ng/Error.hpp" +#include "web/ng/impl/WsConnection.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace web::ng { + +/** + * @brief Implementation of SubscriptionContextInterface. + * @note This class is designed to be used with SubscriptionManager. The class is safe to use from multiple threads. + * The method disconnect() must be called before the object is destroyed. + */ +class SubscriptionContext : public web::SubscriptionContextInterface { +public: + /** + * @brief Error handler definition. Error handler returns true if connection should be closed false otherwise. + */ + using ErrorHandler = std::function; + +private: + std::reference_wrapper connection_; + std::optional maxSendQueueSize_; + util::CoroutineGroup tasksGroup_; + boost::asio::yield_context yield_; + ErrorHandler errorHandler_; + + boost::signals2::signal onDisconnect_; + std::atomic_bool disconnected_{false}; + + /** + * @brief The API version of the web stream client. + * This is used to track the api version of this connection, which mainly is used by subscription. It is different + * from the api version in Context, which is only used for the current request. + */ + std::atomic_uint32_t apiSubversion_ = 0u; + +public: + /** + * @brief Construct a new Subscription Context object + * + * @param factory The tag decorator factory to use to init taggable. + * @param connection The connection for which the context is created. + * @param maxSendQueueSize The maximum size of the send queue. If the queue is full, the connection will be closed. + * @param yield The yield context to spawn sending coroutines. + * @param errorHandler The error handler. + */ + SubscriptionContext( + util::TagDecoratorFactory const& factory, + impl::WsConnectionBase& connection, + std::optional maxSendQueueSize, + boost::asio::yield_context yield, + ErrorHandler errorHandler + ); + + /** + * @brief Send message to the client + * @note This method does nothing after disconnected() was called. + * + * @param message The message to send. + */ + void + send(std::shared_ptr message) override; + + /** + * @brief Connect a slot to onDisconnect connection signal. + * + * @param slot The slot to connect. + */ + void + onDisconnect(OnDisconnectSlot const& slot) override; + + /** + * @brief Set the API subversion. + * @param value The value to set. + */ + void + setApiSubversion(uint32_t value) override; + + /** + * @brief Get the API subversion. + * + * @return The API subversion. + */ + uint32_t + apiSubversion() const override; + + /** + * @brief Notify the context that related connection is disconnected and wait for all the task to complete. + * @note This method must be called before the object is destroyed. + * + * @param yield The yield context to wait for all the tasks to complete. + */ + void + disconnect(boost::asio::yield_context yield); +}; + +} // namespace web::ng diff --git a/src/web/ng/impl/ConnectionHandler.cpp b/src/web/ng/impl/ConnectionHandler.cpp index 4a92e6a4a..02e79b1a7 100644 --- a/src/web/ng/impl/ConnectionHandler.cpp +++ b/src/web/ng/impl/ConnectionHandler.cpp @@ -21,21 +21,30 @@ #include "util/Assert.hpp" #include "util/CoroutineGroup.hpp" +#include "util/Taggable.hpp" #include "util/log/Logger.hpp" +#include "web/SubscriptionContextInterface.hpp" #include "web/ng/Connection.hpp" #include "web/ng/Error.hpp" #include "web/ng/MessageHandler.hpp" +#include "web/ng/ProcessingPolicy.hpp" #include "web/ng/Request.hpp" #include "web/ng/Response.hpp" +#include "web/ng/SubscriptionContext.hpp" +#include +#include #include #include #include +#include +#include #include #include #include #include +#include #include #include #include @@ -47,7 +56,8 @@ namespace { Response handleHttpRequest( - ConnectionContext const& connectionContext, + ConnectionMetadata& connectionMetadata, + SubscriptionContextPtr& subscriptionContext, ConnectionHandler::TargetToHandlerMap const& handlers, Request const& request, boost::asio::yield_context yield @@ -58,12 +68,13 @@ handleHttpRequest( if (it == handlers.end()) { return Response{boost::beast::http::status::bad_request, "Bad target", request}; } - return it->second(request, connectionContext, yield); + return it->second(request, connectionMetadata, subscriptionContext, yield); } Response handleWsRequest( - ConnectionContext connectionContext, + ConnectionMetadata& connectionMetadata, + SubscriptionContextPtr& subscriptionContext, std::optional const& handler, Request const& request, boost::asio::yield_context yield @@ -72,7 +83,7 @@ handleWsRequest( if (not handler.has_value()) { return Response{boost::beast::http::status::bad_request, "WebSocket is not supported by this server", request}; } - return handler->operator()(request, connectionContext, yield); + return handler->operator()(request, connectionMetadata, subscriptionContext, yield); } } // namespace @@ -95,8 +106,16 @@ ConnectionHandler::StringHash::operator()(std::string const& str) const return hash_type{}(str); } -ConnectionHandler::ConnectionHandler(ProcessingPolicy processingPolicy, std::optional maxParallelRequests) - : processingPolicy_{processingPolicy}, maxParallelRequests_{maxParallelRequests} +ConnectionHandler::ConnectionHandler( + ProcessingPolicy processingPolicy, + std::optional maxParallelRequests, + util::TagDecoratorFactory& tagFactory, + std::optional maxSubscriptionSendQueueSize +) + : processingPolicy_{processingPolicy} + , maxParallelRequests_{maxParallelRequests} + , tagFactory_{tagFactory} + , maxSubscriptionSendQueueSize_{maxSubscriptionSendQueueSize} { } @@ -126,14 +145,32 @@ ConnectionHandler::processConnection(ConnectionPtr connectionPtr, boost::asio::y bool shouldCloseGracefully = false; + std::shared_ptr subscriptionContext; + if (connectionRef.wasUpgraded()) { + auto* ptr = dynamic_cast(connectionPtr.get()); + ASSERT(ptr != nullptr, "Casted not websocket connection"); + subscriptionContext = std::make_shared( + tagFactory_, + *ptr, + maxSubscriptionSendQueueSize_, + yield, + [this](Error const& e, Connection const& c) { return handleError(e, c); } + ); + } + SubscriptionContextPtr subscriptionContextInterfacePtr = subscriptionContext; + switch (processingPolicy_) { case ProcessingPolicy::Sequential: - shouldCloseGracefully = sequentRequestResponseLoop(connectionRef, yield); + shouldCloseGracefully = sequentRequestResponseLoop(connectionRef, subscriptionContextInterfacePtr, yield); break; case ProcessingPolicy::Parallel: - shouldCloseGracefully = parallelRequestResponseLoop(connectionRef, yield); + shouldCloseGracefully = parallelRequestResponseLoop(connectionRef, subscriptionContextInterfacePtr, yield); break; } + + if (subscriptionContext != nullptr) + subscriptionContext->disconnect(yield); + if (shouldCloseGracefully) connectionRef.close(yield); @@ -179,7 +216,11 @@ ConnectionHandler::handleError(Error const& error, Connection const& connection) } bool -ConnectionHandler::sequentRequestResponseLoop(Connection& connection, boost::asio::yield_context yield) +ConnectionHandler::sequentRequestResponseLoop( + Connection& connection, + SubscriptionContextPtr& subscriptionContext, + boost::asio::yield_context yield +) { // The loop here is infinite because: // - For websocket connection is persistent so Clio will try to read and respond infinite unless client @@ -196,14 +237,19 @@ ConnectionHandler::sequentRequestResponseLoop(Connection& connection, boost::asi LOG(log_.info()) << connection.tag() << "Received request from ip = " << connection.ip(); - auto maybeReturnValue = processRequest(connection, std::move(expectedRequest).value(), yield); + auto maybeReturnValue = + processRequest(connection, subscriptionContext, std::move(expectedRequest).value(), yield); if (maybeReturnValue.has_value()) return maybeReturnValue.value(); } } bool -ConnectionHandler::parallelRequestResponseLoop(Connection& connection, boost::asio::yield_context yield) +ConnectionHandler::parallelRequestResponseLoop( + Connection& connection, + SubscriptionContextPtr& subscriptionContext, + boost::asio::yield_context yield +) { // atomic_bool is not needed here because everything happening on coroutine's strand bool stop = false; @@ -218,13 +264,18 @@ ConnectionHandler::parallelRequestResponseLoop(Connection& connection, boost::as closeConnectionGracefully &= closeGracefully; break; } - if (tasksGroup.canSpawn()) { + + if (not tasksGroup.isFull()) { bool const spawnSuccess = tasksGroup.spawn( yield, // spawn on the same strand - [this, &stop, &closeConnectionGracefully, &connection, request = std::move(expectedRequest).value()]( - boost::asio::yield_context innerYield - ) mutable { - auto maybeCloseConnectionGracefully = processRequest(connection, request, innerYield); + [this, + &stop, + &closeConnectionGracefully, + &connection, + &subscriptionContext, + request = std::move(expectedRequest).value()](boost::asio::yield_context innerYield) mutable { + auto maybeCloseConnectionGracefully = + processRequest(connection, subscriptionContext, request, innerYield); if (maybeCloseConnectionGracefully.has_value()) { stop = true; closeConnectionGracefully &= maybeCloseConnectionGracefully.value(); @@ -248,9 +299,14 @@ ConnectionHandler::parallelRequestResponseLoop(Connection& connection, boost::as } std::optional -ConnectionHandler::processRequest(Connection& connection, Request const& request, boost::asio::yield_context yield) +ConnectionHandler::processRequest( + Connection& connection, + SubscriptionContextPtr& subscriptionContext, + Request const& request, + boost::asio::yield_context yield +) { - auto response = handleRequest(connection.context(), request, yield); + auto response = handleRequest(connection, subscriptionContext, request, yield); auto const maybeError = connection.send(std::move(response), yield); if (maybeError.has_value()) { @@ -261,18 +317,19 @@ ConnectionHandler::processRequest(Connection& connection, Request const& request Response ConnectionHandler::handleRequest( - ConnectionContext const& connectionContext, + ConnectionMetadata& connectionMetadata, + SubscriptionContextPtr& subscriptionContext, Request const& request, boost::asio::yield_context yield ) { switch (request.method()) { case Request::Method::Get: - return handleHttpRequest(connectionContext, getHandlers_, request, yield); + return handleHttpRequest(connectionMetadata, subscriptionContext, getHandlers_, request, yield); case Request::Method::Post: - return handleHttpRequest(connectionContext, postHandlers_, request, yield); + return handleHttpRequest(connectionMetadata, subscriptionContext, postHandlers_, request, yield); case Request::Method::Websocket: - return handleWsRequest(connectionContext, wsHandler_, request, yield); + return handleWsRequest(connectionMetadata, subscriptionContext, wsHandler_, request, yield); default: return Response{boost::beast::http::status::bad_request, "Unsupported http method", request}; } diff --git a/src/web/ng/impl/ConnectionHandler.hpp b/src/web/ng/impl/ConnectionHandler.hpp index b5f8a9a06..874f0fbc9 100644 --- a/src/web/ng/impl/ConnectionHandler.hpp +++ b/src/web/ng/impl/ConnectionHandler.hpp @@ -19,10 +19,13 @@ #pragma once +#include "util/Taggable.hpp" #include "util/log/Logger.hpp" +#include "web/SubscriptionContextInterface.hpp" #include "web/ng/Connection.hpp" #include "web/ng/Error.hpp" #include "web/ng/MessageHandler.hpp" +#include "web/ng/ProcessingPolicy.hpp" #include "web/ng/Request.hpp" #include "web/ng/Response.hpp" @@ -41,8 +44,6 @@ namespace web::ng::impl { class ConnectionHandler { public: - enum class ProcessingPolicy { Sequential, Parallel }; - struct StringHash { using hash_type = std::hash; using is_transparent = void; @@ -64,6 +65,9 @@ class ConnectionHandler { ProcessingPolicy processingPolicy_; std::optional maxParallelRequests_; + std::reference_wrapper tagFactory_; + std::optional maxSubscriptionSendQueueSize_; + TargetToHandlerMap getHandlers_; TargetToHandlerMap postHandlers_; std::optional wsHandler_; @@ -71,7 +75,12 @@ class ConnectionHandler { boost::signals2::signal onStop_; public: - ConnectionHandler(ProcessingPolicy processingPolicy, std::optional maxParallelRequests); + ConnectionHandler( + ProcessingPolicy processingPolicy, + std::optional maxParallelRequests, + util::TagDecoratorFactory& tagFactory, + std::optional maxSubscriptionSendQueueSize + ); void onGet(std::string const& target, MessageHandler handler); @@ -107,24 +116,34 @@ class ConnectionHandler { * @return True if the connection should be gracefully closed, false otherwise. */ bool - sequentRequestResponseLoop(Connection& connection, boost::asio::yield_context yield); + sequentRequestResponseLoop( + Connection& connection, + SubscriptionContextPtr& subscriptionContext, + boost::asio::yield_context yield + ); bool - parallelRequestResponseLoop(Connection& connection, boost::asio::yield_context yield); + parallelRequestResponseLoop( + Connection& connection, + SubscriptionContextPtr& subscriptionContext, + boost::asio::yield_context yield + ); std::optional - processRequest(Connection& connection, Request const& request, boost::asio::yield_context yield); + processRequest( + Connection& connection, + SubscriptionContextPtr& subscriptionContext, + Request const& request, + boost::asio::yield_context yield + ); - /** - * @brief Handle a request. - * - * @param connectionContext The connection context. - * @param request The request to handle. - * @param yield The yield context. - * @return The response to send. - */ Response - handleRequest(ConnectionContext const& connectionContext, Request const& request, boost::asio::yield_context yield); + handleRequest( + ConnectionMetadata& connectionMetadata, + SubscriptionContextPtr& subscriptionContext, + Request const& request, + boost::asio::yield_context yield + ); }; } // namespace web::ng::impl diff --git a/src/web/ng/impl/ErrorHandling.cpp b/src/web/ng/impl/ErrorHandling.cpp new file mode 100644 index 000000000..49416d9bd --- /dev/null +++ b/src/web/ng/impl/ErrorHandling.cpp @@ -0,0 +1,165 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "web/ng/impl/ErrorHandling.hpp" + +#include "rpc/Errors.hpp" +#include "rpc/JS.hpp" +#include "util/Assert.hpp" +#include "web/ng/Request.hpp" +#include "web/ng/Response.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace http = boost::beast::http; + +namespace web::ng::impl { + +namespace { + +boost::json::object +composeErrorImpl(auto const& error, Request const& rawRequest, std::optional const& request) +{ + auto e = rpc::makeError(error); + + if (request) { + auto const appendFieldIfExist = [&](auto const& field) { + if (request->contains(field) and not request->at(field).is_null()) + e[field] = request->at(field); + }; + + appendFieldIfExist(JS(id)); + + if (not rawRequest.isHttp()) + appendFieldIfExist(JS(api_version)); + + e[JS(request)] = request.value(); + } + + if (not rawRequest.isHttp()) { + return e; + } + return {{JS(result), e}}; +} + +} // namespace + +ErrorHelper::ErrorHelper(Request const& rawRequest, std::optional request) + : rawRequest_{rawRequest}, request_{std::move(request)} +{ +} + +Response +ErrorHelper::makeError(rpc::Status const& err) const +{ + if (not rawRequest_.get().isHttp()) { + return Response{http::status::bad_request, composeError(err), rawRequest_}; + } + + // Note: a collection of crutches to match rippled output follows + if (auto const clioCode = std::get_if(&err.code)) { + switch (*clioCode) { + case rpc::ClioError::rpcINVALID_API_VERSION: + return Response{ + http::status::bad_request, std::string{rpc::getErrorInfo(*clioCode).error}, rawRequest_ + }; + case rpc::ClioError::rpcCOMMAND_IS_MISSING: + return Response{http::status::bad_request, "Null method", rawRequest_}; + case rpc::ClioError::rpcCOMMAND_IS_EMPTY: + return Response{http::status::bad_request, "method is empty", rawRequest_}; + case rpc::ClioError::rpcCOMMAND_NOT_STRING: + return Response{http::status::bad_request, "method is not string", rawRequest_}; + case rpc::ClioError::rpcPARAMS_UNPARSEABLE: + return Response{http::status::bad_request, "params unparseable", rawRequest_}; + + // others are not applicable but we want a compilation error next time we add one + case rpc::ClioError::rpcUNKNOWN_OPTION: + case rpc::ClioError::rpcMALFORMED_CURRENCY: + case rpc::ClioError::rpcMALFORMED_REQUEST: + case rpc::ClioError::rpcMALFORMED_OWNER: + case rpc::ClioError::rpcMALFORMED_ADDRESS: + case rpc::ClioError::rpcINVALID_HOT_WALLET: + case rpc::ClioError::rpcFIELD_NOT_FOUND_TRANSACTION: + case rpc::ClioError::rpcMALFORMED_ORACLE_DOCUMENT_ID: + case rpc::ClioError::rpcMALFORMED_AUTHORIZED_CREDENTIALS: + case rpc::ClioError::etlCONNECTION_ERROR: + case rpc::ClioError::etlREQUEST_ERROR: + case rpc::ClioError::etlREQUEST_TIMEOUT: + case rpc::ClioError::etlINVALID_RESPONSE: + ASSERT(false, "Unknown rpc error code {}", static_cast(*clioCode)); // this should never happen + break; + } + } + + return Response{http::status::bad_request, composeError(err), rawRequest_}; +} + +Response +ErrorHelper::makeInternalError() const +{ + return Response{http::status::internal_server_error, composeError(rpc::RippledError::rpcINTERNAL), rawRequest_}; +} + +Response +ErrorHelper::makeNotReadyError() const +{ + return Response{http::status::ok, composeError(rpc::RippledError::rpcNOT_READY), rawRequest_}; +} + +Response +ErrorHelper::makeTooBusyError() const +{ + if (not rawRequest_.get().isHttp()) { + return Response{http::status::too_many_requests, rpc::makeError(rpc::RippledError::rpcTOO_BUSY), rawRequest_}; + } + + return Response{http::status::service_unavailable, rpc::makeError(rpc::RippledError::rpcTOO_BUSY), rawRequest_}; +} + +Response +ErrorHelper::makeJsonParsingError() const +{ + if (not rawRequest_.get().isHttp()) { + return Response{http::status::bad_request, rpc::makeError(rpc::RippledError::rpcBAD_SYNTAX), rawRequest_}; + } + + return Response{http::status::bad_request, fmt::format("Unable to parse JSON from the request"), rawRequest_}; +} + +boost::json::object +ErrorHelper::composeError(rpc::Status const& error) const +{ + return composeErrorImpl(error, rawRequest_, request_); +} + +boost::json::object +ErrorHelper::composeError(rpc::RippledError error) const +{ + return composeErrorImpl(error, rawRequest_, request_); +} + +} // namespace web::ng::impl diff --git a/src/web/ng/impl/ErrorHandling.hpp b/src/web/ng/impl/ErrorHandling.hpp new file mode 100644 index 000000000..88b87639a --- /dev/null +++ b/src/web/ng/impl/ErrorHandling.hpp @@ -0,0 +1,114 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include "rpc/Errors.hpp" +#include "web/ng/Request.hpp" +#include "web/ng/Response.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace web::ng::impl { + +/** + * @brief A helper that attempts to match rippled reporting mode HTTP errors as close as possible. + */ +class ErrorHelper { + std::reference_wrapper rawRequest_; + std::optional request_; + +public: + /** + * @brief Construct a new Error Helper object + * + * @param rawRequest The request that caused the error. + * @param request The parsed request that caused the error. + */ + ErrorHelper(Request const& rawRequest, std::optional request = std::nullopt); + + /** + * @brief Make an error response from a status. + * + * @param err The status to make an error response from. + * @return + */ + [[nodiscard]] Response + makeError(rpc::Status const& err) const; + + /** + * @brief Make an internal error response. + * + * @return A response with an internal error. + */ + [[nodiscard]] Response + makeInternalError() const; + + /** + * @brief Make a response for when the server is not ready. + * + * @return A response with a not ready error. + */ + [[nodiscard]] Response + makeNotReadyError() const; + + /** + * @brief Make a response for when the server is too busy. + * + * @return A response with a too busy error. + */ + [[nodiscard]] Response + makeTooBusyError() const; + + /** + * @brief Make a response when json parsing fails. + * + * @return A response with a json parsing error. + */ + [[nodiscard]] Response + makeJsonParsingError() const; + + /** + * @beirf Compose an error into json object from a status. + * + * @param error The status to compose into a json object. + * @return The composed json object. + */ + [[nodiscard]] boost::json::object + composeError(rpc::Status const& error) const; + + /** + * @brief Compose an error into json object from a rippled error. + * + * @param error The rippled error to compose into a json object. + * @return The composed json object. + */ + [[nodiscard]] boost::json::object + composeError(rpc::RippledError error) const; +}; + +} // namespace web::ng::impl diff --git a/src/web/ng/impl/WsConnection.hpp b/src/web/ng/impl/WsConnection.hpp index 956f15388..26d1c4bf6 100644 --- a/src/web/ng/impl/WsConnection.hpp +++ b/src/web/ng/impl/WsConnection.hpp @@ -28,6 +28,7 @@ #include "web/ng/Response.hpp" #include "web/ng/impl/Concepts.hpp" +#include #include #include #include @@ -52,8 +53,20 @@ namespace web::ng::impl { +class WsConnectionBase : public Connection { +public: + using Connection::Connection; + + virtual std::optional + sendBuffer( + boost::asio::const_buffer buffer, + boost::asio::yield_context yield, + std::chrono::steady_clock::duration timeout = Connection::DEFAULT_TIMEOUT + ) = 0; +}; + template -class WsConnection : public Connection { +class WsConnection : public WsConnectionBase { boost::beast::websocket::stream stream_; boost::beast::http::request initialRequest_; @@ -66,7 +79,7 @@ class WsConnection : public Connection { util::TagDecoratorFactory const& tagDecoratorFactory ) requires IsTcpStream - : Connection(std::move(ip), std::move(buffer), tagDecoratorFactory) + : WsConnectionBase(std::move(ip), std::move(buffer), tagDecoratorFactory) , stream_(std::move(socket)) , initialRequest_(std::move(initialRequest)) { @@ -81,7 +94,7 @@ class WsConnection : public Connection { util::TagDecoratorFactory const& tagDecoratorFactory ) requires IsSslTcpStream - : Connection(std::move(ip), std::move(buffer), tagDecoratorFactory) + : WsConnectionBase(std::move(ip), std::move(buffer), tagDecoratorFactory) , stream_(std::move(socket), sslContext) , initialRequest_(std::move(initialRequest)) { @@ -112,20 +125,29 @@ class WsConnection : public Connection { } std::optional - send( - Response response, + sendBuffer( + boost::asio::const_buffer buffer, boost::asio::yield_context yield, - std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT + std::chrono::steady_clock::duration timeout = Connection::DEFAULT_TIMEOUT ) override { - auto error = util::withTimeout( - [this, &response](auto&& yield) { stream_.async_write(response.asConstBuffer(), yield); }, yield, timeout - ); + auto error = + util::withTimeout([this, buffer](auto&& yield) { stream_.async_write(buffer, yield); }, yield, timeout); if (error) return error; return std::nullopt; } + std::optional + send( + Response response, + boost::asio::yield_context yield, + std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT + ) override + { + return sendBuffer(response.asWsResponse(), yield, timeout); + } + std::expected receive(boost::asio::yield_context yield, std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT) override { diff --git a/tests/common/feed/FeedTestUtil.hpp b/tests/common/feed/FeedTestUtil.hpp index dfe872fb5..4de08b884 100644 --- a/tests/common/feed/FeedTestUtil.hpp +++ b/tests/common/feed/FeedTestUtil.hpp @@ -23,7 +23,7 @@ #include "util/MockPrometheus.hpp" #include "util/MockWsBase.hpp" #include "util/SyncExecutionCtxFixture.hpp" -#include "web/interface/ConnectionBase.hpp" +#include "web/SubscriptionContextInterface.hpp" #include #include @@ -37,25 +37,9 @@ template struct FeedBaseTest : util::prometheus::WithPrometheus, MockBackendTest, SyncExecutionCtxFixture { protected: - std::shared_ptr sessionPtr; - std::shared_ptr testFeedPtr; - MockSession* mockSessionPtr = nullptr; - - void - SetUp() override - { - testFeedPtr = std::make_shared(ctx); - sessionPtr = std::make_shared(); - sessionPtr->apiSubVersion = 1; - mockSessionPtr = dynamic_cast(sessionPtr.get()); - } - - void - TearDown() override - { - sessionPtr.reset(); - testFeedPtr.reset(); - } + web::SubscriptionContextPtr sessionPtr = std::make_shared(); + std::shared_ptr testFeedPtr = std::make_shared(ctx); + MockSession* mockSessionPtr = dynamic_cast(sessionPtr.get()); }; namespace impl { diff --git a/tests/common/util/HandlerBaseTestFixture.hpp b/tests/common/util/HandlerBaseTestFixture.hpp index 8823d1b2c..e5d88eb53 100644 --- a/tests/common/util/HandlerBaseTestFixture.hpp +++ b/tests/common/util/HandlerBaseTestFixture.hpp @@ -34,22 +34,7 @@ template