Skip to content

Commit

Permalink
fix(search_family): Fix FT.AGGREGATE GROUPBY option
Browse files Browse the repository at this point in the history
fixes #3492

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan committed Sep 5, 2024
1 parent a1e9ee1 commit c50e832
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 80 deletions.
2 changes: 1 addition & 1 deletion src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Reducer::Func FindReducerFunc(ReducerFunc name);
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers);

// Make `SORYBY field [DESC]` step
// Make `SORTBY field [DESC]` step
PipelineStep MakeSortStep(std::string_view field, bool descending = false);

// Make `LIMIT offset num` step
Expand Down
41 changes: 27 additions & 14 deletions src/server/search/doc_accessors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,44 @@ string_view SdsToSafeSv(sds str) {
return str != nullptr ? string_view{str, sdslen(str)} : ""sv;
}

string PrintField(search::SchemaField::FieldType type, string_view value) {
search::SortableValue FieldToSortableValue(search::SchemaField::FieldType type, string_view value) {
if (type == search::SchemaField::NUMERIC) {
double value_as_double = 0;
if (!absl::SimpleAtod(value, &value_as_double)) { // temporary convert to double
VLOG(2) << "Failed to convert " << value << " to double";
}
return value_as_double;
}
if (type == search::SchemaField::VECTOR) {
auto [ptr, size] = search::BytesToFtVector(value);
return absl::StrCat("[", absl::StrJoin(absl::Span<const float>{ptr.get(), size}, ","), "]");
}
return string{value};
}

string ExtractValue(const search::Schema& schema, string_view key, string_view value) {
search::SortableValue ExtractSortableValue(const search::Schema& schema, string_view key,
string_view value) {
auto it = schema.fields.find(key);
if (it == schema.fields.end())
return string{value};

return PrintField(it->second.type, value);
return FieldToSortableValue(search::SchemaField::TEXT, value);
return FieldToSortableValue(it->second.type, value);
}

} // namespace

SearchDocData BaseAccessor::Serialize(const search::Schema& schema,
const SearchParams::FieldReturnList& fields) const {
const SelectedFields& fields) const {
if (fields.ShouldReturnAllFields()) {
return Serialize(schema);
}
return Serialize(schema, fields.GetFields());
}

SearchDocData BaseAccessor::Serialize(const search::Schema& schema,
const FieldsList& fields) const {
SearchDocData out{};
for (const auto& [fident, fname] : fields) {
auto it = schema.fields.find(fident);
auto type = it != schema.fields.end() ? it->second.type : search::SchemaField::TEXT;
out[fname] = PrintField(type, absl::StrJoin(GetStrings(fident), ","));
out[fname] = ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident), ","));
}
return out;
}
Expand All @@ -89,7 +102,7 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
string_view v = container_utils::LpGetView(fptr, intbuf_[1].data());
fptr = lpNext(lp_, fptr);

out[k] = ExtractValue(schema, k, v);
out[k] = ExtractSortableValue(schema, k, v);
}

return out;
Expand All @@ -108,7 +121,7 @@ BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field)
SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
SearchDocData out{};
for (const auto& [kptr, vptr] : *hset_)
out[SdsToSafeSv(kptr)] = ExtractValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr));
out[SdsToSafeSv(kptr)] = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr));

return out;
}
Expand Down Expand Up @@ -223,16 +236,16 @@ JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) c
}

SearchDocData JsonAccessor::Serialize(const search::Schema& schema) const {
return {{"$", json_.to_string()}};
return {{"$", json_.to_string()}}; // todo: doubles
}

SearchDocData JsonAccessor::Serialize(const search::Schema& schema,
const SearchParams::FieldReturnList& fields) const {
const FieldsList& fields) const {
SearchDocData out{};
for (const auto& [ident, name] : fields) {
if (auto* path = GetPath(ident); path) {
if (auto res = path->Evaluate(json_); !res.empty())
out[name] = res[0].to_string();
out[name] = res[0].to_string(); // todo: doubles
}
}
return out;
Expand Down
19 changes: 11 additions & 8 deletions src/server/search/doc_accessors.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class StringMap;
// behind a document interface for quering fields and serializing.
// Field string_view's are only valid until the next is requested.
struct BaseAccessor : public search::DocumentAccessor {
SearchDocData Serialize(const search::Schema& schema, const SelectedFields& fields) const;

private:
// Convert the full underlying type to a map<string, string> to be sent as a reply
virtual SearchDocData Serialize(const search::Schema& schema) const = 0;

// Serialize selected fields
virtual SearchDocData Serialize(const search::Schema& schema,
const SearchParams::FieldReturnList& fields) const;
virtual SearchDocData Serialize(const search::Schema& schema, const FieldsList& fields) const;
};

// Accessor for hashes stored with listpack
Expand All @@ -41,9 +43,10 @@ struct ListPackAccessor : public BaseAccessor {

StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;

private:
SearchDocData Serialize(const search::Schema& schema) const override;

mutable std::array<uint8_t, 33> intbuf_[2];
LpPtr lp_;
};
Expand All @@ -55,9 +58,10 @@ struct StringMapAccessor : public BaseAccessor {

StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;

private:
SearchDocData Serialize(const search::Schema& schema) const override;

StringMap* hset_;
};

Expand All @@ -72,13 +76,12 @@ struct JsonAccessor : public BaseAccessor {
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;

// The JsonAccessor works with structured types and not plain strings, so an overload is needed
SearchDocData Serialize(const search::Schema& schema,
const SearchParams::FieldReturnList& fields) const override;

static void RemoveFieldFromCache(std::string_view field);

private:
// The JsonAccessor works with structured types and not plain strings, so an overload is needed
SearchDocData Serialize(const search::Schema& schema, const FieldsList& fields) const override;

/// Parses `field` into a JSON path. Caches the results internally.
JsonPathContainer* GetPath(std::string_view field) const;

Expand Down
17 changes: 6 additions & 11 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const {

bool SearchParams::ShouldReturnField(std::string_view field) const {
auto cb = [field](const auto& entry) { return entry.first == field; };
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
return !return_fields.fields || any_of(return_fields->begin(), return_fields->end(), cb);
}

string_view SearchFieldTypeToString(search::SchemaField::FieldType type) {
Expand Down Expand Up @@ -228,8 +228,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
}

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields)
: accessor->Serialize(base_->schema);
auto doc_data = accessor->Serialize(base_->schema, params.return_fields);

auto score = search_results.scores.empty() ? monostate{} : std::move(search_results.scores[i]);
out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)});
Expand All @@ -239,19 +238,15 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
std::move(search_results.profile)};
}

