Skip to content

Commit

Permalink
fix(search): Support indexing array paths (#2074)
Browse files Browse the repository at this point in the history
* fix(search): Support indexing array paths

Signed-off-by: Vladislav Oleshko <[email protected]>


---------

Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg authored Oct 29, 2023
1 parent 47d92fb commit 04cd2ff
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 39 deletions.
5 changes: 4 additions & 1 deletion src/core/search/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <absl/container/flat_hash_map.h>
#include <absl/container/inlined_vector.h>

#include <cstdint>
#include <memory>
Expand Down Expand Up @@ -57,9 +58,11 @@ using ResultScore = std::variant<std::monostate, float, double, WrappedStrPtr>;
// Interface for accessing document values with different data structures underneath.
struct DocumentAccessor {
using VectorInfo = search::OwnedFtVector;
using StringList = absl::InlinedVector<std::string_view, 1>;

virtual ~DocumentAccessor() = default;
virtual std::string_view GetString(std::string_view active_field) const = 0;

virtual StringList GetStrings(std::string_view active_field) const = 0;
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
};

Expand Down
44 changes: 30 additions & 14 deletions src/core/search/indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <absl/container/flat_hash_set.h>
#include <absl/strings/ascii.h>
#include <absl/strings/numbers.h>
#include <absl/strings/str_join.h>
#include <absl/strings/str_split.h>

#define UNI_ALGO_DISABLE_NFKC_NFKD
Expand Down Expand Up @@ -59,15 +60,19 @@ NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
}

void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
double num;
if (absl::SimpleAtod(doc->GetString(field), &num))
entries_.emplace(num, id);
for (auto str : doc->GetStrings(field)) {
double num;
if (absl::SimpleAtod(str, &num))
entries_.emplace(num, id);
}
}

void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
int64_t num;
if (absl::SimpleAtoi(doc->GetString(field), &num))
entries_.erase({num, id});
for (auto str : doc->GetStrings(field)) {
double num;
if (absl::SimpleAtod(str, &num))
entries_.erase({num, id});
}
}

vector<DocId> NumericIndex::Range(double l, double r) const {
Expand All @@ -79,6 +84,7 @@ vector<DocId> NumericIndex::Range(double l, double r) const {
out.push_back(it->second);

sort(out.begin(), out.end());
out.erase(unique(out.begin(), out.end()), out.end());
return out;
}

Expand All @@ -104,17 +110,27 @@ CompressedSortedSet* BaseStringIndex::GetOrCreate(string_view word) {
}

void BaseStringIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
for (const auto& word : Tokenize(doc->GetString(field)))
GetOrCreate(word)->Insert(id);
absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field))
tokens.merge(Tokenize(str));

for (string_view token : tokens)
GetOrCreate(token)->Insert(id);
}

void BaseStringIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
for (const auto& word : Tokenize(doc->GetString(field))) {
if (auto it = entries_.find(word); it != entries_.end()) {
it->second.Remove(id);
if (it->second.Size() == 0)
entries_.erase(it);
}
absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field))
tokens.merge(Tokenize(str));

