Skip to content

Commit

Permalink
feat:Adding support for ZMPOP command (#4385)
Browse files Browse the repository at this point in the history
Signed-off-by: Guy Flysher <[email protected]>
  • Loading branch information
guyzilla authored Jan 1, 2025
1 parent c5b3584 commit 413ec0a
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 24 deletions.
18 changes: 16 additions & 2 deletions src/facade/reply_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,7 @@ void RedisReplyBuilder::SendBulkStrArr(const facade::ArgRange& strs, CollectionT
SendBulkString(str);
}

void RedisReplyBuilder::SendScoredArray(absl::Span<const std::pair<std::string, double>> arr,
bool with_scores) {
void RedisReplyBuilder::SendScoredArray(ScoredArray arr, bool with_scores) {
ReplyScope scope(this);
StartArray((with_scores && !IsResp3()) ? arr.size() * 2 : arr.size());
for (const auto& [str, score] : arr) {
Expand All @@ -421,6 +420,21 @@ void RedisReplyBuilder::SendScoredArray(absl::Span<const std::pair<std::string,
}
}

void RedisReplyBuilder::SendLabeledScoredArray(std::string_view arr_label, ScoredArray arr) {
ReplyScope scope(this);

StartArray(2);

SendBulkString(arr_label);
StartArray(arr.size());
for (const auto& [str, score] : arr) {
StartArray(2);
SendBulkString(str);
SendDouble(score);
}

}

void RedisReplyBuilder::SendStored() {
SendSimpleString("OK");
}
Expand Down
5 changes: 3 additions & 2 deletions src/facade/reply_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class RedisReplyBuilderBase : public SinkReplyBuilder {
class RedisReplyBuilder : public RedisReplyBuilderBase {
public:
using RedisReplyBuilderBase::CollectionType;
using ScoredArray = absl::Span<const std::pair<std::string, double>>;

RedisReplyBuilder(io::Sink* sink) : RedisReplyBuilderBase(sink) {
}
Expand All @@ -281,8 +282,8 @@ class RedisReplyBuilder : public RedisReplyBuilderBase {

void SendSimpleStrArr(const facade::ArgRange& strs);
void SendBulkStrArr(const facade::ArgRange& strs, CollectionType ct = ARRAY);
void SendScoredArray(absl::Span<const std::pair<std::string, double>> arr, bool with_scores);

void SendScoredArray(ScoredArray arr, bool with_scores);
void SendLabeledScoredArray(std::string_view arr_label, ScoredArray arr);
void SendStored() final;
void SendSetSkipped() final;

Expand Down
21 changes: 21 additions & 0 deletions src/facade/reply_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,27 @@ TEST_F(RedisReplyBuilderTest, SendScoredArray) {
<< "Resp3 WITHSCORES failed.";
}

TEST_F(RedisReplyBuilderTest, SendLabeledScoredArray) {
const std::vector<std::pair<std::string, double>> scored_array{
{"e1", 1.1}, {"e2", 2.2}, {"e3", 3.3}};

builder_->SetResp3(false);
builder_->SendLabeledScoredArray("foobar", scored_array);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(),
"*2\r\n$6\r\nfoobar\r\n*3\r\n*2\r\n$2\r\ne1\r\n$3\r\n1.1\r\n*2\r\n$2\r\ne2\r\n$3\r\n2."
"2\r\n*2\r\n$2\r\ne3\r\n$3\r\n3.3\r\n")
<< "Resp3 failed.\n";

builder_->SetResp3(true);
builder_->SendLabeledScoredArray("foobar", scored_array);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(),
"*2\r\n$6\r\nfoobar\r\n*3\r\n*2\r\n$2\r\ne1\r\n,1.1\r\n*2\r\n$2\r\ne2\r\n,2.2\r\n*"
"2\r\n$2\r\ne3\r\n,3.3\r\n")
<< "Resp3 failed.";
}

TEST_F(RedisReplyBuilderTest, BasicCapture) {
GTEST_SKIP() << "Unmark when CaptuingReplyBuilder is updated";

Expand Down
2 changes: 2 additions & 0 deletions src/server/command_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ const char* OptName(CO::CommandOpt fl) {
return "no-key-tx-span-all";
case IDEMPOTENT:
return "idempotent";
case SLOW:
return "slow";
}
return "unknown";
}
Expand Down
1 change: 1 addition & 0 deletions src/server/command_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ enum CommandOpt : uint32_t {
// The same callback can be run multiple times without corrupting the result. Used for
// opportunistic optimizations where inconsistencies can only be detected afterwards.
IDEMPOTENT = 1U << 18,
SLOW = 1U << 19 // Unused?
};

const char* OptName(CommandOpt fl);
Expand Down
188 changes: 168 additions & 20 deletions src/server/zset_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ struct GeoPoint {
double dist;
double score;
std::string member;
GeoPoint() : longitude(0.0), latitude(0.0), dist(0.0), score(0.0){};
GeoPoint() : longitude(0.0), latitude(0.0), dist(0.0), score(0.0) {};
GeoPoint(double _longitude, double _latitude, double _dist, double _score,
const std::string& _member)
: longitude(_longitude), latitude(_latitude), dist(_dist), score(_score), member(_member){};
: longitude(_longitude), latitude(_latitude), dist(_dist), score(_score), member(_member) {};
};
using GeoArray = std::vector<GeoPoint>;

Expand Down Expand Up @@ -179,16 +179,15 @@ struct ZParams {
bool override = false;
};

void OutputScoredArrayResult(const OpResult<ScoredArray>& result,
const ZSetFamily::RangeParams& params, SinkReplyBuilder* builder) {
void OutputScoredArrayResult(const OpResult<ScoredArray>& result, SinkReplyBuilder* builder) {
if (result.status() == OpStatus::WRONG_TYPE) {
return builder->SendError(kWrongTypeErr);
}

LOG_IF(WARNING, !result && result.status() != OpStatus::KEY_NOTFOUND)
<< "Unexpected status " << result.status();
auto* rb = static_cast<RedisReplyBuilder*>(builder);
rb->SendScoredArray(result.value(), params.with_scores);
rb->SendScoredArray(result.value(), true /* with scores */);
}

OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs& op_args,
Expand Down Expand Up @@ -1821,31 +1820,47 @@ void ZBooleanOperation(CmdArgList args, string_view cmd, bool is_union, bool sto
}
}

void ZPopMinMax(CmdArgList args, bool reverse, Transaction* tx, SinkReplyBuilder* builder) {
string_view key = ArgS(args, 0);
enum class FilterShards { NO = 0, YES = 1 };

OpResult<ScoredArray> ZPopMinMaxInternal(std::string_view key, FilterShards should_filter_shards,
uint32 count, bool reverse, Transaction* tx) {
ZSetFamily::RangeParams range_params;
range_params.reverse = reverse;
range_params.with_scores = true;
ZSetFamily::ZRangeSpec range_spec;
range_spec.params = range_params;

ZSetFamily::TopNScored sc = 1;
if (args.size() > 1) {
string_view count = ArgS(args, 1);
if (!SimpleAtoi(count, &sc)) {
return builder->SendError(kUintErr);
}
}
range_spec.interval = count;

range_spec.interval = sc;
OpResult<ScoredArray> result;

std::optional<ShardId> key_shard;
if (should_filter_shards == FilterShards::YES) {
key_shard = Shard(key, shard_set->size());
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpPopCount(range_spec, t->GetOpArgs(shard), key);
if (!key_shard.has_value() || *key_shard == shard->shard_id()) {
result = std::move(OpPopCount(range_spec, t->GetOpArgs(shard), key));
}
return OpStatus::OK;
};

OpResult<ScoredArray> result = tx->ScheduleSingleHopT(std::move(cb));
OutputScoredArrayResult(result, range_params, builder);
tx->Execute(std::move(cb), true);

return result;
}

void ZPopMinMaxFromArgs(CmdArgList args, bool reverse, Transaction* tx, SinkReplyBuilder* builder) {
string_view key = ArgS(args, 0);
uint32 count = 1;
if (args.size() > 1) {
string_view count_str = ArgS(args, 1);
if (!SimpleAtoi(count_str, &count)) {
return builder->SendError(kUintErr);
}
}

OutputScoredArrayResult(ZPopMinMaxInternal(key, FilterShards::NO, count, reverse, tx), builder);
}

OpResult<MScoreResponse> ZGetMembers(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
Expand Down Expand Up @@ -2060,6 +2075,71 @@ void ZRemRangeGeneric(string_view key, const ZSetFamily::ZRangeSpec& range_spec,
}
}

// Returns the key of the first non empty set found in the list of shard arguments.
// Returns nullopt if none.
std::optional<std::string_view> GetFirstNonEmptyKeyFound(EngineShard* shard, Transaction* t) {
ShardArgs keys = t->GetShardArgs(shard->shard_id());
DCHECK(!keys.Empty());

auto& db_slice = t->GetDbSlice(shard->shard_id());

for (string_view key : keys) {
auto it = db_slice.FindReadOnly(t->GetDbContext(), key, OBJ_ZSET);
if (!it) {
continue;
}
return std::optional<std::string_view>(key);
}

return std::nullopt;
}

// Validates the ZMPop command arguments and extracts the values to the output params.
// If the arguments are invalid sends the appropiate error to builder and returns false.
bool ValidateZMPopCommand(CmdArgList args, uint32* num_keys, bool* is_max, int* pop_count,
SinkReplyBuilder* builder) {
CmdArgParser parser{args};

if (!SimpleAtoi(parser.Next(), num_keys)) {
builder->SendError(kUintErr);
return false;
}

if (*num_keys <= 0 || !parser.HasAtLeast(*num_keys + 1)) {
// We should have at least num_keys keys + a MIN/MAX arg.
builder->SendError(kSyntaxErr);
return false;
}
// Skip over the keys themselves.
parser.Skip(*num_keys);

// We know we have at least one more arg (we checked above).
if (parser.Check("MAX")) {
*is_max = true;
} else if (parser.Check("MIN")) {
*is_max = false;
} else {
builder->SendError(kSyntaxErr);
return false;
}

*pop_count = 1;
// Check if we have additional COUNT argument.
if (parser.HasNext()) {
if (!parser.Check("COUNT", pop_count)) {
builder->SendError(kSyntaxErr);
return false;
}
}

if (!parser.Finalize()) {
builder->SendError(parser.Error()->MakeReply());
return false;
}

return true;
}

} // namespace

void ZSetFamily::BZPopMin(CmdArgList args, const CommandContext& cmd_cntx) {
Expand Down Expand Up @@ -2355,12 +2435,77 @@ void ZSetFamily::ZInterCard(CmdArgList args, const CommandContext& cmd_cntx) {
builder->SendLong(result.value().size());
}

void ZSetFamily::ZMPop(CmdArgList args, const CommandContext& cmd_cntx) {
uint32 num_keys;
bool is_max;
int pop_count;
if (!ValidateZMPopCommand(args, &num_keys, &is_max, &pop_count, cmd_cntx.rb)) {
return;
}
auto* response_builder = static_cast<RedisReplyBuilder*>(cmd_cntx.rb);

// From the list of input keys, keep the first (in the order of keys in the command) key found in
// the current shard.
std::vector<std::optional<std::string_view>> first_found_key_per_shard_vec(shard_set->size(),
std::nullopt);

auto cb = [&](Transaction* t, EngineShard* shard) {
std::optional<std::string_view> result = GetFirstNonEmptyKeyFound(shard, t);
if (result.has_value()) {
first_found_key_per_shard_vec[shard->shard_id()] = result;
}
return OpStatus::OK;
};

cmd_cntx.tx->Execute(std::move(cb), false /* possibly another hop */);

// Keep all the keys found (first only for each shard) in a set for fast lookups.
absl::flat_hash_set<std::string_view> first_found_keys_for_shard;
// We can have at most one result from each shard.
first_found_keys_for_shard.reserve(std::min(shard_set->size(), num_keys));
for (const auto& key : first_found_key_per_shard_vec) {
if (!key.has_value()) {
continue;
}
first_found_keys_for_shard.insert(*key);
}

// Now that we have the first non empty key from each shard, find the first overall first key and
// pop elements from it.
std::optional<std::string_view> key_to_pop = std::nullopt;
ArgRange arg_keys(args.subspan(1, num_keys));
// Find the first arg_key which exists in any shard and is not empty.
for (std::string_view key : arg_keys) {
if (first_found_keys_for_shard.contains(key)) {
key_to_pop = key;
break;
}
}

if (!key_to_pop.has_value()) {
cmd_cntx.tx->Conclude();
response_builder->SendNull();
return;
}

// Pop elements from relevant set.
OpResult<ScoredArray> pop_result =
ZPopMinMaxInternal(*key_to_pop, FilterShards::YES, pop_count, is_max, cmd_cntx.tx);

if (pop_result.status() == OpStatus::WRONG_TYPE) {
return response_builder->SendError(kWrongTypeErr);
}

LOG_IF(WARNING, !pop_result) << "Unexpected status " << pop_result.status();
response_builder->SendLabeledScoredArray(*key_to_pop, pop_result.value());
}

void ZSetFamily::ZPopMax(CmdArgList args, const CommandContext& cmd_cntx) {
ZPopMinMax(std::move(args), true, cmd_cntx.tx, cmd_cntx.rb);
ZPopMinMaxFromArgs(std::move(args), true, cmd_cntx.tx, cmd_cntx.rb);
}

void ZSetFamily::ZPopMin(CmdArgList args, const CommandContext& cmd_cntx) {
ZPopMinMax(std::move(args), false, cmd_cntx.tx, cmd_cntx.rb);
ZPopMinMaxFromArgs(std::move(args), false, cmd_cntx.tx, cmd_cntx.rb);
}

void ZSetFamily::ZLexCount(CmdArgList args, const CommandContext& cmd_cntx) {
Expand Down Expand Up @@ -3217,6 +3362,7 @@ constexpr uint32_t kZInterStore = WRITE | SORTEDSET | SLOW;
constexpr uint32_t kZInter = READ | SORTEDSET | SLOW;
constexpr uint32_t kZInterCard = WRITE | SORTEDSET | SLOW;
constexpr uint32_t kZLexCount = READ | SORTEDSET | FAST;
constexpr uint32_t kZMPop = WRITE | SORTEDSET | SLOW;
constexpr uint32_t kZPopMax = WRITE | SORTEDSET | FAST;
constexpr uint32_t kZPopMin = WRITE | SORTEDSET | FAST;
constexpr uint32_t kZRem = WRITE | SORTEDSET | FAST;
Expand Down Expand Up @@ -3267,6 +3413,8 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZINTERCARD", CO::READONLY | CO::VARIADIC_KEYS, -3, 2, 2, acl::kZInterCard}.HFUNC(
ZInterCard)
<< CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, acl::kZLexCount}.HFUNC(ZLexCount)
<< CI{"ZMPOP", CO::SLOW | CO::WRITE | CO::VARIADIC_KEYS, -4, 2, 2, acl::kZMPop}.HFUNC(ZMPop)

<< CI{"ZPOPMAX", CO::FAST | CO::WRITE, -2, 1, 1, acl::kZPopMax}.HFUNC(ZPopMax)
<< CI{"ZPOPMIN", CO::FAST | CO::WRITE, -2, 1, 1, acl::kZPopMin}.HFUNC(ZPopMin)
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, acl::kZRem}.HFUNC(ZRem)
Expand Down
1 change: 1 addition & 0 deletions src/server/zset_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ZSetFamily {
static void ZInter(CmdArgList args, const CommandContext& cmd_cntx);
static void ZInterCard(CmdArgList args, const CommandContext& cmd_cntx);
static void ZLexCount(CmdArgList args, const CommandContext& cmd_cntx);
static void ZMPop(CmdArgList args, const CommandContext& cmd_cntx);
static void ZPopMax(CmdArgList args, const CommandContext& cmd_cntx);
static void ZPopMin(CmdArgList args, const CommandContext& cmd_cntx);
static void ZRange(CmdArgList args, const CommandContext& cmd_cntx);
Expand Down
Loading

0 comments on commit 413ec0a

Please sign in to comment.