Skip to content

Commit

Permalink
sparse: refactor approx dimension max score ratio (#1029)
Browse files Browse the repository at this point in the history
* sparse: refactor approx dimension max score ratio

1. Move the dimension max score ratio from build params to search params,
and rename it from `wand_bm25_max_score_ratio` to `dim_max_score_ratio`.

2. Remove template param `bm25` and add a new `SparseMetricType`.

3. Wrap some params of `Search()` to `InvertedIndexApproxSearchParams`.

Signed-off-by: Shawn Wang <[email protected]>

* sparse: override CheckAndAdjust for inverted_index_algo config

Signed-off-by: Shawn Wang <[email protected]>

---------

Signed-off-by: Shawn Wang <[email protected]>
  • Loading branch information
sparknack authored Jan 16, 2025
1 parent 86ca90f commit 7dc867d
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 98 deletions.
2 changes: 1 addition & 1 deletion include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ constexpr const char* BM25_K1 = "bm25_k1";
constexpr const char* BM25_B = "bm25_b";
// average document length
constexpr const char* BM25_AVGDL = "bm25_avgdl";
constexpr const char* WAND_BM25_MAX_SCORE_RATIO = "wand_bm25_max_score_ratio";
constexpr const char* DIM_MAX_SCORE_RATIO = "dim_max_score_ratio";
}; // namespace meta

namespace indexparam {
Expand Down
5 changes: 5 additions & 0 deletions include/knowhere/sparse_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@

namespace knowhere::sparse {

enum class SparseMetricType {
METRIC_IP = 1,
METRIC_BM25 = 2,
};

// integer type in SparseRow
using table_t = uint32_t;
// type used to represent the id of a vector in the index interface.
Expand Down
54 changes: 34 additions & 20 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,39 @@ class SparseInvertedIndexNode : public IndexNode {
LOG_KNOWHERE_ERROR_ << "Could not search empty " << Type();
return expected<DataSetPtr>::Err(Status::empty_index, "index not loaded");
}

auto cfg = static_cast<const SparseInvertedIndexConfig&>(*config);

auto computer_or = index_->GetDocValueComputer(cfg);
if (!computer_or.has_value()) {
return expected<DataSetPtr>::Err(computer_or.error(), computer_or.what());
}
auto computer = computer_or.value();
auto nq = dataset->GetRows();
auto queries = static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor());
auto k = cfg.k.value();
auto refine_factor = cfg.refine_factor.value_or(1);
auto dim_max_score_ratio = cfg.dim_max_score_ratio.value();
auto drop_ratio_search = cfg.drop_ratio_search.value_or(0.0f);
auto refine_factor = cfg.refine_factor.value_or(1);
// if no data was dropped during search, no refinement is needed.
if (drop_ratio_search == 0) {
refine_factor = 1;
}

sparse::InvertedIndexApproxSearchParams approx_params = {
.refine_factor = refine_factor,
.drop_ratio_search = drop_ratio_search,
.dim_max_score_ratio = dim_max_score_ratio,
};

auto queries = static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor());
auto nq = dataset->GetRows();
auto k = cfg.k.value();
auto p_id = std::make_unique<sparse::label_t[]>(nq * k);
auto p_dist = std::make_unique<float[]>(nq * k);

std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int64_t idx = 0; idx < nq; ++idx) {
futs.emplace_back(search_pool_->push([&, idx = idx, p_id = p_id.get(), p_dist = p_dist.get()]() {
index_->Search(queries[idx], k, drop_ratio_search, p_dist + idx * k, p_id + idx * k, refine_factor,
bitset, computer);
index_->Search(queries[idx], k, p_dist + idx * k, p_id + idx * k, bitset, computer, approx_params);
}));
}
WaitAllSuccess(futs);
Expand Down Expand Up @@ -359,36 +371,38 @@ class SparseInvertedIndexNode : public IndexNode {
auto k1 = cfg.bm25_k1.value();
auto b = cfg.bm25_b.value();
auto avgdl = cfg.bm25_avgdl.value();
auto max_score_ratio = cfg.wand_bm25_max_score_ratio.value();

if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_WAND, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
auto index = new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_WAND, mmapped>(
sparse::SparseMetricType::METRIC_BM25);
index->SetBM25Params(k1, b, avgdl);
return index;
} else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
auto index = new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, mmapped>(
sparse::SparseMetricType::METRIC_BM25);
index->SetBM25Params(k1, b, avgdl);
return index;
} else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::TAAT_NAIVE, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
auto index = new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::TAAT_NAIVE, mmapped>(
sparse::SparseMetricType::METRIC_BM25);
index->SetBM25Params(k1, b, avgdl);
return index;
} else {
return expected<sparse::BaseInvertedIndex<T>*>::Err(Status::invalid_args,
"Invalid search algorithm for SparseInvertedIndex");
}
} else {
if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") {
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_WAND, false, mmapped>();
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_WAND, mmapped>(
sparse::SparseMetricType::METRIC_IP);
return index;
} else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") {
auto index =
new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, false, mmapped>();
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, mmapped>(
sparse::SparseMetricType::METRIC_IP);
return index;
} else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") {
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::TAAT_NAIVE, false, mmapped>();
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::TAAT_NAIVE, mmapped>(
sparse::SparseMetricType::METRIC_IP);
return index;
} else {
return expected<sparse::BaseInvertedIndex<T>*>::Err(Status::invalid_args,
Expand Down
Loading

0 comments on commit 7dc867d

Please sign in to comment.