From c50e8322263538a139fe67e12c7d94b5c8585bb8 Mon Sep 17 00:00:00 2001 From: Stepan Bagritsevich Date: Thu, 5 Sep 2024 15:51:54 +0200 Subject: [PATCH] fix(search_family): Fix FT.AGGREGATE GROUPBY option fixes dragonflydb#3492 Signed-off-by: Stepan Bagritsevich --- src/server/search/aggregator.h | 2 +- src/server/search/doc_accessors.cc | 41 +++++--- src/server/search/doc_accessors.h | 19 ++-- src/server/search/doc_index.cc | 17 ++- src/server/search/doc_index.h | 54 ++++++++-- src/server/search/search_family.cc | 80 +++++++++----- src/server/search/search_family_test.cc | 134 ++++++++++++++++++++++-- 7 files changed, 267 insertions(+), 80 deletions(-) diff --git a/src/server/search/aggregator.h b/src/server/search/aggregator.h index c4d4a3d45569..727c0ba96ed0 100644 --- a/src/server/search/aggregator.h +++ b/src/server/search/aggregator.h @@ -75,7 +75,7 @@ Reducer::Func FindReducerFunc(ReducerFunc name); PipelineStep MakeGroupStep(absl::Span fields, std::vector reducers); -// Make `SORYBY field [DESC]` step +// Make `SORTBY field [DESC]` step PipelineStep MakeSortStep(std::string_view field, bool descending = false); // Make `LIMIT offset num` step diff --git a/src/server/search/doc_accessors.cc b/src/server/search/doc_accessors.cc index 06e90bc14252..41e62da3d167 100644 --- a/src/server/search/doc_accessors.cc +++ b/src/server/search/doc_accessors.cc @@ -38,7 +38,14 @@ string_view SdsToSafeSv(sds str) { return str != nullptr ? string_view{str, sdslen(str)} : ""sv; } -string PrintField(search::SchemaField::FieldType type, string_view value) { +search::SortableValue FieldToSortableValue(search::SchemaField::FieldType type, string_view value) { + if (type == search::SchemaField::NUMERIC) { + double value_as_double = 0; + if (!absl::SimpleAtod(value, &value_as_double)) { // temporary convert to double + VLOG(2) << "Failed to convert " << value << " to double"; + } + return value_as_double; + } if (type == search::SchemaField::VECTOR) { auto [ptr, size] = search::BytesToFtVector(value); return absl::StrCat("[", absl::StrJoin(absl::Span{ptr.get(), size}, ","), "]"); @@ -46,23 +53,29 @@ string PrintField(search::SchemaField::FieldType type, string_view value) { return string{value}; } -string ExtractValue(const search::Schema& schema, string_view key, string_view value) { +search::SortableValue ExtractSortableValue(const search::Schema& schema, string_view key, + string_view value) { auto it = schema.fields.find(key); if (it == schema.fields.end()) - return string{value}; - - return PrintField(it->second.type, value); + return FieldToSortableValue(search::SchemaField::TEXT, value); + return FieldToSortableValue(it->second.type, value); } } // namespace SearchDocData BaseAccessor::Serialize(const search::Schema& schema, - const SearchParams::FieldReturnList& fields) const { + const SelectedFields& fields) const { + if (fields.ShouldReturnAllFields()) { + return Serialize(schema); + } + return Serialize(schema, fields.GetFields()); +} + +SearchDocData BaseAccessor::Serialize(const search::Schema& schema, + const FieldsList& fields) const { SearchDocData out{}; for (const auto& [fident, fname] : fields) { - auto it = schema.fields.find(fident); - auto type = it != schema.fields.end() ? it->second.type : search::SchemaField::TEXT; - out[fname] = PrintField(type, absl::StrJoin(GetStrings(fident), ",")); + out[fname] = ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident), ",")); } return out; } @@ -89,7 +102,7 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const { string_view v = container_utils::LpGetView(fptr, intbuf_[1].data()); fptr = lpNext(lp_, fptr); - out[k] = ExtractValue(schema, k, v); + out[k] = ExtractSortableValue(schema, k, v); } return out; @@ -108,7 +121,7 @@ BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const { SearchDocData out{}; for (const auto& [kptr, vptr] : *hset_) - out[SdsToSafeSv(kptr)] = ExtractValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr)); + out[SdsToSafeSv(kptr)] = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr)); return out; } @@ -223,16 +236,16 @@ JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) c } SearchDocData JsonAccessor::Serialize(const search::Schema& schema) const { - return {{"$", json_.to_string()}}; + return {{"$", json_.to_string()}}; // todo: doubles } SearchDocData JsonAccessor::Serialize(const search::Schema& schema, - const SearchParams::FieldReturnList& fields) const { + const FieldsList& fields) const { SearchDocData out{}; for (const auto& [ident, name] : fields) { if (auto* path = GetPath(ident); path) { if (auto res = path->Evaluate(json_); !res.empty()) - out[name] = res[0].to_string(); + out[name] = res[0].to_string(); // todo: doubles } } return out; diff --git a/src/server/search/doc_accessors.h b/src/server/search/doc_accessors.h index 1d78e512672e..19b9fa0d15b3 100644 --- a/src/server/search/doc_accessors.h +++ b/src/server/search/doc_accessors.h @@ -24,12 +24,14 @@ class StringMap; // behind a document interface for quering fields and serializing. // Field string_view's are only valid until the next is requested. struct BaseAccessor : public search::DocumentAccessor { + SearchDocData Serialize(const search::Schema& schema, const SelectedFields& fields) const; + + private: // Convert the full underlying type to a map to be sent as a reply virtual SearchDocData Serialize(const search::Schema& schema) const = 0; // Serialize selected fields - virtual SearchDocData Serialize(const search::Schema& schema, - const SearchParams::FieldReturnList& fields) const; + virtual SearchDocData Serialize(const search::Schema& schema, const FieldsList& fields) const; }; // Accessor for hashes stored with listpack @@ -41,9 +43,10 @@ struct ListPackAccessor : public BaseAccessor { StringList GetStrings(std::string_view field) const override; VectorInfo GetVector(std::string_view field) const override; - SearchDocData Serialize(const search::Schema& schema) const override; private: + SearchDocData Serialize(const search::Schema& schema) const override; + mutable std::array intbuf_[2]; LpPtr lp_; }; @@ -55,9 +58,10 @@ struct StringMapAccessor : public BaseAccessor { StringList GetStrings(std::string_view field) const override; VectorInfo GetVector(std::string_view field) const override; - SearchDocData Serialize(const search::Schema& schema) const override; private: + SearchDocData Serialize(const search::Schema& schema) const override; + StringMap* hset_; }; @@ -72,13 +76,12 @@ struct JsonAccessor : public BaseAccessor { VectorInfo GetVector(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; - // The JsonAccessor works with structured types and not plain strings, so an overload is needed - SearchDocData Serialize(const search::Schema& schema, - const SearchParams::FieldReturnList& fields) const override; - static void RemoveFieldFromCache(std::string_view field); private: + // The JsonAccessor works with structured types and not plain strings, so an overload is needed + SearchDocData Serialize(const search::Schema& schema, const FieldsList& fields) const override; + /// Parses `field` into a JSON path. Caches the results internally. JsonPathContainer* GetPath(std::string_view field) const; diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index 98e27a102654..ed4bac86995f 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -61,7 +61,7 @@ bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const { bool SearchParams::ShouldReturnField(std::string_view field) const { auto cb = [field](const auto& entry) { return entry.first == field; }; - return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb); + return !return_fields.fields || any_of(return_fields->begin(), return_fields->end(), cb); } string_view SearchFieldTypeToString(search::SchemaField::FieldType type) { @@ -228,8 +228,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa } auto accessor = GetAccessor(op_args.db_cntx, (*it)->second); - auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields) - : accessor->Serialize(base_->schema); + auto doc_data = accessor->Serialize(base_->schema, params.return_fields); auto score = search_results.scores.empty() ? monostate{} : std::move(search_results.scores[i]); out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)}); @@ -239,19 +238,15 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa std::move(search_results.profile)}; } -vector> ShardDocIndex::SearchForAggregator( - const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const { +vector ShardDocIndex::SearchForAggregator( + const OpArgs& op_args, const AggregateParams& params, + search::SearchAlgorithm* search_algo) const { auto& db_slice = op_args.GetDbSlice(); auto search_results = search_algo->Search(&indices_); if (!search_results.error.empty()) return {}; - // Convert load_fields into return_list required by accessor interface - SearchParams::FieldReturnList return_fields; - for (string_view load_field : load_fields) - return_fields.emplace_back(indices_.GetSchema().LookupAlias(load_field), load_field); - vector> out; for (DocId doc : search_results.ids) { auto key = key_index_.Get(doc); @@ -262,7 +257,7 @@ vector> ShardDocIndex::Search auto accessor = GetAccessor(op_args.db_cntx, (*it)->second); auto extracted = indices_.ExtractStoredValues(doc); - auto loaded = accessor->Serialize(base_->schema, return_fields); + auto loaded = accessor->Serialize(base_->schema, params.load_fields); out.emplace_back(make_move_iterator(extracted.begin()), make_move_iterator(extracted.end())); out.back().insert(make_move_iterator(loaded.begin()), make_move_iterator(loaded.end())); diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index edcd928f1825..6c96ca66d321 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -16,11 +16,12 @@ #include "core/mi_memory_resource.h" #include "core/search/search.h" #include "server/common.h" +#include "server/search/aggregator.h" #include "server/table.h" namespace dfly { -using SearchDocData = absl::flat_hash_map; +using SearchDocData = absl::flat_hash_map; std::string_view SearchFieldTypeToString(search::SchemaField::FieldType); @@ -51,27 +52,63 @@ struct SearchResult { std::optional error; }; -struct SearchParams { - using FieldReturnList = - std::vector>; +using FieldsList = std::vector>; + +struct SelectedFields { + /* + 1. If not set -> return all fields + 2. If set but empty -> no fields should be returned + 3. If set and not empty -> return only these fields + */ + std::optional fields; + bool ShouldReturnAllFields() const { + return !fields.has_value(); + } + + bool ShouldReturnNoFields() const { + return fields && fields->empty(); + } + + FieldsList* operator->() { + return &fields.value(); + } + + const FieldsList* operator->() const { + return &fields.value(); + } + + const FieldsList& GetFields() const { + return fields.value(); + } +}; + +struct SearchParams { // Parameters for "LIMIT offset total": select total amount documents with a specific offset from // the whole result set size_t limit_offset = 0; size_t limit_total = 10; // Set but empty means no fields should be returned - std::optional return_fields; + SelectedFields return_fields; std::optional sort_option; search::QueryParams query_params; bool IdsOnly() const { - return return_fields && return_fields->empty(); + return return_fields.ShouldReturnNoFields(); } bool ShouldReturnField(std::string_view field) const; }; +struct AggregateParams { + std::string_view index, query; + search::QueryParams params; + + SelectedFields load_fields; + std::vector steps; +}; + // Stores basic info about a document index. struct DocIndex { enum DataType { HASH, JSON }; @@ -126,8 +163,9 @@ class ShardDocIndex { search::SearchAlgorithm* search_algo) const; // Perform search and load requested values - note params might be interpreted differently. - std::vector> SearchForAggregator( - const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const; + std::vector SearchForAggregator(const OpArgs& op_args, + const AggregateParams& params, + search::SearchAlgorithm* search_algo) const; // Return whether base index matches bool Matches(std::string_view key, unsigned obj_code) const; diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index efe8b1d9f878..718e62118afa 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -204,14 +204,14 @@ optional ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC } else if (parser.Check("RETURN")) { // RETURN {num} [{ident} AS {name}...] size_t num_fields = parser.Next(); - params.return_fields = SearchParams::FieldReturnList{}; + params.return_fields.fields.emplace(); while (params.return_fields->size() < num_fields) { string_view ident = parser.Next(); string_view alias = parser.Check("AS") ? parser.Next() : ident; params.return_fields->emplace_back(ident, alias); } } else if (parser.Check("NOCONTENT")) { // NOCONTENT - params.return_fields = SearchParams::FieldReturnList{}; + params.return_fields.fields.emplace(); } else if (parser.Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector] params.query_params = ParseQueryParams(&parser); } else if (parser.Check("SORTBY")) { @@ -230,13 +230,15 @@ optional ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC return params; } -struct AggregateParams { - string_view index, query; - search::QueryParams params; - - vector load_fields; - vector steps; -}; +std::optional ParseField(CmdArgParser* parser, bool expect_at_sign = false) { + std::string_view field = parser->Next(); + if (field.front() == '@') { + field.remove_prefix(1); // remove leading @ + } else if (expect_at_sign) { + return std::nullopt; // if we expect @, but it's not there, return nullopt + } + return field; +} optional ParseAggregatorParamsOrReply(CmdArgParser parser, ConnectionContext* cntx) { @@ -246,17 +248,28 @@ optional ParseAggregatorParamsOrReply(CmdArgParser parser, while (parser.HasNext()) { // LOAD count field [field ...] if (parser.Check("LOAD")) { - params.load_fields.resize(parser.Next()); - for (string_view& field : params.load_fields) - field = parser.Next(); + size_t num_fields = parser.Next(); + params.load_fields.fields.emplace(); + while (params.load_fields->size() < num_fields) { + string_view field = ParseField(&parser).value(); + string_view alias = parser.Check("AS") ? parser.Next() : field; + params.load_fields->emplace_back(field, alias); + } continue; } // GROUPBY nargs property [property ...] if (parser.Check("GROUPBY")) { vector fields(parser.Next()); - for (string_view& field : fields) - field = parser.Next(); + for (string_view& field : fields) { + auto parsed_field = ParseField(&parser, true); + if (!parsed_field) { + cntx->SendError(absl::StrCat("bad arguments for GROUPBY: Unknown property '", field, + "'. Did you mean '@", field, "`?")); + return nullopt; + } + field = parsed_field.value(); + } vector reducers; while (parser.Check("REDUCE")) { @@ -273,7 +286,10 @@ optional ParseAggregatorParamsOrReply(CmdArgParser parser, auto func = aggregate::FindReducerFunc(*func_name); auto nargs = parser.Next(); - string source_field = nargs > 0 ? parser.Next() : ""; + string source_field; + if (nargs > 0) { + source_field = ParseField(&parser).value(); + } parser.ExpectTag("AS"); string result_field = parser.Next(); @@ -321,13 +337,23 @@ optional ParseAggregatorParamsOrReply(CmdArgParser parser, return params; } +auto SortableValueSender(RedisReplyBuilder* rb) { + return Overloaded{ + [rb](monostate) { rb->SendNull(); }, + [rb](double d) { rb->SendDouble(d); }, + [rb](const string& s) { rb->SendBulkString(s); }, + }; +} + void SendSerializedDoc(const SerializedSearchDoc& doc, ConnectionContext* cntx) { auto* rb = static_cast(cntx->reply_builder()); + auto sortable_value_sender = SortableValueSender(rb); + rb->SendBulkString(doc.key); rb->StartCollection(doc.values.size(), RedisReplyBuilder::MAP); for (const auto& [k, v] : doc.values) { rb->SendBulkString(k); - rb->SendBulkString(v); + visit(sortable_value_sender, v); } } @@ -804,14 +830,14 @@ void SearchFamily::FtAggregate(CmdArgList args, ConnectionContext* cntx) { if (!search_algo.Init(params->query, ¶ms->params, nullptr)) return cntx->SendError("Query syntax error"); - using ResultContainer = - decltype(declval().SearchForAggregator(declval(), {}, &search_algo)); + using ResultContainer = decltype(declval().SearchForAggregator( + declval(), params.value(), &search_algo)); vector query_results(shard_set->size()); cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { if (auto* index = es->search_indices()->GetIndex(params->index); index) { query_results[es->shard_id()] = - index->SearchForAggregator(t->GetOpArgs(es), params->load_fields, &search_algo); + index->SearchForAggregator(t->GetOpArgs(es), params.value(), &search_algo); } return OpStatus::OK; }); @@ -826,20 +852,18 @@ void SearchFamily::FtAggregate(CmdArgList args, ConnectionContext* cntx) { if (!agg_results.has_value()) return cntx->SendError(agg_results.error()); + size_t result_size = agg_results->size(); auto* rb = static_cast(cntx->reply_builder()); - Overloaded replier{ - [rb](monostate) { rb->SendNull(); }, - [rb](double d) { rb->SendDouble(d); }, - [rb](const string& s) { rb->SendBulkString(s); }, - }; + auto sortable_value_sender = SortableValueSender(rb); + + rb->StartArray(result_size + 1); + rb->SendLong(result_size); - rb->StartArray(agg_results->size()); for (const auto& result : agg_results.value()) { - rb->StartArray(result.size()); + rb->StartArray(result.size() * 2); for (const auto& [k, v] : result) { - rb->StartArray(2); rb->SendBulkString(k); - visit(replier, v); + std::visit(sortable_value_sender, v); } } } diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 6ca38c5c6c78..d2015d51b483 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -70,6 +70,73 @@ template auto IsUnordArray(Args... args) { return RespArray(UnorderedElementsAre(std::forward(args)...)); } +MATCHER_P(IsMapMatcher, expected, "") { + if (arg.type != RespExpr::ARRAY) { + *result_listener << "Wrong response type: " << arg.type; + return false; + } + + auto result = arg.GetVec(); + if (result.size() != expected.size()) { + *result_listener << "Wrong resp array size: " << result.size(); + return false; + } + + using KeyValueArray = std::vector>; + + KeyValueArray received_pairs; + for (size_t i = 0; i < result.size(); i += 2) { + received_pairs.emplace_back(result[i].GetString(), result[i + 1].GetString()); + } + + KeyValueArray expected_pairs; + for (size_t i = 0; i < expected.size(); i += 2) { + expected_pairs.emplace_back(expected[i], expected[i + 1]); + } + + // Custom unordered comparison + std::sort(received_pairs.begin(), received_pairs.end()); + std::sort(expected_pairs.begin(), expected_pairs.end()); + + return received_pairs == expected_pairs; +} + +template auto IsMap(Matchers... matchers) { + return IsMapMatcher(std::vector{std::forward(matchers)...}); +} + +MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") { + if (arg.type != RespExpr::ARRAY) { + *result_listener << "Wrong response type: " << arg.type; + return false; + } + + auto result = arg.GetVec(); + size_t expected_size = std::tuple_size::value; + if (result.size() != expected_size + 1) { + *result_listener << "Wrong resp array size: " << result.size(); + return false; + } + + if (result[0].GetInt() != expected_size) { + *result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1); + return false; + } + + std::vector received_elements(result.begin() + 1, result.end()); + + // Create a vector of matchers from the tuple + std::vector> matchers; + std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected); + + return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements, + result_listener); +} + +template auto IsUnordArrayWithSize(Matchers... matchers) { + return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...)); +} + TEST_F(SearchFamilyTest, CreateDropListIndex) { EXPECT_EQ(Run({"ft.create", "idx-1", "ON", "HASH", "PREFIX", "1", "prefix-1"}), "OK"); EXPECT_EQ(Run({"ft.create", "idx-2", "ON", "JSON", "PREFIX", "1", "prefix-2"}), "OK"); @@ -707,6 +774,56 @@ TEST_F(SearchFamilyTest, SimpleExpiry) { Run({"flushall"}); } +TEST_F(SearchFamilyTest, AggregateGroupBy) { + Run({"hset", "key:1", "word", "item1", "foo", "10", "text", "\"first key\"", "non_indexed_value", + "1"}); + Run({"hset", "key:2", "word", "item2", "foo", "20", "text", "\"second key\"", "non_indexed_value", + "2"}); + Run({"hset", "key:3", "word", "item1", "foo", "40", "text", "\"third key\"", "non_indexed_value", + "3"}); + + auto resp = Run( + {"ft.create", "i1", "ON", "HASH", "SCHEMA", "word", "TAG", "foo", "NUMERIC", "text", "TEXT"}); + EXPECT_EQ(resp, "OK"); + + resp = Run( + {"ft.aggregate", "i1", "*", "GROUPBY", "1", "@word", "REDUCE", "COUNT", "0", "AS", "count"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("count", "2", "word", "item1"), + IsMap("word", "item2", "count", "1"))); + + resp = Run({"ft.aggregate", "i1", "*", "GROUPBY", "1", "@word", "REDUCE", "SUM", "1", "@foo", + "AS", "foo_total"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "50", "word", "item1"), + IsMap("foo_total", "20", "word", "item2"))); + + resp = Run({"ft.aggregate", "i1", "*", "GROUPBY", "1", "@word", "REDUCE", "AVG", "1", "@foo", + "AS", "foo_average"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_average", "20", "word", "item2"), + IsMap("foo_average", "25", "word", "item1"))); + + resp = Run({"ft.aggregate", "i1", "*", "GROUPBY", "2", "@word", "@text", "REDUCE", "SUM", "1", + "@foo", "AS", "foo_total"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("foo_total", "10", "word", "item1", "text", "\"first key\""), + IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""), + IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""))); + + resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "word", "GROUPBY", "1", "@word", + "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"), + IsMap("foo_total", "50", "word", "item1"))); + + /* + Temporary not supported + + resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word", + "@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp, + IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second + key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""), + IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"}))); + */ +} + TEST_F(SearchFamilyTest, AggregateGroupByReduceSort) { for (size_t i = 0; i < 101; i++) { // 51 even, 50 odd Run({"hset", absl::StrCat("k", i), "even", (i % 2 == 0) ? "true" : "false", "value", @@ -716,7 +833,7 @@ TEST_F(SearchFamilyTest, AggregateGroupByReduceSort) { // clang-format off auto resp = Run({"ft.aggregate", "i1", "*", - "GROUPBY", "1", "even", + "GROUPBY", "1", "@even", "REDUCE", "count", "0", "as", "count", "REDUCE", "count_distinct", "1", "even", "as", "distinct_tags", "REDUCE", "count_distinct", "1", "value", "as", "distinct_vals", @@ -726,12 +843,10 @@ TEST_F(SearchFamilyTest, AggregateGroupByReduceSort) { // clang-format on EXPECT_THAT(resp, - IsArray(IsUnordArray(IsArray("even", "false"), IsArray("count", "50"), - IsArray("distinct_tags", "1"), IsArray("distinct_vals", "50"), - IsArray("max_val", "99"), IsArray("min_val", "1")), - IsUnordArray(IsArray("even", "true"), IsArray("count", "51"), - IsArray("distinct_tags", "1"), IsArray("distinct_vals", "51"), - IsArray("max_val", "100"), IsArray("min_val", "0")))); + IsUnordArrayWithSize(IsMap("even", "false", "count", "50", "distinct_tags", "1", + "distinct_vals", "50", "max_val", "99", "min_val", "1"), + IsMap("even", "true", "count", "51", "distinct_tags", "1", + "distinct_vals", "51", "max_val", "100", "min_val", "0"))); } TEST_F(SearchFamilyTest, AggregateLoadGroupBy) { @@ -744,11 +859,10 @@ TEST_F(SearchFamilyTest, AggregateLoadGroupBy) { // clang-format off auto resp = Run({"ft.aggregate", "i1", "*", "LOAD", "1", "even", - "GROUPBY", "1", "even"}); + "GROUPBY", "1", "@even"}); // clang-format on - EXPECT_THAT(resp, IsUnordArray(IsUnordArray(IsArray("even", "false")), - IsUnordArray(IsArray("even", "true")))); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("even", "false"), IsMap("even", "true"))); } TEST_F(SearchFamilyTest, Vector) {