Skip to content

Commit

Permalink
opt9: batch vle processing (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
msm-cert authored Oct 16, 2024
1 parent cdbf0c1 commit c2ffdea
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 55 deletions.
125 changes: 95 additions & 30 deletions libursa/SortedRun.cpp
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
#include "SortedRun.h"

#include <algorithm>
#include <numeric>
#include <stdexcept>

#include "Utils.h"

uint32_t RunIterator::current() const {
// Read VLE integer stored under pos.
uint32_t run_read(uint8_t *pos) {
uint64_t acc = 0;
uint32_t shift = 0;
for (uint8_t *it = pos_;; it++) {
for (uint8_t *it = pos;; it++) {
uint32_t next = *it;
acc += (next & 0x7FU) << shift;
shift += 7U;
if ((next & 0x80U) == 0) {
return prev_ + acc + 1;
return acc + 1;
}
}
}

uint8_t *RunIterator::nextpos() {
for (uint8_t *it = pos_;; it++) {
if ((*it & 0x80) == 0) {
return it + 1;
// Return a pointer to the next encoded integer.
uint8_t *run_forward(uint8_t *pos) {
for (;; pos++) {
if ((*pos & 0x80) == 0) {
return pos + 1;
}
}
}
Expand All @@ -42,44 +45,106 @@ std::vector<uint32_t>::iterator SortedRun::end() {
return sequence_.end();
}

RunIterator SortedRun::comp_begin() {
validate_compression(true);
return RunIterator(run_.data());
}

RunIterator SortedRun::comp_end() {
validate_compression(true);
return RunIterator(run_.data() + run_.size());
}

void SortedRun::do_or(SortedRun &other) {
// In almost every case this is already decompressed.
decompress();
other.decompress();
std::vector<FileId> new_results;
if (other.is_compressed()) {
// Unlikely case, in most cases both runs are already decompressed.
std::set_union(other.comp_begin(), other.comp_end(), begin(), end(),
std::back_inserter(new_results));
} else {
std::set_union(other.begin(), other.end(), begin(), end(),
std::back_inserter(new_results));
}
std::set_union(other.begin(), other.end(), begin(), end(),
std::back_inserter(new_results));
std::swap(new_results, sequence_);
}

// Read VLE integer under run_it_ and do the intersection.
void IntersectionHelper::step_single() {
uint32_t next = prev_ + run_read(run_it_);
if (next < *seq_it_) {
prev_ = next;
run_it_ = run_forward(run_it_);
return;
}
if (*seq_it_ == next) {
*seq_out_++ = *seq_it_;
prev_ = next;
run_it_ = run_forward(run_it_);
}
seq_it_++;
}

// Read 8 bytes under run_it_. If all are small, intersect them all.
// Returns true if the method can continue, and false if a large int was found.
bool IntersectionHelper::step_by_8() {
constexpr int BATCH_SIZE = 8;
constexpr uint64_t VLE_MASK = 0x8080808080808080UL;

uint64_t *as_qword = (uint64_t *)run_it_;
uint64_t hit = (*as_qword & VLE_MASK);
if (hit != 0) {
// A large byte (>0x80) was found, handle them in a slow path.
return false;
}

uint32_t after_batch = prev_ + BATCH_SIZE;
after_batch += std::accumulate(run_it_, run_it_ + BATCH_SIZE, 0);

// Fast-fast path. Maybe we can just add all 8 bytes and still are
// below the next sequence byte (i.e. nothing to do in intersection).
if (after_batch < *seq_it_) {
run_it_ += BATCH_SIZE;
prev_ = after_batch;
return true;
}

// Regular batch: like step_single but we know are only dealing with bytes.
for (uint8_t *end = run_it_ + BATCH_SIZE;
run_it_ < end && seq_it_ < seq_end_;) {
uint32_t next = prev_ + *run_it_ + 1;
if (next < *seq_it_) {
prev_ = next;
run_it_ += 1;
continue;
}
if (*seq_it_ == next) {
*seq_out_++ = *seq_it_;
prev_ = next;
run_it_ += 1;
}
seq_it_++;
}
return true;
}

// Do the intersection in batches of 8 bytes at once.
void IntersectionHelper::intersect_by_8() {
while (run_it_ < run_end_ - 8 && seq_it_ < seq_end_) {
if (step_by_8()) {
continue;
}
step_single();
}
}

// This function is basically std::set_intersection, but optimized as
// much as possible (since sometimes almost 50% of time is spent here).
void IntersectionHelper::intersect() {
intersect_by_8();
while (run_it_ < run_end_ && seq_it_ < seq_end_) {
step_single();
}
}

void SortedRun::do_and(SortedRun &other) {
// Benchmarking shows that handling a situation where this->is_compressed()
// makes the code *slower*. I assume that's because of memory efficiency.
decompress();
std::vector<uint32_t>::iterator new_end;
if (other.is_compressed()) {
new_end = std::set_intersection(other.comp_begin(), other.comp_end(),
begin(), end(), begin());
IntersectionHelper helper(&sequence_, &other.run_);
helper.intersect();
new_end = begin() + helper.result_size();
} else {
new_end = std::set_intersection(other.begin(), other.end(), begin(),
end(), begin());
}
sequence_.erase(new_end, sequence_.end());
sequence_.erase(new_end, end());
}

void SortedRun::decompress() {
Expand Down
55 changes: 31 additions & 24 deletions libursa/SortedRun.h
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
#include <emmintrin.h>

#include "Core.h"

// Iterate over a compressed run representation.
// "Run" here means a sorted list of FileIDs (this name is used in the
// codebase). And a "compressed" run format is described in the documentation
// "ondiskformat.md", in the "Index" section.
class RunIterator : public std::iterator<std::forward_iterator_tag, uint32_t> {
typedef RunIterator iterator;
uint8_t *pos_;
uint32_t run_read(uint8_t *pos);
uint8_t *run_forward(uint8_t *pos);

class IntersectionHelper {
uint8_t *run_it_;
uint8_t *run_end_;
int32_t prev_;
uint32_t *seq_start_;
uint32_t *seq_it_;
uint32_t *seq_end_;
uint32_t *seq_out_;

uint32_t current() const;
uint8_t *nextpos();
bool step_by_8();
void step_single();
void intersect_by_8();

public:
RunIterator(uint8_t *run) : pos_(run), prev_(-1) {}
~RunIterator() {}

RunIterator &operator++() {
prev_ = current();
pos_ = nextpos();
return *this;
}

uint32_t operator*() const { return current(); }
bool operator!=(const iterator &rhs) const { return pos_ != rhs.pos_; }
IntersectionHelper(std::vector<uint32_t> *seq, std::vector<uint8_t> *run)
: run_it_(run->data()),
run_end_(run->data() + run->size()),
prev_(-1),
seq_start_(seq->data()),
seq_it_(seq->data()),
seq_end_(seq->data() + seq->size()),
seq_out_(seq->data()) {}

size_t result_size() const { return seq_out_ - seq_start_; }
void intersect();
};

// This class represents a "run" - a sorted list of FileIDs. This can be
Expand Down Expand Up @@ -52,10 +58,6 @@ class SortedRun {
std::vector<uint32_t>::iterator begin();
std::vector<uint32_t>::iterator end();

// Iterate over the compressed representation (throws if decompressed)
RunIterator comp_begin();
RunIterator comp_end();

SortedRun(const SortedRun &other) = default;

public:
Expand All @@ -72,11 +74,16 @@ class SortedRun {
: sequence_(other.sequence_), run_(other.run_) {}
SortedRun &operator=(SortedRun &&) = default;

// Checks if the current run is empty.
bool empty() const { return sequence_.empty() && run_.empty(); }

// Does the OR operation with the other vector, overwrites this object.
void do_or(SortedRun &other);

// Does the AND operation with the other vector, overwrites this object.
void do_and(SortedRun &other);

// Does the MIN_OF operation on specified operands. Allocates a new reuslt.
static SortedRun pick_common(int cutoff, std::vector<SortedRun *> &sources);

// When you really need to clone the run - TODO remove.
Expand Down
2 changes: 1 addition & 1 deletion libursa/Version.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ constexpr std::string_view ursadb_format_version = "1.5.0";
// Project version.
// Consider updating the version tag when doing PRs.
// clang-format off
constexpr std::string_view ursadb_version_string = "@PROJECT_VERSION@+opt8";
constexpr std::string_view ursadb_version_string = "@PROJECT_VERSION@+opt9";
// clang-format on

0 comments on commit c2ffdea

Please sign in to comment.