Skip to content

Commit

Permalink
fix(rax_tree): Fix crash caused by destructor in RaxTreeMap (#4228)
Browse files Browse the repository at this point in the history
* fix(rax_tree): Fix double raxStop call in the SeekIterator

fixes #4172

Signed-off-by: Stepan Bagritsevich <[email protected]>

* refactor(rax_tree): Address comments

Signed-off-by: Stepan Bagritsevich <[email protected]>

---------

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan authored and romange committed Dec 10, 2024
1 parent 976586c commit c49dc61
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 23 deletions.
62 changes: 39 additions & 23 deletions src/core/search/rax_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,37 @@ template <typename V> struct RaxTreeMap {

// Simple seeking iterator
struct SeekIterator {
friend struct FindIterator;

SeekIterator() {
raxStart(&it_, nullptr);
it_.node = nullptr;
}

~SeekIterator() {
raxStop(&it_);
it_.rt = nullptr;
}

SeekIterator(SeekIterator&&) = delete; // self-referential
SeekIterator(const SeekIterator&) = delete; // self-referential

SeekIterator(rax* tree, const char* op, std::string_view key) {
raxStart(&it_, tree);
raxSeek(&it_, op, to_key_ptr(key), key.size());
operator++();
if (raxSeek(&it_, op, to_key_ptr(key), key.size())) { // Successfuly seeked
operator++();
} else {
InvalidateIterator();
}
}

explicit SeekIterator(rax* tree) : SeekIterator(tree, "^", std::string_view{nullptr, 0}) {
}

/* Remove copy/move constructors to avoid double iterator invalidation */
SeekIterator(SeekIterator&&) = delete;
SeekIterator(const SeekIterator&) = delete;
SeekIterator& operator=(SeekIterator&&) = delete;
SeekIterator& operator=(const SeekIterator&) = delete;

~SeekIterator() {
if (IsValid()) {
InvalidateIterator();
}
}

bool operator==(const SeekIterator& rhs) const {
if (!IsValid() || !rhs.IsValid())
return !IsValid() && !rhs.IsValid();
return it_.node == rhs.it_.node;
}

Expand All @@ -56,31 +63,40 @@ template <typename V> struct RaxTreeMap {
}

SeekIterator& operator++() {
if (!raxNext(&it_)) {
raxStop(&it_);
it_.node = nullptr;
int next_result = raxNext(&it_);
if (!next_result) { // OOM or we reached the end of the tree
InvalidateIterator();
}
return *this;
}

/* After operator++() the first value (string_view) is invalid. So make sure your copied it to
* string */
std::pair<std::string_view, V&> operator*() const {
assert(IsValid() && it_.node && it_.node->iskey && it_.data);
return {std::string_view{reinterpret_cast<const char*>(it_.key), it_.key_len},
*reinterpret_cast<V*>(it_.data)};
}

bool IsValid() const {
return it_.rt;
}

private:
void InvalidateIterator() {
raxStop(&it_);
it_.rt = nullptr;
}

raxIterator it_;
};

// Result of find() call. Inherits from pair to mimic iterator interface, not incrementable.
struct FindIterator : public std::optional<std::pair<std::string, V&>> {
bool operator==(const SeekIterator& rhs) const {
if (this->has_value() != !bool(rhs.it_.flags & RAX_ITER_EOF))
return false;
if (!this->has_value())
return true;
return (*this)->first ==
std::string_view{reinterpret_cast<const char*>(rhs.it_.key), rhs.it_.key_len};
if (!this->has_value() || !rhs.IsValid())
return !this->has_value() && !rhs.IsValid();
return (*this)->first == (*rhs).first;
}

bool operator!=(const SeekIterator& rhs) const {
Expand Down Expand Up @@ -160,7 +176,7 @@ std::pair<typename RaxTreeMap<V>::FindIterator, bool> RaxTreeMap<V>::try_emplace

V* old = nullptr;
raxInsert(tree_, to_key_ptr(key), key.size(), ptr, reinterpret_cast<void**>(&old));
assert(old == nullptr);
assert(!old);

auto it = std::make_optional(std::pair<std::string, V&>(std::string(key), *ptr));
return std::make_pair(std::move(FindIterator{it}), true);
Expand Down
24 changes: 24 additions & 0 deletions src/core/search/rax_tree_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,28 @@ TEST_F(RaxTreeTest, Find) {
EXPECT_TRUE(map.find(string_view{}) == map.end());
}

/* Run with mimalloc to make sure there is no double free */
TEST_F(RaxTreeTest, Iterate) {
const char* kKeys[] = {
"aaaaaaaaaaaaaaaaaaaa",
"bbbbbbbbbbbbbbbbbbbbbb"
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
"dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"
"eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
};

RaxTreeMap<int> map(pmr::get_default_resource());
for (const char* key : kKeys) {
map.try_emplace(key, 2);
}

for (auto it = map.begin(); it != map.end(); ++it) {
EXPECT_EQ((*it).second, 2);
}

for (auto it = map.begin(); it != map.end(); ++it) {
EXPECT_EQ((*it).second, 2);
}
}

} // namespace dfly::search

0 comments on commit c49dc61

Please sign in to comment.