From 413ec0a1cf1e9684c7d2e84092738e6d6fd5174a Mon Sep 17 00:00:00 2001 From: guyzilla Date: Wed, 1 Jan 2025 11:34:53 +0200 Subject: [PATCH] feat:Adding support for ZMPOP command (#4385) Signed-off-by: Guy Flysher --- src/facade/reply_builder.cc | 18 ++- src/facade/reply_builder.h | 5 +- src/facade/reply_builder_test.cc | 21 ++++ src/server/command_registry.cc | 2 + src/server/command_registry.h | 1 + src/server/zset_family.cc | 188 +++++++++++++++++++++++++++---- src/server/zset_family.h | 1 + src/server/zset_family_test.cc | 128 +++++++++++++++++++++ 8 files changed, 340 insertions(+), 24 deletions(-) diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index 9a909b902d12..d398ecbf8f98 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -408,8 +408,7 @@ void RedisReplyBuilder::SendBulkStrArr(const facade::ArgRange& strs, CollectionT SendBulkString(str); } -void RedisReplyBuilder::SendScoredArray(absl::Span> arr, - bool with_scores) { +void RedisReplyBuilder::SendScoredArray(ScoredArray arr, bool with_scores) { ReplyScope scope(this); StartArray((with_scores && !IsResp3()) ? arr.size() * 2 : arr.size()); for (const auto& [str, score] : arr) { @@ -421,6 +420,21 @@ void RedisReplyBuilder::SendScoredArray(absl::Span>; RedisReplyBuilder(io::Sink* sink) : RedisReplyBuilderBase(sink) { } @@ -281,8 +282,8 @@ class RedisReplyBuilder : public RedisReplyBuilderBase { void SendSimpleStrArr(const facade::ArgRange& strs); void SendBulkStrArr(const facade::ArgRange& strs, CollectionType ct = ARRAY); - void SendScoredArray(absl::Span> arr, bool with_scores); - + void SendScoredArray(ScoredArray arr, bool with_scores); + void SendLabeledScoredArray(std::string_view arr_label, ScoredArray arr); void SendStored() final; void SendSetSkipped() final; diff --git a/src/facade/reply_builder_test.cc b/src/facade/reply_builder_test.cc index 028d594eef9f..8560bcde5ee0 100644 --- a/src/facade/reply_builder_test.cc +++ b/src/facade/reply_builder_test.cc @@ -775,6 +775,27 @@ TEST_F(RedisReplyBuilderTest, SendScoredArray) { << "Resp3 WITHSCORES failed."; } +TEST_F(RedisReplyBuilderTest, SendLabeledScoredArray) { + const std::vector> scored_array{ + {"e1", 1.1}, {"e2", 2.2}, {"e3", 3.3}}; + + builder_->SetResp3(false); + builder_->SendLabeledScoredArray("foobar", scored_array); + ASSERT_TRUE(NoErrors()); + ASSERT_EQ(TakePayload(), + "*2\r\n$6\r\nfoobar\r\n*3\r\n*2\r\n$2\r\ne1\r\n$3\r\n1.1\r\n*2\r\n$2\r\ne2\r\n$3\r\n2." + "2\r\n*2\r\n$2\r\ne3\r\n$3\r\n3.3\r\n") + << "Resp3 failed.\n"; + + builder_->SetResp3(true); + builder_->SendLabeledScoredArray("foobar", scored_array); + ASSERT_TRUE(NoErrors()); + ASSERT_EQ(TakePayload(), + "*2\r\n$6\r\nfoobar\r\n*3\r\n*2\r\n$2\r\ne1\r\n,1.1\r\n*2\r\n$2\r\ne2\r\n,2.2\r\n*" + "2\r\n$2\r\ne3\r\n,3.3\r\n") + << "Resp3 failed."; +} + TEST_F(RedisReplyBuilderTest, BasicCapture) { GTEST_SKIP() << "Unmark when CaptuingReplyBuilder is updated"; diff --git a/src/server/command_registry.cc b/src/server/command_registry.cc index cce5354f8ffe..0777ad7d992e 100644 --- a/src/server/command_registry.cc +++ b/src/server/command_registry.cc @@ -283,6 +283,8 @@ const char* OptName(CO::CommandOpt fl) { return "no-key-tx-span-all"; case IDEMPOTENT: return "idempotent"; + case SLOW: + return "slow"; } return "unknown"; } diff --git a/src/server/command_registry.h b/src/server/command_registry.h index 3acc69c355bc..76f27117bef1 100644 --- a/src/server/command_registry.h +++ b/src/server/command_registry.h @@ -52,6 +52,7 @@ enum CommandOpt : uint32_t { // The same callback can be run multiple times without corrupting the result. Used for // opportunistic optimizations where inconsistencies can only be detected afterwards. IDEMPOTENT = 1U << 18, + SLOW = 1U << 19 // Unused? }; const char* OptName(CommandOpt fl); diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 56b8a1777f07..fc440ad92037 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -67,10 +67,10 @@ struct GeoPoint { double dist; double score; std::string member; - GeoPoint() : longitude(0.0), latitude(0.0), dist(0.0), score(0.0){}; + GeoPoint() : longitude(0.0), latitude(0.0), dist(0.0), score(0.0) {}; GeoPoint(double _longitude, double _latitude, double _dist, double _score, const std::string& _member) - : longitude(_longitude), latitude(_latitude), dist(_dist), score(_score), member(_member){}; + : longitude(_longitude), latitude(_latitude), dist(_dist), score(_score), member(_member) {}; }; using GeoArray = std::vector; @@ -179,8 +179,7 @@ struct ZParams { bool override = false; }; -void OutputScoredArrayResult(const OpResult& result, - const ZSetFamily::RangeParams& params, SinkReplyBuilder* builder) { +void OutputScoredArrayResult(const OpResult& result, SinkReplyBuilder* builder) { if (result.status() == OpStatus::WRONG_TYPE) { return builder->SendError(kWrongTypeErr); } @@ -188,7 +187,7 @@ void OutputScoredArrayResult(const OpResult& result, LOG_IF(WARNING, !result && result.status() != OpStatus::KEY_NOTFOUND) << "Unexpected status " << result.status(); auto* rb = static_cast(builder); - rb->SendScoredArray(result.value(), params.with_scores); + rb->SendScoredArray(result.value(), true /* with scores */); } OpResult FindZEntry(const ZParams& zparams, const OpArgs& op_args, @@ -1821,31 +1820,47 @@ void ZBooleanOperation(CmdArgList args, string_view cmd, bool is_union, bool sto } } -void ZPopMinMax(CmdArgList args, bool reverse, Transaction* tx, SinkReplyBuilder* builder) { - string_view key = ArgS(args, 0); +enum class FilterShards { NO = 0, YES = 1 }; +OpResult ZPopMinMaxInternal(std::string_view key, FilterShards should_filter_shards, + uint32 count, bool reverse, Transaction* tx) { ZSetFamily::RangeParams range_params; range_params.reverse = reverse; range_params.with_scores = true; ZSetFamily::ZRangeSpec range_spec; range_spec.params = range_params; - ZSetFamily::TopNScored sc = 1; - if (args.size() > 1) { - string_view count = ArgS(args, 1); - if (!SimpleAtoi(count, &sc)) { - return builder->SendError(kUintErr); - } - } + range_spec.interval = count; - range_spec.interval = sc; + OpResult result; + std::optional key_shard; + if (should_filter_shards == FilterShards::YES) { + key_shard = Shard(key, shard_set->size()); + } auto cb = [&](Transaction* t, EngineShard* shard) { - return OpPopCount(range_spec, t->GetOpArgs(shard), key); + if (!key_shard.has_value() || *key_shard == shard->shard_id()) { + result = std::move(OpPopCount(range_spec, t->GetOpArgs(shard), key)); + } + return OpStatus::OK; }; - OpResult result = tx->ScheduleSingleHopT(std::move(cb)); - OutputScoredArrayResult(result, range_params, builder); + tx->Execute(std::move(cb), true); + + return result; +} + +void ZPopMinMaxFromArgs(CmdArgList args, bool reverse, Transaction* tx, SinkReplyBuilder* builder) { + string_view key = ArgS(args, 0); + uint32 count = 1; + if (args.size() > 1) { + string_view count_str = ArgS(args, 1); + if (!SimpleAtoi(count_str, &count)) { + return builder->SendError(kUintErr); + } + } + + OutputScoredArrayResult(ZPopMinMaxInternal(key, FilterShards::NO, count, reverse, tx), builder); } OpResult ZGetMembers(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { @@ -2060,6 +2075,71 @@ void ZRemRangeGeneric(string_view key, const ZSetFamily::ZRangeSpec& range_spec, } } +// Returns the key of the first non empty set found in the list of shard arguments. +// Returns nullopt if none. +std::optional GetFirstNonEmptyKeyFound(EngineShard* shard, Transaction* t) { + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); + + auto& db_slice = t->GetDbSlice(shard->shard_id()); + + for (string_view key : keys) { + auto it = db_slice.FindReadOnly(t->GetDbContext(), key, OBJ_ZSET); + if (!it) { + continue; + } + return std::optional(key); + } + + return std::nullopt; +} + +// Validates the ZMPop command arguments and extracts the values to the output params. +// If the arguments are invalid sends the appropiate error to builder and returns false. +bool ValidateZMPopCommand(CmdArgList args, uint32* num_keys, bool* is_max, int* pop_count, + SinkReplyBuilder* builder) { + CmdArgParser parser{args}; + + if (!SimpleAtoi(parser.Next(), num_keys)) { + builder->SendError(kUintErr); + return false; + } + + if (*num_keys <= 0 || !parser.HasAtLeast(*num_keys + 1)) { + // We should have at least num_keys keys + a MIN/MAX arg. + builder->SendError(kSyntaxErr); + return false; + } + // Skip over the keys themselves. + parser.Skip(*num_keys); + + // We know we have at least one more arg (we checked above). + if (parser.Check("MAX")) { + *is_max = true; + } else if (parser.Check("MIN")) { + *is_max = false; + } else { + builder->SendError(kSyntaxErr); + return false; + } + + *pop_count = 1; + // Check if we have additional COUNT argument. + if (parser.HasNext()) { + if (!parser.Check("COUNT", pop_count)) { + builder->SendError(kSyntaxErr); + return false; + } + } + + if (!parser.Finalize()) { + builder->SendError(parser.Error()->MakeReply()); + return false; + } + + return true; +} + } // namespace void ZSetFamily::BZPopMin(CmdArgList args, const CommandContext& cmd_cntx) { @@ -2355,12 +2435,77 @@ void ZSetFamily::ZInterCard(CmdArgList args, const CommandContext& cmd_cntx) { builder->SendLong(result.value().size()); } +void ZSetFamily::ZMPop(CmdArgList args, const CommandContext& cmd_cntx) { + uint32 num_keys; + bool is_max; + int pop_count; + if (!ValidateZMPopCommand(args, &num_keys, &is_max, &pop_count, cmd_cntx.rb)) { + return; + } + auto* response_builder = static_cast(cmd_cntx.rb); + + // From the list of input keys, keep the first (in the order of keys in the command) key found in + // the current shard. + std::vector> first_found_key_per_shard_vec(shard_set->size(), + std::nullopt); + + auto cb = [&](Transaction* t, EngineShard* shard) { + std::optional result = GetFirstNonEmptyKeyFound(shard, t); + if (result.has_value()) { + first_found_key_per_shard_vec[shard->shard_id()] = result; + } + return OpStatus::OK; + }; + + cmd_cntx.tx->Execute(std::move(cb), false /* possibly another hop */); + + // Keep all the keys found (first only for each shard) in a set for fast lookups. + absl::flat_hash_set first_found_keys_for_shard; + // We can have at most one result from each shard. + first_found_keys_for_shard.reserve(std::min(shard_set->size(), num_keys)); + for (const auto& key : first_found_key_per_shard_vec) { + if (!key.has_value()) { + continue; + } + first_found_keys_for_shard.insert(*key); + } + + // Now that we have the first non empty key from each shard, find the first overall first key and + // pop elements from it. + std::optional key_to_pop = std::nullopt; + ArgRange arg_keys(args.subspan(1, num_keys)); + // Find the first arg_key which exists in any shard and is not empty. + for (std::string_view key : arg_keys) { + if (first_found_keys_for_shard.contains(key)) { + key_to_pop = key; + break; + } + } + + if (!key_to_pop.has_value()) { + cmd_cntx.tx->Conclude(); + response_builder->SendNull(); + return; + } + + // Pop elements from relevant set. + OpResult pop_result = + ZPopMinMaxInternal(*key_to_pop, FilterShards::YES, pop_count, is_max, cmd_cntx.tx); + + if (pop_result.status() == OpStatus::WRONG_TYPE) { + return response_builder->SendError(kWrongTypeErr); + } + + LOG_IF(WARNING, !pop_result) << "Unexpected status " << pop_result.status(); + response_builder->SendLabeledScoredArray(*key_to_pop, pop_result.value()); +} + void ZSetFamily::ZPopMax(CmdArgList args, const CommandContext& cmd_cntx) { - ZPopMinMax(std::move(args), true, cmd_cntx.tx, cmd_cntx.rb); + ZPopMinMaxFromArgs(std::move(args), true, cmd_cntx.tx, cmd_cntx.rb); } void ZSetFamily::ZPopMin(CmdArgList args, const CommandContext& cmd_cntx) { - ZPopMinMax(std::move(args), false, cmd_cntx.tx, cmd_cntx.rb); + ZPopMinMaxFromArgs(std::move(args), false, cmd_cntx.tx, cmd_cntx.rb); } void ZSetFamily::ZLexCount(CmdArgList args, const CommandContext& cmd_cntx) { @@ -3217,6 +3362,7 @@ constexpr uint32_t kZInterStore = WRITE | SORTEDSET | SLOW; constexpr uint32_t kZInter = READ | SORTEDSET | SLOW; constexpr uint32_t kZInterCard = WRITE | SORTEDSET | SLOW; constexpr uint32_t kZLexCount = READ | SORTEDSET | FAST; +constexpr uint32_t kZMPop = WRITE | SORTEDSET | SLOW; constexpr uint32_t kZPopMax = WRITE | SORTEDSET | FAST; constexpr uint32_t kZPopMin = WRITE | SORTEDSET | FAST; constexpr uint32_t kZRem = WRITE | SORTEDSET | FAST; @@ -3267,6 +3413,8 @@ void ZSetFamily::Register(CommandRegistry* registry) { << CI{"ZINTERCARD", CO::READONLY | CO::VARIADIC_KEYS, -3, 2, 2, acl::kZInterCard}.HFUNC( ZInterCard) << CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, acl::kZLexCount}.HFUNC(ZLexCount) + << CI{"ZMPOP", CO::SLOW | CO::WRITE | CO::VARIADIC_KEYS, -4, 2, 2, acl::kZMPop}.HFUNC(ZMPop) + << CI{"ZPOPMAX", CO::FAST | CO::WRITE, -2, 1, 1, acl::kZPopMax}.HFUNC(ZPopMax) << CI{"ZPOPMIN", CO::FAST | CO::WRITE, -2, 1, 1, acl::kZPopMin}.HFUNC(ZPopMin) << CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, acl::kZRem}.HFUNC(ZRem) diff --git a/src/server/zset_family.h b/src/server/zset_family.h index ec678597a791..17d4eceb24ad 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -72,6 +72,7 @@ class ZSetFamily { static void ZInter(CmdArgList args, const CommandContext& cmd_cntx); static void ZInterCard(CmdArgList args, const CommandContext& cmd_cntx); static void ZLexCount(CmdArgList args, const CommandContext& cmd_cntx); + static void ZMPop(CmdArgList args, const CommandContext& cmd_cntx); static void ZPopMax(CmdArgList args, const CommandContext& cmd_cntx); static void ZPopMin(CmdArgList args, const CommandContext& cmd_cntx); static void ZRange(CmdArgList args, const CommandContext& cmd_cntx); diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 140488741de7..e38a364fbabf 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -81,6 +81,32 @@ MATCHER_P(UnorderedScoredElementsAreMatcher, elements_list, "") { elements_list.end()); } +MATCHER_P2(ContainsLabeledScoredArrayMatcher, label, elements, "") { + auto label_vec = arg.GetVec(); + if (label_vec.size() != 2) { + *result_listener << "Labeled Scored Array does no contain two elements."; + return false; + } + + if (!ExplainMatchResult(Eq(label), label_vec[0].GetString(), result_listener)) { + return false; + } + + auto value_pairs_vec = label_vec[1].GetVec(); + std::set> actual_elements; + for (const auto& scored_element : value_pairs_vec) { + actual_elements.insert(std::make_pair(scored_element.GetVec()[0].GetString(), + scored_element.GetVec()[1].GetString())); + } + if (actual_elements != elements) { + *result_listener << "Scored elements do not match: "; + ExplainMatchResult(ElementsAreArray(elements), actual_elements, result_listener); + return false; + } + + return true; +} + auto ConsistsOf(std::initializer_list elements) { return ConsistsOfMatcher(std::unordered_set{elements}); } @@ -98,6 +124,12 @@ auto UnorderedScoredElementsAre( return UnorderedScoredElementsAreMatcher(elements); } +auto ContainsLabeledScoredArray( + std::string_view label, std::initializer_list> elements) { + return ContainsLabeledScoredArrayMatcher(label, + std::set>{elements}); +} + TEST_F(ZSetFamilyTest, Add) { auto resp = Run({"zadd", "x", "1.1", "a"}); EXPECT_THAT(resp, IntArg(1)); @@ -757,6 +789,102 @@ TEST_F(ZSetFamilyTest, ZAddBug148) { EXPECT_THAT(resp, IntArg(1)); } +TEST_F(ZSetFamilyTest, ZMPopInvalidSyntax) { + // Not enough arguments. + auto resp = Run({"zmpop", "1", "a"}); + EXPECT_THAT(resp, ErrArg("wrong number of arguments")); + + // Zero keys. + resp = Run({"zmpop", "0", "MIN", "COUNT", "1"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + // Number of keys not uint. + resp = Run({"zmpop", "aa", "a", "MIN"}); + EXPECT_THAT(resp, ErrArg("value is not an integer or out of range")); + + // Missing MIN/MAX. + resp = Run({"zmpop", "1", "a", "COUNT", "1"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + // Wrong number of keys. + resp = Run({"zmpop", "1", "a", "b", "MAX"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + // Count with no number. + resp = Run({"zmpop", "1", "a", "MAX", "COUNT"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + // Count number is not uint. + resp = Run({"zmpop", "1", "a", "MIN", "COUNT", "boo"}); + EXPECT_THAT(resp, ErrArg("value is not an integer or out of range")); + + // Too many arguments. + resp = Run({"zmpop", "1", "c", "MAX", "COUNT", "2", "foo"}); + EXPECT_THAT(resp, ErrArg("syntax error")); +} + +TEST_F(ZSetFamilyTest, ZMPop) { + // All sets are empty. + auto resp = Run({"zmpop", "1", "e", "MIN"}); + EXPECT_THAT(resp, ArgType(RespExpr::NIL)); + + // Min operation. + resp = Run({"zadd", "a", "1", "a1", "2", "a2"}); + EXPECT_THAT(resp, IntArg(2)); + + resp = Run({"zmpop", "1", "a", "MIN"}); + EXPECT_THAT(resp, ContainsLabeledScoredArray("a", {{"a1", "1"}})); + + resp = Run({"ZRANGE", "a", "0", "-1", "WITHSCORES"}); + EXPECT_THAT(resp, RespArray(ElementsAre("a2", "2"))); + + // Max operation. + resp = Run({"zadd", "b", "1", "b1", "2", "b2"}); + EXPECT_THAT(resp, IntArg(2)); + + resp = Run({"zmpop", "1", "b", "MAX"}); + EXPECT_THAT(resp, ContainsLabeledScoredArray("b", {{"b2", "2"}})); + + resp = Run({"ZRANGE", "b", "0", "-1", "WITHSCORES"}); + EXPECT_THAT(resp, RespArray(ElementsAre("b1", "1"))); + + // Count > 1. + resp = Run({"zadd", "c", "1", "c1", "2", "c2"}); + EXPECT_THAT(resp, IntArg(2)); + + resp = Run({"zmpop", "1", "c", "MAX", "COUNT", "2"}); + EXPECT_THAT(resp, ContainsLabeledScoredArray("c", {{"c1", "1"}, {"c2", "2"}})); + + resp = Run({"zcard", "c"}); + EXPECT_THAT(resp, IntArg(0)); + + // Count > #elements in set. + resp = Run({"zadd", "d", "1", "d1", "2", "d2"}); + EXPECT_THAT(resp, IntArg(2)); + + resp = Run({"zmpop", "1", "d", "MAX", "COUNT", "3"}); + EXPECT_THAT(resp, ContainsLabeledScoredArray("d", {{"d1", "1"}, {"d2", "2"}})); + + resp = Run({"zcard", "d"}); + EXPECT_THAT(resp, IntArg(0)); + + // First non empty set is not the first set. + resp = Run({"zadd", "x", "1", "x1"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"zadd", "y", "1", "y1"}); + EXPECT_THAT(resp, IntArg(1)); + + resp = Run({"zmpop", "3", "empty", "x", "y", "MAX"}); + EXPECT_THAT(resp, ContainsLabeledScoredArray("x", {{"x1", "1"}})); + + resp = Run({"zcard", "x"}); + EXPECT_THAT(resp, IntArg(0)); + + resp = Run({"ZRANGE", "y", "0", "-1", "WITHSCORES"}); + EXPECT_THAT(resp, RespArray(ElementsAre("y1", "1"))); +} + TEST_F(ZSetFamilyTest, ZPopMin) { auto resp = Run({"zadd", "key", "1", "a", "2", "b", "3", "c", "4", "d", "5", "e", "6", "f"}); EXPECT_THAT(resp, IntArg(6));