Skip to content

Commit

Permalink
feat(search): Prefix search* (#3913)
Browse files Browse the repository at this point in the history
  • Loading branch information
dranikpg authored Oct 14, 2024
1 parent 588d6cc commit f455981
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/core/search/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ cxx_test(compressed_sorted_set_test query_parser LABELS DFLY)
cxx_test(block_list_test query_parser LABELS DFLY)
cxx_test(rax_tree_test redis_test_lib LABELS DFLY)
cxx_test(search_parser_test query_parser LABELS DFLY)
cxx_test(search_test query_parser LABELS DFLY)
cxx_test(search_test redis_test_lib query_parser LABELS DFLY)
6 changes: 5 additions & 1 deletion src/core/search/ast_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ using namespace std;

namespace dfly::search {

AstTermNode::AstTermNode(string term) : term{term} {
AstTermNode::AstTermNode(string term) : term{std::move(term)} {
}

AstPrefixNode::AstPrefixNode(string prefix) : prefix{std::move(prefix)} {
this->prefix.pop_back();
}

AstRangeNode::AstRangeNode(double lo, bool lo_excl, double hi, bool hi_excl)
Expand Down
12 changes: 9 additions & 3 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@ struct AstStarNode {};

// Matches terms in text fields
struct AstTermNode {
AstTermNode(std::string term);
explicit AstTermNode(std::string term);

std::string term;
};

struct AstPrefixNode {
explicit AstPrefixNode(std::string prefix);

std::string prefix;
};

// Matches numeric range
struct AstRangeNode {
AstRangeNode(double lo, bool lo_excl, double hi, bool hi_excl);
Expand Down Expand Up @@ -97,8 +103,8 @@ struct AstSortNode {
};

using NodeVariants =
std::variant<std::monostate, AstStarNode, AstTermNode, AstRangeNode, AstNegateNode,
AstLogicalNode, AstFieldNode, AstTagsNode, AstKnnNode, AstSortNode>;
std::variant<std::monostate, AstStarNode, AstTermNode, AstPrefixNode, AstRangeNode,
AstNegateNode, AstLogicalNode, AstFieldNode, AstTagsNode, AstKnnNode, AstSortNode>;

struct AstNode : public NodeVariants {
using variant::variant;
Expand Down
9 changes: 9 additions & 0 deletions src/core/search/indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ const typename BaseStringIndex<C>::Container* BaseStringIndex<C>::Matching(strin
return (it != entries_.end()) ? &it->second : nullptr;
}

template <typename C>
void BaseStringIndex<C>::MatchingPrefix(std::string_view prefix,
absl::FunctionRef<void(const Container*)> cb) const {
for (auto it = entries_.lower_bound(prefix);
it != entries_.end() && (*it).first.rfind(prefix, 0) == 0; ++it) {
cb(&(*it).second);
}
}

template <typename C>
typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_view word) {
auto* mr = entries_.get_allocator().resource();
Expand Down
30 changes: 6 additions & 24 deletions src/core/search/indices.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
#include <optional>
#include <vector>

#include "absl/functional/function_ref.h"
#include "base/pmr/memory_resource.h"
#include "core/search/base.h"
#include "core/search/block_list.h"
#include "core/search/compressed_sorted_set.h"
#include "core/search/rax_tree.h"

// TODO: move core field definitions out of big header
#include "core/search/search.h"
Expand Down Expand Up @@ -51,37 +53,17 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
// Pointer is valid as long as index is not mutated. Nullptr if not found
const Container* Matching(std::string_view str) const;

// Iterate over all Machting on prefix.
void MatchingPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;

// Returns all the terms that appear as keys in the reverse index.
std::vector<std::string> GetTerms() const;

protected:
Container* GetOrCreate(std::string_view word);

struct PmrEqual {
using is_transparent = void;
bool operator()(const PMR_NS::string& lhs, const PMR_NS::string& rhs) const {
return lhs == rhs;
}
bool operator()(const PMR_NS::string& lhs, const std::string_view& rhs) const {
return lhs == rhs;
}
};

struct PmrHash {
using is_transparent = void;
size_t operator()(const std::string_view& sv) const {
return absl::Hash<std::string_view>()(sv);
}
size_t operator()(const PMR_NS::string& pmrs) const {
return operator()(std::string_view{pmrs.data(), pmrs.size()});
}
};

bool case_sensitive_ = false;

absl::flat_hash_map<PMR_NS::string, Container, PmrHash, PmrEqual,
PMR_NS::polymorphic_allocator<std::pair<PMR_NS::string, Container>>>
entries_;
search::RaxTreeMap<Container> entries_;
};

// Index for text fields.
Expand Down
1 change: 1 addition & 0 deletions src/core/search/lexer.lex
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ tag_val_char {term_char}|\\[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]

"$"{term_char}+ return ParseParam(str(), loc());
"@"{term_char}+ return Parser::make_FIELD(str(), loc());
{term_char}+"*" return Parser::make_PREFIX(str(), loc());

{term_char}+ return Parser::make_TERM(str(), loc());
{tag_val_char}+ return make_TagVal(str(), loc());
Expand Down
3 changes: 2 additions & 1 deletion src/core/search/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ double toDouble(string_view src);

// Needed 0 at the end to satisfy bison 3.5.1
%token YYEOF 0
%token <std::string> TERM "term" TAG_VAL "tag_val" PARAM "param" FIELD "field"
%token <std::string> TERM "term" TAG_VAL "tag_val" PARAM "param" FIELD "field" PREFIX "prefix"

%precedence TERM TAG_VAL
%left OR_OP
Expand Down Expand Up @@ -132,6 +132,7 @@ search_unary_expr:
LPAREN search_expr RPAREN { $$ = std::move($2); }
| NOT_OP search_unary_expr { $$ = AstNegateNode(std::move($2)); }
| TERM { $$ = AstTermNode(std::move($1)); }
| PREFIX { $$ = AstPrefixNode(std::move($1)); }
| UINT32 { $$ = AstTermNode(std::move($1)); }
| FIELD COLON field_cond { $$ = AstFieldNode(std::move($1), std::move($3)); }

Expand Down
26 changes: 15 additions & 11 deletions src/core/search/rax_tree.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#pragma once

#include <absl/types/span.h>

#include <cstdio>
#include <cassert>
#include <optional>
#include <string_view>
#include <utility>

#include "base/pmr/memory_resource.h"

Expand All @@ -17,6 +16,7 @@ namespace dfly::search {
// absl::flat_hash_map/std::unordered_map compatible tree map based on rax tree.
// Allocates all objects on heap (with custom memory resource) as rax tree operates fully on
// pointers.
// TODO: Add full support for polymorphic allocators, including rax trie node allocations
template <typename V> struct RaxTreeMap {
struct FindIterator;

Expand Down Expand Up @@ -87,7 +87,7 @@ template <typename V> struct RaxTreeMap {
};

public:
explicit RaxTreeMap(PMR_NS::memory_resource* mr) : tree_(raxNew()), mr_(mr) {
explicit RaxTreeMap(PMR_NS::memory_resource* mr) : tree_(raxNew()), alloc_(mr) {
}

size_t size() const {
Expand Down Expand Up @@ -119,7 +119,12 @@ template <typename V> struct RaxTreeMap {
V* old = nullptr;
raxRemove(tree_, to_key_ptr(it->first.data()), it->first.size(),
reinterpret_cast<void**>(&old));
mr_->deallocate(old, sizeof(V), alignof(V));
alloc_.destroy(old);
alloc_.deallocate(old, 1);
}

auto& get_allocator() const {
return alloc_;
}

private:
Expand All @@ -128,7 +133,7 @@ template <typename V> struct RaxTreeMap {
}

rax* tree_;
PMR_NS::memory_resource* mr_;
PMR_NS::polymorphic_allocator<V> alloc_;
};

template <typename V>
Expand All @@ -138,15 +143,14 @@ std::pair<typename RaxTreeMap<V>::FindIterator, bool> RaxTreeMap<V>::try_emplace
if (auto it = find(key); it)
return {it, false};

void* ptr = mr_->allocate(sizeof(V), alignof(V));
V* data = new (ptr) V(std::forward<Args>(args)...);
assert(uint64_t(ptr) == uint64_t(data)); // we free by the latter
V* ptr = alloc_.allocate(1);
alloc_.construct(ptr, std::forward<Args>(args)...);

V* old = nullptr;
raxInsert(tree_, to_key_ptr(key), key.size(), data, reinterpret_cast<void**>(&old));
raxInsert(tree_, to_key_ptr(key), key.size(), ptr, reinterpret_cast<void**>(&old));
assert(old == nullptr);

auto it = std::make_optional(std::pair<std::string_view, V&>(key, *data));
auto it = std::make_optional(std::pair<std::string_view, V&>(key, *ptr));
return std::make_pair(FindIterator{it}, true);
}

Expand Down
23 changes: 23 additions & 0 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ struct ProfileBuilder {
Overloaded node_info{
[](monostate) -> string { return ""s; },
[](const AstTermNode& n) { return absl::StrCat("Term{", n.term, "}"); },
[](const AstPrefixNode& n) { return absl::StrCat("Prefix{", n.prefix, "}"); },
[](const AstRangeNode& n) { return absl::StrCat("Range{", n.lo, "<>", n.hi, "}"); },
[](const AstLogicalNode& n) {
auto op = n.op == AstLogicalNode::AND ? "and" : "or";
Expand Down Expand Up @@ -270,6 +271,28 @@ struct BasicSearch {
return UnifyResults(GetSubResults(selected_indices, mapping), LogicOp::OR);
}

IndexResult Search(const AstPrefixNode& node, string_view active_field) {
vector<TextIndex*> indices;
if (!active_field.empty()) {
if (auto* index = GetIndex<TextIndex>(active_field); index)
indices = {index};
else
return IndexResult{};
} else {
indices = indices_->GetAllTextIndices();
}

auto mapping = [&node, this](TextIndex* index) {
IndexResult result{};
index->MatchingPrefix(node.prefix, [&result, this](const auto* c) {
Merge(IndexResult{c}, &result, LogicOp::OR);
});
return result;
};

return UnifyResults(GetSubResults(indices, mapping), LogicOp::OR);
}

// [range]: access field's numeric index
IndexResult Search(const AstRangeNode& node, string_view active_field) {
DCHECK(!active_field.empty());
Expand Down
40 changes: 39 additions & 1 deletion src/core/search/search_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ TEST_F(SearchParserTest, Scanner) {
NEXT_EQ(TOK_TERM, string, "cd");
NEXT_TOK(TOK_YYEOF);

SetInput("(5a 6) ");
SetInput("*");
NEXT_TOK(TOK_STAR);

SetInput("(5a 6) ");
NEXT_TOK(TOK_LPAREN);
NEXT_EQ(TOK_TERM, string, "5a");
NEXT_EQ(TOK_UINT32, string, "6");
Expand Down Expand Up @@ -151,6 +153,36 @@ TEST_F(SearchParserTest, Scanner) {
NEXT_EQ(TOK_TAG_VAL, string, "blue]1#-");
NEXT_TOK(TOK_RCURLBR);

// Prefix simple
SetInput("pre*");
NEXT_EQ(TOK_PREFIX, string, "pre*");

// TODO: uncomment when we support escaped terms
// Prefix escaped (redis doesn't support quoted prefix matches)
// SetInput("pre\\**");
// NEXT_EQ(TOK_PREFIX, string, "pre*");

// Prefix in tag
SetInput("@color:{prefix*}");
NEXT_EQ(TOK_FIELD, string, "@color");
NEXT_TOK(TOK_COLON);
NEXT_TOK(TOK_LCURLBR);
NEXT_EQ(TOK_PREFIX, string, "prefix*");
NEXT_TOK(TOK_RCURLBR);

// Prefix escaped star
SetInput("@color:{\"prefix*\"}");
NEXT_EQ(TOK_FIELD, string, "@color");
NEXT_TOK(TOK_COLON);
NEXT_TOK(TOK_LCURLBR);
NEXT_EQ(TOK_TERM, string, "prefix*");
NEXT_TOK(TOK_RCURLBR);

// Prefix spaced with star
SetInput("pre *");
NEXT_EQ(TOK_TERM, string, "pre");
NEXT_TOK(TOK_STAR);

SetInput("почтальон Печкин");
NEXT_EQ(TOK_TERM, string, "почтальон");
NEXT_EQ(TOK_TERM, string, "Печкин");
Expand All @@ -172,6 +204,12 @@ TEST_F(SearchParserTest, Parse) {
EXPECT_EQ(1, Parse(" foo:bar "));
EXPECT_EQ(1, Parse(" @foo:@bar "));
EXPECT_EQ(1, Parse(" @foo: "));

// We don't support suffix/any other position for now
EXPECT_EQ(1, Parse("*pre"));
EXPECT_EQ(1, Parse("*pre*"));

EXPECT_EQ(1, Parse("pre***"));
}

TEST_F(SearchParserTest, ParseParams) {
Expand Down
29 changes: 29 additions & 0 deletions src/core/search/search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <absl/strings/str_split.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <mimalloc.h>

#include <algorithm>
#include <memory_resource>
Expand All @@ -22,6 +23,10 @@
#include "core/search/query_driver.h"
#include "core/search/vector_utils.h"

extern "C" {
#include "redis/zmalloc.h"
}

namespace dfly {
namespace search {

Expand Down Expand Up @@ -80,6 +85,11 @@ Schema MakeSimpleSchema(initializer_list<pair<string_view, SchemaField::FieldTyp

class SearchTest : public ::testing::Test {
protected:
static void SetUpTestSuite() {
auto* tlh = mi_heap_get_backing();
init_zmalloc_threadlocal(tlh);
}

SearchTest() {
PrepareSchema({{"field", SchemaField::TEXT}});
}
Expand Down Expand Up @@ -260,6 +270,25 @@ TEST_F(SearchTest, CheckParenthesisPriority) {
}
}

TEST_F(SearchTest, CheckPrefix) {
{
PrepareQuery("pre*");

ExpectAll("pre", "prepre", "preachers", "prepared", "pRetty", "PRedators", "prEcisely!");
ExpectNone("pristine", "represent", "repair", "depreciation");

EXPECT_TRUE(Check()) << GetError();
}
{
PrepareQuery("new*");

ExpectAll("new", "New York", "Newham", "newbie", "news", "Welcome to Newark!");
ExpectNone("ne", "renew", "nev", "ne-w", "notnew", "casino in neVada");

EXPECT_TRUE(Check()) << GetError();
}
}

using Map = MockedDocument::Map;

TEST_F(SearchTest, MatchField) {
Expand Down

0 comments on commit f455981

Please sign in to comment.