for (const auto& token : tokens) {
auto it = entries_.find(token);
if (it == entries_.end())
continue;

it->second.Remove(id);
if (it->second.Size() == 0)
entries_.erase(it);
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/core/search/search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ struct MockedDocument : public DocumentAccessor {
MockedDocument(std::string test_field) : fields_{{"field", test_field}} {
}

string_view GetString(string_view field) const override {
StringList GetStrings(string_view field) const override {
auto it = fields_.find(field);
return it != fields_.end() ? string_view{it->second} : "";
return {it != fields_.end() ? string_view{it->second} : ""};
}

VectorInfo GetVector(string_view field) const override {
return BytesToFtVector(GetString(field));
return BytesToFtVector(GetStrings(field).front());
}

string DebugFormat() {
Expand Down
12 changes: 10 additions & 2 deletions src/core/search/sort_indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,22 @@ template struct SimpleValueSortIndex<double>;
template struct SimpleValueSortIndex<PMR_NS::string>;

double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
auto str = doc->GetStrings(field);
if (str.empty())
return 0;

double v;
if (!absl::SimpleAtod(doc->GetString(field), &v))
if (!absl::SimpleAtod(str.front(), &v))
return 0;
return v;
}

PMR_NS::string StringSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
return PMR_NS::string{doc->GetString(field), GetMemRes()};
auto str = doc->GetStrings(field);
if (str.empty())
return "";

return PMR_NS::string{str.front(), GetMemRes()};
}

} // namespace dfly::search
55 changes: 40 additions & 15 deletions src/server/search/doc_accessors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,19 @@ SearchDocData BaseAccessor::Serialize(const search::Schema& schema,
for (const auto& [fident, fname] : fields) {
auto it = schema.fields.find(fident);
auto type = it != schema.fields.end() ? it->second.type : search::SchemaField::TEXT;
out[fname] = PrintField(type, GetString(fident));
out[fname] = PrintField(type, absl::StrJoin(GetStrings(fident), ","));
}
return out;
}

string_view ListPackAccessor::GetString(string_view active_field) const {
return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv);
BaseAccessor::StringList ListPackAccessor::GetStrings(string_view active_field) const {
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
return strsv.has_value() ? StringList{*strsv} : StringList{};
}

BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const {
return search::BytesToFtVector(GetString(active_field));
auto strlist = GetStrings(active_field);
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
}

SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
Expand All @@ -86,13 +88,14 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
return out;
}

string_view StringMapAccessor::GetString(string_view active_field) const {
BaseAccessor::StringList StringMapAccessor::GetStrings(string_view active_field) const {
auto it = hset_->Find(active_field);
return it != hset_->end() ? it->second : ""sv;
return it != hset_->end() ? StringList{it->second} : StringList{};
}

BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
return search::BytesToFtVector(GetString(active_field));
auto strlist = GetStrings(active_field);
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
}

SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
Expand All @@ -106,13 +109,35 @@ SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
struct JsonAccessor::JsonPathContainer : public jsoncons::jsonpath::jsonpath_expression<JsonType> {
};

string_view JsonAccessor::GetString(string_view active_field) const {
auto res = GetPath(active_field)->evaluate(json_);
DCHECK(res.is_array());
if (res.empty())
return "";
buf_ = res[0].as_string();
return buf_;
BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) const {
auto path_res = GetPath(active_field)->evaluate(json_);
DCHECK(path_res.is_array()); // json path always returns arrays

if (path_res.empty())
return {};

if (path_res.size() == 1) {
buf_ = path_res[0].as_string();
return {buf_};
}

// First, grow buffer and compute string sizes
vector<size_t> sizes;
for (auto element : path_res.array_range()) {
size_t start = buf_.size();
buf_ += element.as_string();
sizes.push_back(buf_.size() - start);
}

// Reposition start pointers to the most recent allocation of buf
StringList out(sizes.size());
size_t start = 0;
for (size_t i = 0; i < out.size(); i++) {
out[i] = string_view{buf_}.substr(start, sizes[i]);
start += sizes[i];
}

return out;
}

BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const {
Expand Down Expand Up @@ -156,7 +181,7 @@ SearchDocData JsonAccessor::Serialize(const search::Schema& schema,
const SearchParams::FieldReturnList& fields) const {
SearchDocData out{};
for (const auto& [ident, name] : fields)
out[name] = GetString(ident);
out[name] = GetPath(ident)->evaluate(json_).to_string();
return out;
}

Expand Down
6 changes: 3 additions & 3 deletions src/server/search/doc_accessors.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct ListPackAccessor : public BaseAccessor {
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
}

std::string_view GetString(std::string_view field) const override;
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;

Expand All @@ -53,7 +53,7 @@ struct StringMapAccessor : public BaseAccessor {
explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
}

std::string_view GetString(std::string_view field) const override;
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;

Expand All @@ -68,7 +68,7 @@ struct JsonAccessor : public BaseAccessor {
explicit JsonAccessor(const JsonType* json) : json_{*json} {
}

std::string_view GetString(std::string_view field) const override;
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;

Expand Down
61 changes: 60 additions & 1 deletion src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ TEST_F(SearchFamilyTest, Json) {
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults);
}

TEST_F(SearchFamilyTest, AttributesJsonPaths) {
TEST_F(SearchFamilyTest, JsonAttributesPaths) {
Run({"json.set", "k1", ".", R"( {"nested": {"value": "no"}} )"});
Run({"json.set", "k2", ".", R"( {"nested": {"value": "yes"}} )"});
Run({"json.set", "k3", ".", R"( {"nested": {"value": "maybe"}} )"});
Expand All @@ -193,6 +193,65 @@ TEST_F(SearchFamilyTest, AttributesJsonPaths) {
EXPECT_THAT(Run({"ft.search", "i1", "yes"}), AreDocIds("k2"));
}

TEST_F(SearchFamilyTest, JsonArrayValues) {
string_view D1 = R"(
{
"name": "Alex",
"plays" : [
{"game": "Pacman", "score": 10},
{"game": "Tetris", "score": 15}
],
"areas": ["EU-west", "EU-central"]
}
)";
string_view D2 = R"(
{
"name": "Bob",
"plays" : [
{"game": "Pacman", "score": 15},
{"game": "Mario", "score": 7}
],
"areas": "US-central"
}
)";
string_view D3 = R"(
{
"name": "Caren",
"plays" : [
{"game": "Mario", "score": 9},
{"game": "Doom", "score": 20}
],
"areas": ["EU-central", "EU-east"]
}
)";

Run({"json.set", "k1", ".", D1});
Run({"json.set", "k2", ".", D2});
Run({"json.set", "k3", ".", D3});

auto resp = Run({"ft.create", "i1", "on", "json", "schema", "$.name", "text", "$.plays[*].game",
"as", "games", "tag", "$.plays[*].score", "as", "scores", "numeric",
"$.areas[*]", "as", "areas", "tag"});
EXPECT_EQ(resp, "OK");

EXPECT_THAT(Run({"ft.search", "i1", "*"}), AreDocIds("k1", "k2", "k3"));

// Find players by games
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Tetris | Mario | Doom}"}),
AreDocIds("k1", "k2", "k3"));
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Pacman}"}), AreDocIds("k1", "k2"));
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Mario}"}), AreDocIds("k2", "k3"));

// Find players by scores
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[15 15]"}), AreDocIds("k1", "k2"));
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[0 (10]"}), AreDocIds("k2", "k3"));
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[(15 20]"}), AreDocIds("k3"));

// Find platers by areas
EXPECT_THAT(Run({"ft.search", "i1", "@areas:{\"EU-central\"}"}), AreDocIds("k1", "k3"));
EXPECT_THAT(Run({"ft.search", "i1", "@areas:{\"US-central\"}"}), AreDocIds("k2"));
}

TEST_F(SearchFamilyTest, Tags) {
Run({"hset", "d:1", "color", "red, green"});
Run({"hset", "d:2", "color", "green, blue"});
Expand Down
25 changes: 25 additions & 0 deletions tests/dragonfly/search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,31 @@ async def test_basic(async_client: aioredis.Redis, index_type):
await i1.dropindex()


@dfly_args({"proactor_threads": 4})
async def test_big_json(async_client: aioredis.Redis):
i1 = async_client.ft("i1")
gen_arr = lambda base: {"blob": [base + str(i) for i in range(100)]}

await async_client.json().set("k1", "$", gen_arr("alex"))
await async_client.json().set("k2", "$", gen_arr("bob"))

await i1.create_index(
[TextField(name="$.blob", as_name="items")],
definition=IndexDefinition(index_type=IndexType.JSON),
)

res = await i1.search("alex55")
assert res.docs[0].id == "k1"

res = await i1.search("bob77")
assert res.docs[0].id == "k2"

res = await i1.search("alex11 | bob22")
assert res.total == 2

await i1.dropindex()


async def knn_query(idx, query, vector):
params = {"vec": np.array(vector, dtype=np.float32).tobytes()}
result = await idx.search(query, params)
Expand Down

0 comments on commit 04cd2ff

Please sign in to comment.