diff --git a/helio b/helio index 8985263c3acc..ceaa6f844b2a 160000 --- a/helio +++ b/helio @@ -1 +1 @@ -Subproject commit 8985263c3acca038752e8f9fdd8e9f61d2ec2b6f +Subproject commit ceaa6f844b2a72e03c1535939d21aa3fbd3c4e98 diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index db89b05f5c25..d82749c72685 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -89,7 +89,8 @@ void SendProtocolError(RedisParser::Result pres, SinkReplyBuilder* builder) { // https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html // One place to find a good implementation would be https://github.com/h2o/picohttpparser bool MatchHttp11Line(string_view line) { - return absl::StartsWith(line, "GET ") && absl::EndsWith(line, "HTTP/1.1"); + return (absl::StartsWith(line, "GET ") || absl::StartsWith(line, "POST ")) && + absl::EndsWith(line, "HTTP/1.1"); } void UpdateIoBufCapacity(const base::IoBuf& io_buf, ConnectionStats* stats, @@ -651,11 +652,13 @@ void Connection::HandleRequests() { http_res = CheckForHttpProto(peer); if (http_res) { + cc_.reset(service_->CreateContext(peer, this)); if (*http_res) { VLOG(1) << "HTTP1.1 identified"; is_http_ = true; HttpConnection http_conn{http_listener_}; http_conn.SetSocket(peer); + http_conn.set_user_data(cc_.get()); auto ec = http_conn.ParseFromBuffer(io_buf_.InputBuffer()); io_buf_.ConsumeInput(io_buf_.InputLen()); if (!ec) { @@ -666,7 +669,6 @@ void Connection::HandleRequests() { // this connection. http_conn.ReleaseSocket(); } else { - cc_.reset(service_->CreateContext(peer, this)); if (breaker_cb_) { socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); }); } @@ -674,9 +676,8 @@ void Connection::HandleRequests() { ConnectionFlow(peer); socket_->CancelOnErrorCb(); // noop if nothing is registered. - - cc_.reset(); } + cc_.reset(); } VLOG(1) << "Closed connection for peer " << remote_ep; diff --git a/src/facade/reply_capture.h b/src/facade/reply_capture.h index 7004faf5c9a1..7fe2843d23a7 100644 --- a/src/facade/reply_capture.h +++ b/src/facade/reply_capture.h @@ -47,11 +47,9 @@ class CapturingReplyBuilder : public RedisReplyBuilder { void StartCollection(unsigned len, CollectionType type) override; - private: + public: using Error = std::pair; // SendError (msg, type) using Null = std::nullptr_t; // SendNull or SendNullArray - struct SimpleString : public std::string {}; // SendSimpleString - struct BulkString : public std::string {}; // SendBulkString struct StrArrPayload { bool simple; @@ -66,7 +64,9 @@ class CapturingReplyBuilder : public RedisReplyBuilder { bool with_scores; }; - public: + struct SimpleString : public std::string {}; // SendSimpleString + struct BulkString : public std::string {}; // SendBulkString + CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL) : RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} { } @@ -89,7 +89,6 @@ class CapturingReplyBuilder : public RedisReplyBuilder { // If an error is stored inside payload, get a reference to it. static std::optional GetError(const Payload& pl); - private: struct CollectionPayload { CollectionPayload(unsigned len, CollectionType type); diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index df40cf6cee3b..bc261a7a267e 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -36,7 +36,7 @@ SET(SEARCH_FILES search/search_family.cc search/doc_index.cc search/doc_accessor add_library(dragonfly_lib engine_shard_set.cc channel_store.cc config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc - generic_family.cc hset_family.cc json_family.cc + generic_family.cc hset_family.cc http_api.cc json_family.cc ${SEARCH_FILES} list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc protocol_client.cc diff --git a/src/server/http_api.cc b/src/server/http_api.cc new file mode 100644 index 000000000000..9e6b5c9b55b5 --- /dev/null +++ b/src/server/http_api.cc @@ -0,0 +1,228 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/http_api.h" + +#include "base/logging.h" +#include "core/flatbuffers.h" +#include "facade/conn_context.h" +#include "facade/reply_builder.h" +#include "server/main_service.h" +#include "util/http/http_common.h" + +namespace dfly { +using namespace util; +using namespace std; +namespace h2 = boost::beast::http; +using facade::CapturingReplyBuilder; + +namespace { + +bool IsVectorOfStrings(flexbuffers::Reference req) { + if (!req.IsVector()) { + return false; + } + + auto vec = req.AsVector(); + if (vec.size() == 0) { + return false; + } + + for (size_t i = 0; i < vec.size(); ++i) { + if (!vec[i].IsString()) { + return false; + } + } + return true; +} + +// Escape a string so that it is legal to print it in JSON text. +std::string JsonEscape(string_view input) { + auto hex_digit = [](unsigned c) -> char { + DCHECK_LT(c, 0xFu); + return c < 10 ? c + '0' : c - 10 + 'a'; + }; + + string out; + out.reserve(input.size() + 2); + out.push_back('\"'); + + auto p = input.begin(); + auto e = input.end(); + + while (p < e) { + uint8_t c = *p; + if (c == '\\' || c == '\"') { + out.push_back('\\'); + out.push_back(*p++); + } else if (c <= 0x1f) { + switch (c) { + case '\b': + out.append("\\b"); + p++; + break; + case '\f': + out.append("\\f"); + p++; + break; + case '\n': + out.append("\\n"); + p++; + break; + case '\r': + out.append("\\r"); + p++; + break; + case '\t': + out.append("\\t"); + p++; + break; + default: + // this condition captures non readable chars with value < 32, + // so size = 1 byte (e.g control chars). + out.append("\\u00"); + out.push_back(hex_digit((c & 0xf0) >> 4)); + out.push_back(hex_digit(c & 0xf)); + p++; + } + } else { + out.push_back(*p++); + } + } + + out.push_back('\"'); + return out; +} + +struct CaptureVisitor { + CaptureVisitor() { + str = R"({"result":)"; + } + + void operator()(monostate) { + } + + void operator()(long v) { + absl::StrAppend(&str, v); + } + + void operator()(double v) { + absl::StrAppend(&str, v); + } + + void operator()(const CapturingReplyBuilder::SimpleString& ss) { + absl::StrAppend(&str, "\"", ss, "\""); + } + + void operator()(const CapturingReplyBuilder::BulkString& bs) { + absl::StrAppend(&str, JsonEscape(bs)); + } + + void operator()(CapturingReplyBuilder::Null) { + absl::StrAppend(&str, "null"); + } + + void operator()(CapturingReplyBuilder::Error err) { + str = absl::StrCat(R"({"error": ")", err.first); + } + + void operator()(facade::OpStatus status) { + absl::StrAppend(&str, "\"", facade::StatusToMsg(status), "\""); + } + + void operator()(const CapturingReplyBuilder::StrArrPayload& sa) { + absl::StrAppend(&str, "not_implemented"); + } + + void operator()(unique_ptr cp) { + if (!cp) { + absl::StrAppend(&str, "null"); + return; + } + if (cp->len == 0 && cp->type == facade::RedisReplyBuilder::ARRAY) { + absl::StrAppend(&str, "[]"); + return; + } + + absl::StrAppend(&str, "["); + for (auto& pl : cp->arr) { + visit(*this, std::move(pl)); + } + } + + void operator()(facade::SinkReplyBuilder::MGetResponse resp) { + absl::StrAppend(&str, "not_implemented"); + } + + void operator()(const CapturingReplyBuilder::ScoredArray& sarr) { + absl::StrAppend(&str, "["); + for (const auto& [key, score] : sarr.arr) { + absl::StrAppend(&str, "{", JsonEscape(key), ":", score, "},"); + } + if (sarr.arr.size() > 0) { + str.pop_back(); + } + absl::StrAppend(&str, "]"); + } + + string str; +}; + +} // namespace + +void HttpAPI(const http::QueryArgs& args, HttpRequest&& req, Service* service, + HttpContext* http_cntx) { + auto& body = req.body(); + + flexbuffers::Builder fbb; + flatbuffers::Parser parser; + flexbuffers::Reference doc; + bool success = parser.ParseFlexBuffer(body.c_str(), nullptr, &fbb); + if (success) { + fbb.Finish(); + doc = flexbuffers::GetRoot(fbb.GetBuffer()); + if (!IsVectorOfStrings(doc)) { + success = false; + } + } + + if (!success) { + auto response = http::MakeStringResponse(h2::status::bad_request); + http::SetMime(http::kTextMime, &response); + response.body() = "Failed to parse json\r\n"; + http_cntx->Invoke(std::move(response)); + return; + } + + vector cmd_args; + flexbuffers::Vector vec = doc.AsVector(); + for (size_t i = 0; i < vec.size(); ++i) { + cmd_args.push_back(vec[i].AsString().c_str()); + } + vector cmd_slices(cmd_args.size()); + for (size_t i = 0; i < cmd_args.size(); ++i) { + cmd_slices[i] = absl::MakeSpan(cmd_args[i]); + } + + facade::ConnectionContext* context = (facade::ConnectionContext*)http_cntx->user_data(); + DCHECK(context); + + facade::CapturingReplyBuilder reply_builder; + auto* prev = context->Inject(&reply_builder); + // TODO: to finish this. + service->DispatchCommand(absl::MakeSpan(cmd_slices), context); + facade::CapturingReplyBuilder::Payload payload = reply_builder.Take(); + + context->Inject(prev); + auto response = http::MakeStringResponse(); + http::SetMime(http::kJsonMime, &response); + + CaptureVisitor visitor; + std::visit(visitor, std::move(payload)); + visitor.str.append("}\r\n"); + response.body() = visitor.str; + http_cntx->Invoke(std::move(response)); +} + +} // namespace dfly diff --git a/src/server/http_api.h b/src/server/http_api.h new file mode 100644 index 000000000000..1c5cd61e9bd8 --- /dev/null +++ b/src/server/http_api.h @@ -0,0 +1,26 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "util/http/http_handler.h" + +namespace dfly { +class Service; +using HttpRequest = util::HttpListenerBase::RequestType; + +/** + * @brief The main handler function for dispatching commands via HTTP. + * + * @param args - query arguments. currently not used. + * @param req - full http request including the body that should consist of a json array + * representing a Dragonfly command. aka `["set", "foo", "bar"]` + * @param service - a pointer to dfly::Service* object. + * @param http_cntxt - a pointer to the http context object which provide dragonfly context + * information via user_data() and allows to reply with HTTP responses. + */ +void HttpAPI(const util::http::QueryArgs& args, HttpRequest&& req, Service* service, + util::HttpContext* http_cntxt); + +} // namespace dfly diff --git a/src/server/main_service.cc b/src/server/main_service.cc index cfc5e6259db2..67b8c4d6d831 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1,4 +1,4 @@ -// Copyright 2022, DragonflyDB authors. All rights reserved. +// Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // @@ -40,6 +40,7 @@ extern "C" { #include "server/generic_family.h" #include "server/hll_family.h" #include "server/hset_family.h" +#include "server/http_api.h" #include "server/json_family.h" #include "server/list_family.h" #include "server/multi_command_squasher.h" @@ -83,6 +84,9 @@ ABSL_FLAG(bool, admin_nopass, false, "If set, would enable open admin access to console on the assigned port, without " "authorization needed."); +ABSL_FLAG(bool, expose_http_api, false, + "If set, will expose a POST /api handler for sending redis commands as json array."); + ABSL_FLAG(dfly::MemoryBytesFlag, maxmemory, dfly::MemoryBytesFlag{}, "Limit on maximum-memory that is used by the database. " "0 - means the program will automatically determine its maximum memory usage. " @@ -2441,6 +2445,13 @@ void Service::ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privil base->RegisterCb("/clusterz", [this](const http::QueryArgs& args, HttpContext* send) { return ClusterHtmlPage(args, send, &cluster_family_); }); + + if (absl::GetFlag(FLAGS_expose_http_api)) { + base->RegisterCb("/api", + [this](const http::QueryArgs& args, HttpRequest&& req, HttpContext* send) { + HttpAPI(args, std::move(req), this, send); + }); + } } void Service::OnClose(facade::ConnectionContext* cntx) { diff --git a/tests/dragonfly/connection_test.py b/tests/dragonfly/connection_test.py index cfa1aab1505d..3c15f5479517 100755 --- a/tests/dragonfly/connection_test.py +++ b/tests/dragonfly/connection_test.py @@ -7,6 +7,7 @@ from redis.exceptions import ConnectionError as redis_conn_error, ResponseError import async_timeout from dataclasses import dataclass +from aiohttp import ClientSession from . import dfly_args from .instance import DflyInstance, DflyInstanceFactory @@ -67,7 +68,6 @@ def should_exclude(cmd: str): """ -@pytest.mark.asyncio @dfly_args({"proactor_threads": 4}) async def test_monitor_command(async_pool): monitor = CollectingMonitor(aioredis.Redis(connection_pool=async_pool)) @@ -90,7 +90,6 @@ async def test_monitor_command(async_pool): """ -@pytest.mark.asyncio @dfly_args({"proactor_threads": 4, "multi_exec_squash": "true"}) async def test_monitor_command_multi(async_pool): monitor = CollectingMonitor(aioredis.Redis(connection_pool=async_pool)) @@ -127,7 +126,6 @@ async def test_monitor_command_multi(async_pool): """ -@pytest.mark.asyncio @dfly_args({"proactor_threads": 4, "lua_auto_async": "false"}) async def test_monitor_command_lua(async_pool): monitor = CollectingMonitor(aioredis.Redis(connection_pool=async_pool)) @@ -151,7 +149,6 @@ async def test_monitor_command_lua(async_pool): """ -@pytest.mark.asyncio async def test_pipeline_support(async_client): def generate(max): for i in range(max): @@ -200,7 +197,6 @@ async def run_pipeline_mode(async_client: aioredis.Redis, messages): """ -@pytest.mark.asyncio async def test_pubsub_command(async_client): def generate(max): for i in range(max): @@ -276,7 +272,6 @@ async def run_multi_pubsub(async_client, messages, channel_name): """ -@pytest.mark.asyncio async def test_multi_pubsub(async_client): def generate(max): for i in range(max): @@ -293,7 +288,6 @@ def generate(max): """ -@pytest.mark.asyncio async def test_pubsub_subcommand_for_numsub(async_client): subs1 = [async_client.pubsub() for i in range(5)] for s in subs1: @@ -343,7 +337,6 @@ async def test_pubsub_subcommand_for_numsub(async_client): """ -@pytest.mark.asyncio @pytest.mark.slow @dfly_args({"proactor_threads": "1", "subscriber_thread_limit": "100"}) async def test_publish_stuck(df_server: DflyInstance, async_client: aioredis.Redis): @@ -381,7 +374,6 @@ async def pub_task(): await pub -@pytest.mark.asyncio async def test_subscribers_with_active_publisher(df_server: DflyInstance, max_connections=100): # TODO: I am not how to customize the max connections for the pool. async_pool = aioredis.ConnectionPool( @@ -562,7 +554,6 @@ async def test_large_cmd(async_client: aioredis.Redis): assert len(res) == MAX_ARR_SIZE -@pytest.mark.asyncio async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory): server: DflyInstance = df_local_factory.create( no_tls_on_admin_port="true", @@ -583,7 +574,6 @@ async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_ await client.close() -@pytest.mark.asyncio async def test_tls_insecure(with_ca_tls_server_args, with_tls_client_args, df_local_factory): server = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args) server.start() @@ -593,7 +583,6 @@ async def test_tls_insecure(with_ca_tls_server_args, with_tls_client_args, df_lo await client.close() -@pytest.mark.asyncio async def test_tls_full_auth(with_ca_tls_server_args, with_ca_tls_client_args, df_local_factory): server = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args) server.start() @@ -603,7 +592,6 @@ async def test_tls_full_auth(with_ca_tls_server_args, with_ca_tls_client_args, d await client.close() -@pytest.mark.asyncio async def test_tls_reject( with_ca_tls_server_args, with_tls_client_args, df_local_factory: DflyInstanceFactory ): @@ -620,7 +608,6 @@ async def test_tls_reject( await client.close() -@pytest.mark.asyncio @dfly_args({"proactor_threads": "4", "pipeline_squash": 10}) async def test_squashed_pipeline(async_client: aioredis.Redis): p = async_client.pipeline(transaction=False) @@ -638,7 +625,6 @@ async def test_squashed_pipeline(async_client: aioredis.Redis): res = res[11:] -@pytest.mark.asyncio @dfly_args({"proactor_threads": "4", "pipeline_squash": 10}) async def test_squashed_pipeline_seeder(df_server, df_seeder_factory): seeder = df_seeder_factory.create(port=df_server.port, keys=10_000) @@ -650,7 +636,6 @@ async def test_squashed_pipeline_seeder(df_server, df_seeder_factory): """ -@pytest.mark.asyncio @dfly_args({"proactor_threads": "4", "pipeline_squash": 1}) async def test_squashed_pipeline_multi(async_client: aioredis.Redis): p = async_client.pipeline(transaction=False) @@ -670,7 +655,6 @@ async def test_squashed_pipeline_multi(async_client: aioredis.Redis): await p.execute() -@pytest.mark.asyncio async def test_unix_domain_socket(df_local_factory, tmp_dir): server = df_local_factory.create(proactor_threads=1, port=BASE_PORT, unixsocket="./df.sock") server.start() @@ -688,7 +672,6 @@ async def test_unix_domain_socket(df_local_factory, tmp_dir): @pytest.mark.slow -@pytest.mark.asyncio async def test_nested_client_pause(async_client: aioredis.Redis): async def do_pause(): await async_client.execute_command("CLIENT", "PAUSE", "1000", "WRITE") @@ -715,7 +698,6 @@ async def do_write(): await p3 -@pytest.mark.asyncio async def test_blocking_command_client_pause(async_client: aioredis.Redis): """ 1. Check client pause success when blocking transaction is running @@ -743,7 +725,6 @@ async def lpush_command(): await blocking -@pytest.mark.asyncio async def test_multiple_blocking_commands_client_pause(async_client: aioredis.Redis): """ Check running client pause command simultaneously with running multiple blocking command @@ -765,3 +746,24 @@ async def client_pause(): assert not all.done() await all + + +@dfly_args({"proactor_threads": "1", "expose_http_api": "true"}) +async def test_http(df_server: DflyInstance): + client = df_server.client() + async with ClientSession() as session: + async with session.get(f"http://localhost:{df_server.port}") as resp: + assert resp.status == 200 + + body = '["set", "foo", "МайяХилли", "ex", "100"]' + async with session.post(f"http://localhost:{df_server.port}/api", data=body) as resp: + assert resp.status == 200 + text = await resp.text() + assert text.strip() == '{"result":"OK"}' + + body = '["get", "foo"]' + async with session.post(f"http://localhost:{df_server.port}/api", data=body) as resp: + assert resp.status == 200 + text = await resp.text() + assert text.strip() == '{"result":"МайяХилли"}' + assert await client.ttl("foo") > 0 diff --git a/tests/dragonfly/instance.py b/tests/dragonfly/instance.py index eca9ef772154..450441ab6c35 100644 --- a/tests/dragonfly/instance.py +++ b/tests/dragonfly/instance.py @@ -317,8 +317,7 @@ def __init__(self, params: DflyParams, args): def create(self, existing_port=None, **kwargs) -> DflyInstance: args = {**self.args, **kwargs} args.setdefault("dbfilename", "") - args.setdefault("jsonpathv2", None) - + args.setdefault("enable_direct_fd", None) # Testing iouring with direct_fd enabled. # MacOs does not set it automatically, so we need to set it manually args.setdefault("maxmemory", "8G") vmod = "dragonfly_connection=1,accept_server=1,listener_interface=1,main_service=1,rdb_save=1,replica=1,cluster_family=1"