vector<absl::flat_hash_map<string, search::SortableValue>> ShardDocIndex::SearchForAggregator(
const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const {
vector<SearchDocData> ShardDocIndex::SearchForAggregator(
const OpArgs& op_args, const AggregateParams& params,
search::SearchAlgorithm* search_algo) const {
auto& db_slice = op_args.GetDbSlice();
auto search_results = search_algo->Search(&indices_);

if (!search_results.error.empty())
return {};

// Convert load_fields into return_list required by accessor interface
SearchParams::FieldReturnList return_fields;
for (string_view load_field : load_fields)
return_fields.emplace_back(indices_.GetSchema().LookupAlias(load_field), load_field);

vector<absl::flat_hash_map<string, search::SortableValue>> out;
for (DocId doc : search_results.ids) {
auto key = key_index_.Get(doc);
Expand All @@ -262,7 +257,7 @@ vector<absl::flat_hash_map<string, search::SortableValue>> ShardDocIndex::Search

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto extracted = indices_.ExtractStoredValues(doc);
auto loaded = accessor->Serialize(base_->schema, return_fields);
auto loaded = accessor->Serialize(base_->schema, params.load_fields);

out.emplace_back(make_move_iterator(extracted.begin()), make_move_iterator(extracted.end()));
out.back().insert(make_move_iterator(loaded.begin()), make_move_iterator(loaded.end()));
Expand Down
54 changes: 46 additions & 8 deletions src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
#include "core/mi_memory_resource.h"
#include "core/search/search.h"
#include "server/common.h"
#include "server/search/aggregator.h"
#include "server/table.h"

namespace dfly {

using SearchDocData = absl::flat_hash_map<std::string /*field*/, std::string /*value*/>;
using SearchDocData = absl::flat_hash_map<std::string /*field*/, search::SortableValue /*value*/>;

std::string_view SearchFieldTypeToString(search::SchemaField::FieldType);

Expand Down Expand Up @@ -51,27 +52,63 @@ struct SearchResult {
std::optional<facade::ErrorReply> error;
};

struct SearchParams {
using FieldReturnList =
std::vector<std::pair<std::string /*identifier*/, std::string /*short name*/>>;
using FieldsList = std::vector<std::pair<std::string /*identifier*/, std::string /*short name*/>>;

struct SelectedFields {
/*
1. If not set -> return all fields
2. If set but empty -> no fields should be returned
3. If set and not empty -> return only these fields
*/
std::optional<FieldsList> fields;

bool ShouldReturnAllFields() const {
return !fields.has_value();
}

bool ShouldReturnNoFields() const {
return fields && fields->empty();
}

FieldsList* operator->() {
return &fields.value();
}

const FieldsList* operator->() const {
return &fields.value();
}

const FieldsList& GetFields() const {
return fields.value();
}
};

struct SearchParams {
// Parameters for "LIMIT offset total": select total amount documents with a specific offset from
// the whole result set
size_t limit_offset = 0;
size_t limit_total = 10;

// Set but empty means no fields should be returned
std::optional<FieldReturnList> return_fields;
SelectedFields return_fields;
std::optional<search::SortOption> sort_option;
search::QueryParams query_params;

bool IdsOnly() const {
return return_fields && return_fields->empty();
return return_fields.ShouldReturnNoFields();
}

bool ShouldReturnField(std::string_view field) const;
};

struct AggregateParams {
std::string_view index, query;
search::QueryParams params;

SelectedFields load_fields;
std::vector<aggregate::PipelineStep> steps;
};

// Stores basic info about a document index.
struct DocIndex {
enum DataType { HASH, JSON };
Expand Down Expand Up @@ -126,8 +163,9 @@ class ShardDocIndex {
search::SearchAlgorithm* search_algo) const;

// Perform search and load requested values - note params might be interpreted differently.
std::vector<absl::flat_hash_map<std::string, search::SortableValue>> SearchForAggregator(
const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const;
std::vector<SearchDocData> SearchForAggregator(const OpArgs& op_args,
const AggregateParams& params,
search::SearchAlgorithm* search_algo) const;

// Return whether base index matches
bool Matches(std::string_view key, unsigned obj_code) const;
Expand Down
Loading

0 comments on commit c50e832

Please sign in to comment.