Skip to content

Commit

Permalink
feat(search): Prefix search for tags (#3972)
Browse files Browse the repository at this point in the history
  • Loading branch information
dranikpg authored Oct 25, 2024
1 parent 0a62f6b commit eb9b689
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 24 deletions.
8 changes: 6 additions & 2 deletions src/core/search/ast_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ AstFieldNode::AstFieldNode(string field, AstNode&& node)
: field{field.substr(1)}, node{make_unique<AstNode>(std::move(node))} {
}

AstTagsNode::AstTagsNode(std::string tag) {
AstTagsNode::AstTagsNode(TagValue tag) {
tags = {std::move(tag)};
}

AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
AstTagsNode::AstTagsNode(AstExpr&& l, TagValue tag) {
DCHECK(holds_alternative<AstTagsNode>(l));
auto& tags_node = get<AstTagsNode>(l);

Expand Down Expand Up @@ -82,4 +82,8 @@ namespace std {
ostream& operator<<(ostream& os, optional<size_t> o) {
return os;
}

ostream& operator<<(ostream& os, dfly::search::AstTagsNode::TagValueProxy o) {
return os;
}
} // namespace std
23 changes: 18 additions & 5 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,22 @@ struct AstFieldNode {

// Stores a list of tags for a tag query
struct AstTagsNode {
AstTagsNode(std::string tag);
AstTagsNode(AstNode&& l, std::string tag);

std::vector<std::string> tags;
using TagValue = std::variant<AstTermNode, AstPrefixNode>;

struct TagValueProxy
: public AstTagsNode::TagValue { // bison needs it to be default constructible
TagValueProxy() : AstTagsNode::TagValue(AstTermNode("")) {
}
TagValueProxy(AstPrefixNode tv) : AstTagsNode::TagValue(std::move(tv)) {
}
TagValueProxy(AstTermNode tv) : AstTagsNode::TagValue(std::move(tv)) {
}
};

AstTagsNode(TagValue);
AstTagsNode(AstNode&& l, TagValue);

std::vector<TagValue> tags;
};

// Applies nearest neighbor search to the final result set
Expand Down Expand Up @@ -125,4 +137,5 @@ using AstExpr = AstNode;

namespace std {
ostream& operator<<(ostream& os, optional<size_t> o);
}
ostream& operator<<(ostream& os, dfly::search::AstTagsNode::TagValueProxy o);
} // namespace std
11 changes: 6 additions & 5 deletions src/core/search/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ double toDouble(string_view src);
%nterm <bool> opt_lparen
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr numeric_filter_expr
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list
%nterm <std::string> tag_list_element
%nterm <AstTagsNode::TagValueProxy> tag_list_element

%nterm <AstKnnNode> knn_query
%nterm <std::string> opt_knn_alias
Expand Down Expand Up @@ -179,10 +179,11 @@ tag_list:
| tag_list OR_OP tag_list_element { $$ = AstTagsNode(std::move($1), std::move($3)); }

tag_list_element:
TERM { $$ = std::move($1); }
| UINT32 { $$ = std::move($1); }
| DOUBLE { $$ = std::move($1); }
| TAG_VAL { $$ = std::move($1); }
TERM { $$ = AstTermNode(std::move($1)); }
| PREFIX { $$ = AstPrefixNode(std::move($1)); }
| UINT32 { $$ = AstTermNode(std::move($1)); }
| DOUBLE { $$ = AstTermNode(std::move($1)); }
| TAG_VAL { $$ = AstTermNode(std::move($1)); }


%%
Expand Down
47 changes: 35 additions & 12 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ struct IndexResult {

struct ProfileBuilder {
string GetNodeInfo(const AstNode& node) {
struct NodeFormatter {
void operator()(std::string* out, const AstPrefixNode& node) const {
out->append(node.prefix);
}
void operator()(std::string* out, const AstTermNode& node) const {
out->append(node.term);
}
void operator()(std::string* out, const AstTagsNode::TagValue& value) const {
visit([this, out](const auto& n) { this->operator()(out, n); }, value);
}
};
Overloaded node_info{
[](monostate) -> string { return ""s; },
[](const AstTermNode& n) { return absl::StrCat("Term{", n.term, "}"); },
Expand All @@ -125,7 +136,9 @@ struct ProfileBuilder {
auto op = n.op == AstLogicalNode::AND ? "and" : "or";
return absl::StrCat("Logical{n=", n.nodes.size(), ",o=", op, "}");
},
[](const AstTagsNode& n) { return absl::StrCat("Tags{", absl::StrJoin(n.tags, ","), "}"); },
[](const AstTagsNode& n) {
return absl::StrCat("Tags{", absl::StrJoin(n.tags, ",", NodeFormatter()), "}");
},
[](const AstFieldNode& n) { return absl::StrCat("Field{", n.field, "}"); },
[](const AstKnnNode& n) { return absl::StrCat("KNN{l=", n.limit, "}"); },
[](const AstNegateNode& n) { return absl::StrCat("Negate{}"); },
Expand Down Expand Up @@ -248,6 +261,14 @@ struct BasicSearch {
return out;
}

template <typename C>
IndexResult CollectPrefixMatches(BaseStringIndex<C>* index, std::string_view prefix) {
IndexResult result{};
index->MatchingPrefix(
prefix, [&result, this](const auto* c) { Merge(IndexResult{c}, &result, LogicOp::OR); });
return result;
}

IndexResult Search(monostate, string_view) {
return vector<DocId>{};
}
Expand Down Expand Up @@ -283,13 +304,8 @@ struct BasicSearch {
}

auto mapping = [&node, this](TextIndex* index) {
IndexResult result{};
index->MatchingPrefix(node.prefix, [&result, this](const auto* c) {
Merge(IndexResult{c}, &result, LogicOp::OR);
});
return result;
return CollectPrefixMatches(index, node.prefix);
};

return UnifyResults(GetSubResults(indices, mapping), LogicOp::OR);
}

Expand Down Expand Up @@ -330,11 +346,18 @@ struct BasicSearch {

// {tags | ...}: Unify results for all tags
IndexResult Search(const AstTagsNode& node, string_view active_field) {
if (auto* tag_index = GetIndex<TagIndex>(active_field); tag_index) {
auto mapping = [tag_index](string_view tag) { return tag_index->Matching(tag); };
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
}
return IndexResult{};
auto* tag_index = GetIndex<TagIndex>(active_field);
if (!tag_index)
return IndexResult{};

Overloaded ov{[tag_index](const AstTermNode& term) -> IndexResult {
return tag_index->Matching(term.term);
},
[tag_index, this](const AstPrefixNode& prefix) {
return CollectPrefixMatches(tag_index, prefix.prefix);
}};
auto mapping = [ov](const auto& tag) { return visit(ov, tag); };
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
}

// SORTBY field [DESC]: Sort by field. Part of params and not "core query".
Expand Down
13 changes: 13 additions & 0 deletions src/core/search/search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,19 @@ TEST_F(SearchTest, CheckTag) {
EXPECT_TRUE(Check()) << GetError();
}

TEST_F(SearchTest, CheckTagPrefix) {
PrepareSchema({{"color", SchemaField::TAG}});
PrepareQuery("@color:{green* | orange | yellow*}");

ExpectAll(Map{{"color", "green"}}, Map{{"color", "yellow"}}, Map{{"color", "greenish"}},
Map{{"color", "yellowish"}}, Map{{"color", "green-forestish"}},
Map{{"color", "yellowsunish"}}, Map{{"color", "orange"}});
ExpectNone(Map{{"color", "red"}}, Map{{"color", "blue"}}, Map{{"color", "orangeish"}},
Map{{"color", "darkgreen"}}, Map{{"color", "light-yellow"}});

EXPECT_TRUE(Check()) << GetError();
}

TEST_F(SearchTest, IntegerTerms) {
PrepareSchema({{"status", SchemaField::TAG}, {"title", SchemaField::TEXT}});

Expand Down

0 comments on commit eb9b689

Please sign in to comment.