diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 51fa452dd..4e0632c2c 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -103,7 +103,7 @@ constexpr const char* BM25_K1 = "bm25_k1"; constexpr const char* BM25_B = "bm25_b"; // average document length constexpr const char* BM25_AVGDL = "bm25_avgdl"; -constexpr const char* WAND_BM25_MAX_SCORE_RATIO = "wand_bm25_max_score_ratio"; +constexpr const char* DIM_MAX_SCORE_RATIO = "dim_max_score_ratio"; }; // namespace meta namespace indexparam { diff --git a/include/knowhere/sparse_utils.h b/include/knowhere/sparse_utils.h index be069c559..9f47df999 100644 --- a/include/knowhere/sparse_utils.h +++ b/include/knowhere/sparse_utils.h @@ -29,6 +29,11 @@ namespace knowhere::sparse { +enum class SparseMetricType { + METRIC_IP = 1, + METRIC_BM25 = 2, +}; + // integer type in SparseRow using table_t = uint32_t; // type used to represent the id of a vector in the index interface. diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc index 7fdd9f905..eb9e84858 100644 --- a/src/index/sparse/sparse_index_node.cc +++ b/src/index/sparse/sparse_index_node.cc @@ -95,18 +95,31 @@ class SparseInvertedIndexNode : public IndexNode { LOG_KNOWHERE_ERROR_ << "Could not search empty " << Type(); return expected::Err(Status::empty_index, "index not loaded"); } + auto cfg = static_cast(*config); + auto computer_or = index_->GetDocValueComputer(cfg); if (!computer_or.has_value()) { return expected::Err(computer_or.error(), computer_or.what()); } auto computer = computer_or.value(); - auto nq = dataset->GetRows(); - auto queries = static_cast*>(dataset->GetTensor()); - auto k = cfg.k.value(); - auto refine_factor = cfg.refine_factor.value_or(1); + auto dim_max_score_ratio = cfg.dim_max_score_ratio.value(); auto drop_ratio_search = cfg.drop_ratio_search.value_or(0.0f); + auto refine_factor = cfg.refine_factor.value_or(1); + // if no data was dropped during search, no refinement is needed. + if (drop_ratio_search == 0) { + refine_factor = 1; + } + + sparse::InvertedIndexApproxSearchParams approx_params = { + .refine_factor = refine_factor, + .drop_ratio_search = drop_ratio_search, + .dim_max_score_ratio = dim_max_score_ratio, + }; + auto queries = static_cast*>(dataset->GetTensor()); + auto nq = dataset->GetRows(); + auto k = cfg.k.value(); auto p_id = std::make_unique(nq * k); auto p_dist = std::make_unique(nq * k); @@ -114,8 +127,7 @@ class SparseInvertedIndexNode : public IndexNode { futs.reserve(nq); for (int64_t idx = 0; idx < nq; ++idx) { futs.emplace_back(search_pool_->push([&, idx = idx, p_id = p_id.get(), p_dist = p_dist.get()]() { - index_->Search(queries[idx], k, drop_ratio_search, p_dist + idx * k, p_id + idx * k, refine_factor, - bitset, computer); + index_->Search(queries[idx], k, p_dist + idx * k, p_id + idx * k, bitset, computer, approx_params); })); } WaitAllSuccess(futs); @@ -359,21 +371,21 @@ class SparseInvertedIndexNode : public IndexNode { auto k1 = cfg.bm25_k1.value(); auto b = cfg.bm25_b.value(); auto avgdl = cfg.bm25_avgdl.value(); - auto max_score_ratio = cfg.wand_bm25_max_score_ratio.value(); + if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") { - auto index = - new sparse::InvertedIndex(); - index->SetBM25Params(k1, b, avgdl, max_score_ratio); + auto index = new sparse::InvertedIndex( + sparse::SparseMetricType::METRIC_BM25); + index->SetBM25Params(k1, b, avgdl); return index; } else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") { - auto index = - new sparse::InvertedIndex(); - index->SetBM25Params(k1, b, avgdl, max_score_ratio); + auto index = new sparse::InvertedIndex( + sparse::SparseMetricType::METRIC_BM25); + index->SetBM25Params(k1, b, avgdl); return index; } else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") { - auto index = - new sparse::InvertedIndex(); - index->SetBM25Params(k1, b, avgdl, max_score_ratio); + auto index = new sparse::InvertedIndex( + sparse::SparseMetricType::METRIC_BM25); + index->SetBM25Params(k1, b, avgdl); return index; } else { return expected*>::Err(Status::invalid_args, @@ -381,14 +393,16 @@ class SparseInvertedIndexNode : public IndexNode { } } else { if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") { - auto index = new sparse::InvertedIndex(); + auto index = new sparse::InvertedIndex( + sparse::SparseMetricType::METRIC_IP); return index; } else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") { - auto index = - new sparse::InvertedIndex(); + auto index = new sparse::InvertedIndex( + sparse::SparseMetricType::METRIC_IP); return index; } else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") { - auto index = new sparse::InvertedIndex(); + auto index = new sparse::InvertedIndex( + sparse::SparseMetricType::METRIC_IP); return index; } else { return expected*>::Err(Status::invalid_args, diff --git a/src/index/sparse/sparse_inverted_index.h b/src/index/sparse/sparse_inverted_index.h index 4544aeff0..4e7822464 100644 --- a/src/index/sparse/sparse_inverted_index.h +++ b/src/index/sparse/sparse_inverted_index.h @@ -41,6 +41,12 @@ enum class InvertedIndexAlgo { DAAT_MAXSCORE, }; +struct InvertedIndexApproxSearchParams { + int refine_factor; + float drop_ratio_search; + float dim_max_score_ratio; +}; + template class BaseInvertedIndex { public: @@ -49,7 +55,7 @@ class BaseInvertedIndex { virtual Status Save(MemoryIOWriter& writer) = 0; - // supplement_target_filename: when in mmap mode, we need an extra file to store the mmaped index data structure. + // supplement_target_filename: when in mmap mode, we need an extra file to store the mmapped index data structure. // this file will be created during loading and deleted in the destructor. virtual Status Load(MemoryIOReader& reader, int map_flags, const std::string& supplement_target_filename) = 0; @@ -61,8 +67,8 @@ class BaseInvertedIndex { Add(const SparseRow* data, size_t rows, int64_t dim) = 0; virtual void - Search(const SparseRow& query, size_t k, float drop_ratio_search, float* distances, label_t* labels, - size_t refine_factor, const BitsetView& bitset, const DocValueComputer& computer) const = 0; + Search(const SparseRow& query, size_t k, float* distances, label_t* labels, const BitsetView& bitset, + const DocValueComputer& computer, InvertedIndexApproxSearchParams& approx_params) const = 0; virtual std::vector GetAllDistances(const SparseRow& query, float drop_ratio_search, const BitsetView& bitset, @@ -84,10 +90,10 @@ class BaseInvertedIndex { n_cols() const = 0; }; -template +template class InvertedIndex : public BaseInvertedIndex { public: - explicit InvertedIndex() { + explicit InvertedIndex(SparseMetricType metric_type) : metric_type_(metric_type) { } ~InvertedIndex() override { @@ -112,15 +118,15 @@ class InvertedIndex : public BaseInvertedIndex { using Vector = std::conditional_t, std::vector>; void - SetBM25Params(float k1, float b, float avgdl, float max_score_ratio) { - bm25_params_ = std::make_unique(k1, b, avgdl, max_score_ratio); + SetBM25Params(float k1, float b, float avgdl) { + bm25_params_ = std::make_unique(k1, b, avgdl); } expected> GetDocValueComputer(const SparseInvertedIndexConfig& cfg) const override { // if metric_type is set in config, it must match with how the index was built. auto metric_type = cfg.metric_type; - if constexpr (!bm25) { + if (metric_type_ != SparseMetricType::METRIC_BM25) { if (metric_type.has_value() && !IsMetricType(metric_type.value(), metric::IP)) { auto msg = "metric type not match, expected: " + std::string(metric::IP) + ", got: " + metric_type.value(); @@ -239,7 +245,7 @@ class InvertedIndex : public BaseInvertedIndex { if constexpr (mmapped) { RETURN_IF_ERROR(PrepareMmap(reader, rows, map_flags, supplement_target_filename)); } else { - if constexpr (bm25) { + if (metric_type_ == SparseMetricType::METRIC_BM25) { bm25_params_->row_sums.reserve(rows); } } @@ -304,7 +310,7 @@ class InvertedIndex : public BaseInvertedIndex { if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { map_byte_size_ += max_score_in_dim_byte_size; } - if constexpr (bm25) { + if (metric_type_ == SparseMetricType::METRIC_BM25) { row_sums_byte_size = rows * sizeof(typename decltype(bm25_params_->row_sums)::value_type); map_byte_size_ += row_sums_byte_size; } @@ -355,7 +361,7 @@ class InvertedIndex : public BaseInvertedIndex { ptr += max_score_in_dim_byte_size; } - if constexpr (bm25) { + if (metric_type_ == SparseMetricType::METRIC_BM25) { bm25_params_->row_sums.initialize(ptr, row_sums_byte_size); ptr += row_sums_byte_size; } @@ -407,7 +413,7 @@ class InvertedIndex : public BaseInvertedIndex { max_dim_ = dim; } - if constexpr (bm25) { + if (metric_type_ == SparseMetricType::METRIC_BM25) { bm25_params_->row_sums.reserve(current_rows + rows); } for (size_t i = 0; i < rows; ++i) { @@ -420,8 +426,8 @@ class InvertedIndex : public BaseInvertedIndex { } void - Search(const SparseRow& query, size_t k, float drop_ratio_search, float* distances, label_t* labels, - size_t refine_factor, const BitsetView& bitset, const DocValueComputer& computer) const override { + Search(const SparseRow& query, size_t k, float* distances, label_t* labels, const BitsetView& bitset, + const DocValueComputer& computer, InvertedIndexApproxSearchParams& approx_params) const override { // initially set result distances to NaN and labels to -1 std::fill(distances, distances + k, std::numeric_limits::quiet_NaN()); std::fill(labels, labels + k, -1); @@ -429,36 +435,25 @@ class InvertedIndex : public BaseInvertedIndex { return; } - std::vector values(query.size()); - for (size_t i = 0; i < query.size(); ++i) { - values[i] = std::abs(query[i].val); - } - auto q_threshold = get_threshold(values, drop_ratio_search); - - // if no data was dropped during search, no refinement is needed. - if (drop_ratio_search == 0) { - refine_factor = 1; - } - - auto q_vec = parse_query(query, q_threshold); + auto q_vec = parse_query(query, approx_params.drop_ratio_search); if (q_vec.empty()) { return; } - MaxMinHeap heap(k * refine_factor); + MaxMinHeap heap(k * approx_params.refine_factor); // DAAT_WAND and DAAT_MAXSCORE are based on the implementation in PISA. if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) { - search_daat_wand(q_vec, heap, bitset, computer); + search_daat_wand(q_vec, heap, bitset, computer, approx_params.dim_max_score_ratio); } else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) { - search_daat_maxscore(q_vec, heap, bitset, computer); + search_daat_maxscore(q_vec, heap, bitset, computer, approx_params.dim_max_score_ratio); } else { search_taat_naive(q_vec, heap, bitset, computer); } - if (refine_factor == 1) { + if (approx_params.refine_factor == 1) { collect_result(heap, distances, labels); } else { - refine_and_collect(query, heap, k, distances, labels, computer); + refine_and_collect(query, heap, k, distances, labels, computer, approx_params); } } @@ -473,8 +468,7 @@ class InvertedIndex : public BaseInvertedIndex { for (size_t i = 0; i < query.size(); ++i) { values[i] = std::abs(query[i].val); } - auto q_threshold = get_threshold(values, drop_ratio_search); - auto q_vec = parse_query(query, q_threshold); + auto q_vec = parse_query(query, drop_ratio_search); auto distances = compute_all_distances(q_vec, computer); if (!bitset.empty()) { @@ -502,8 +496,10 @@ class InvertedIndex : public BaseInvertedIndex { auto it = std::lower_bound(plist_ids.begin(), plist_ids.end(), vec_id, [](const auto& x, table_t y) { return x < y; }); if (it != plist_ids.end() && *it == vec_id) { - distance += val * computer(inverted_index_vals_[dim_it->second][it - plist_ids.begin()], - bm25 ? bm25_params_->row_sums.at(vec_id) : 0); + distance += + val * + computer(inverted_index_vals_[dim_it->second][it - plist_ids.begin()], + metric_type_ == SparseMetricType::METRIC_BM25 ? bm25_params_->row_sums.at(vec_id) : 0); } } @@ -573,7 +569,7 @@ class InvertedIndex : public BaseInvertedIndex { // TODO: improve with SIMD for (size_t j = 0; j < plist_ids.size(); ++j) { auto doc_id = plist_ids[j]; - float val_sum = bm25 ? bm25_params_->row_sums.at(doc_id) : 0; + float val_sum = metric_type_ == SparseMetricType::METRIC_BM25 ? bm25_params_->row_sums.at(doc_id) : 0; scores[doc_id] += q_vec[i].second * computer(plist_vals[j], val_sum); } } @@ -644,7 +640,16 @@ class InvertedIndex : public BaseInvertedIndex { }; // struct Cursor std::vector> - parse_query(const SparseRow& query, DType q_threshold) const { + parse_query(const SparseRow& query, float drop_ratio_search) const { + DType q_threshold = 0; + if (drop_ratio_search != 0) { + std::vector values(query.size()); + for (size_t i = 0; i < query.size(); ++i) { + values[i] = std::abs(query[i].val); + } + q_threshold = get_threshold(values, drop_ratio_search); + } + std::vector> filtered_query; for (size_t i = 0; i < query.size(); ++i) { auto [dim, val] = query[i]; @@ -654,20 +659,22 @@ class InvertedIndex : public BaseInvertedIndex { } filtered_query.emplace_back(dim_it->second, val); } + return filtered_query; } template std::vector> make_cursors(const std::vector>& q_vec, const DocValueComputer& computer, - DocIdFilter& filter) const { + DocIdFilter& filter, float dim_max_score_ratio) const { std::vector> cursors; cursors.reserve(q_vec.size()); for (auto q_dim : q_vec) { auto& plist_ids = inverted_index_ids_[q_dim.first]; auto& plist_vals = inverted_index_vals_[q_dim.first]; - cursors.emplace_back(plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[q_dim.first] * q_dim.second, - q_dim.second, filter); + cursors.emplace_back(plist_ids, plist_vals, n_rows_internal_, + max_score_in_dim_[q_dim.first] * q_dim.second * dim_max_score_ratio, q_dim.second, + filter); } return cursors; } @@ -690,8 +697,8 @@ class InvertedIndex : public BaseInvertedIndex { template void search_daat_wand(const std::vector>& q_vec, MaxMinHeap& heap, DocIdFilter& filter, - const DocValueComputer& computer) const { - std::vector> cursors = make_cursors(q_vec, computer, filter); + const DocValueComputer& computer, float dim_max_score_ratio) const { + std::vector> cursors = make_cursors(q_vec, computer, filter, dim_max_score_ratio); std::vector*> cursor_ptrs(cursors.size()); for (size_t i = 0; i < cursors.size(); ++i) { cursor_ptrs[i] = &cursors[i]; @@ -726,7 +733,8 @@ class InvertedIndex : public BaseInvertedIndex { table_t pivot_id = cursor_ptrs[pivot]->cur_vec_id_; if (pivot_id == cursor_ptrs[0]->cur_vec_id_) { float score = 0; - float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(pivot_id) : 0; + float cur_vec_sum = + metric_type_ == SparseMetricType::METRIC_BM25 ? bm25_params_->row_sums.at(pivot_id) : 0; for (auto& cursor_ptr : cursor_ptrs) { if (cursor_ptr->cur_vec_id_ != pivot_id) { break; @@ -754,12 +762,12 @@ class InvertedIndex : public BaseInvertedIndex { template void search_daat_maxscore(std::vector>& q_vec, MaxMinHeap& heap, DocIdFilter& filter, - const DocValueComputer& computer) const { + const DocValueComputer& computer, float dim_max_score_ratio) const { std::sort(q_vec.begin(), q_vec.end(), [this](auto& a, auto& b) { return a.second * max_score_in_dim_[a.first] > b.second * max_score_in_dim_[b.first]; }); - std::vector> cursors = make_cursors(q_vec, computer, filter); + std::vector> cursors = make_cursors(q_vec, computer, filter, dim_max_score_ratio); float threshold = heap.full() ? heap.top().val : 0; @@ -802,7 +810,8 @@ class InvertedIndex : public BaseInvertedIndex { curr_cand_score = 0.0f; // update next_cand_vec_id next_cand_vec_id = n_rows_internal_; - float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(curr_cand_vec_id) : 0; + float cur_vec_sum = + metric_type_ == SparseMetricType::METRIC_BM25 ? bm25_params_->row_sums.at(curr_cand_vec_id) : 0; for (size_t i = 0; i < first_ne_idx; ++i) { if (cursors[i].cur_vec_id_ == curr_cand_vec_id) { @@ -842,7 +851,8 @@ class InvertedIndex : public BaseInvertedIndex { void refine_and_collect(const SparseRow& query, MaxMinHeap& inacc_heap, size_t k, float* distances, - label_t* labels, const DocValueComputer& computer) const { + label_t* labels, const DocValueComputer& computer, + InvertedIndexApproxSearchParams& approx_params) const { std::vector docids; MaxMinHeap heap(k); @@ -857,11 +867,14 @@ class InvertedIndex : public BaseInvertedIndex { return; } + // dim_max_score_ratio for refine process should be >= 1.0 + float dim_max_score_ratio = std::max(approx_params.dim_max_score_ratio, 1.0f); + DocIdFilterByVector filter(std::move(docids)); if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) { - search_daat_wand(q_vec, heap, filter, computer); + search_daat_wand(q_vec, heap, filter, computer, dim_max_score_ratio); } else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) { - search_daat_maxscore(q_vec, heap, filter, computer); + search_daat_maxscore(q_vec, heap, filter, computer, dim_max_score_ratio); } else { search_taat_naive(q_vec, heap, filter, computer); } @@ -884,7 +897,7 @@ class InvertedIndex : public BaseInvertedIndex { [[maybe_unused]] float row_sum = 0; for (size_t j = 0; j < row.size(); ++j) { auto [dim, val] = row[j]; - if constexpr (bm25) { + if (metric_type_ == SparseMetricType::METRIC_BM25) { row_sum += val; } // Skip values equals to or close enough to zero(which contributes @@ -895,7 +908,7 @@ class InvertedIndex : public BaseInvertedIndex { auto dim_it = dim_map_.find(dim); if (dim_it == dim_map_.cend()) { if constexpr (mmapped) { - throw std::runtime_error("unexpected vector dimension in mmaped InvertedIndex"); + throw std::runtime_error("unexpected vector dimension in mmapped InvertedIndex"); } dim_it = dim_map_.insert({dim, next_dim_id_++}).first; inverted_index_ids_.emplace_back(); @@ -908,13 +921,13 @@ class InvertedIndex : public BaseInvertedIndex { inverted_index_vals_[dim_it->second].emplace_back(get_quant_val(val)); if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { auto score = static_cast(val); - if constexpr (bm25) { - score = bm25_params_->max_score_ratio * bm25_params_->wand_max_score_computer(val, row_sum); + if (metric_type_ == SparseMetricType::METRIC_BM25) { + score = bm25_params_->max_score_computer(val, row_sum); } max_score_in_dim_[dim_it->second] = std::max(max_score_in_dim_[dim_it->second], score); } } - if constexpr (bm25) { + if (metric_type_ == SparseMetricType::METRIC_BM25) { bm25_params_->row_sums.emplace_back(row_sum); } } @@ -943,6 +956,8 @@ class InvertedIndex : public BaseInvertedIndex { Vector> inverted_index_vals_; Vector max_score_in_dim_; + SparseMetricType metric_type_; + size_t n_rows_internal_ = 0; size_t max_dim_ = 0; uint32_t next_dim_id_ = 0; @@ -958,15 +973,10 @@ class InvertedIndex : public BaseInvertedIndex { // corresponds to the document length of each doc in the BM25 formula. Vector row_sums; - // below are used only for DAAT_WAND and DAAT_MAXSCORE algorithms. - float max_score_ratio; - DocValueComputer wand_max_score_computer; + DocValueComputer max_score_computer; - BM25Params(float k1, float b, float avgdl, float max_score_ratio) - : k1(k1), - b(b), - max_score_ratio(max_score_ratio), - wand_max_score_computer(GetDocValueBM25Computer(k1, b, avgdl)) { + BM25Params(float k1, float b, float avgdl) + : k1(k1), b(b), max_score_computer(GetDocValueBM25Computer(k1, b, avgdl)) { } }; // struct BM25Params diff --git a/src/index/sparse/sparse_inverted_index_config.h b/src/index/sparse/sparse_inverted_index_config.h index 3b2fbd1be..f92b8a03c 100644 --- a/src/index/sparse/sparse_inverted_index_config.h +++ b/src/index/sparse/sparse_inverted_index_config.h @@ -22,7 +22,7 @@ class SparseInvertedIndexConfig : public BaseConfig { CFG_FLOAT drop_ratio_build; CFG_FLOAT drop_ratio_search; CFG_INT refine_factor; - CFG_FLOAT wand_bm25_max_score_ratio; + CFG_FLOAT dim_max_score_ratio; CFG_STRING inverted_index_algo; KNOHWERE_DECLARE_CONFIG(SparseInvertedIndexConfig) { // NOTE: drop_ratio_build has been deprecated, it won't change anything @@ -57,7 +57,7 @@ class SparseInvertedIndexConfig : public BaseConfig { * WAND algorithm uses the max score of each dim for pruning, which is * precomputed and cached in our implementation. The cached max score * is actually not equal to the actual max score. Instead, it is a - * scaled one based on the wand_bm25_max_score_ratio. + * scaled one based on the dim_max_score_ratio. * We should use different scale strategy for different reasons. * 1. As more documents being added to the collection, avgdl could * be changed. Re-computing such score for each segment is @@ -66,29 +66,45 @@ class SparseInvertedIndexConfig : public BaseConfig { * This will make the cached max score larger than the actual max * score, so that it makes the filtering less aggressive, but * guarantees the correctness. - * 2. In WAND searching process, we use the sum of the max scores to - * filter the candidate vectors. If the sum is smaller than the - * threshold, skip current vector. If approximate searching is - * accepted, we can make the skipping more aggressive by downscaling - * the max score with a ratio less than 1.0. Since the possibility - * that the max score of all dims in the query appears on the same - * vector is relatively small, it won't lead to a sharp decline in - * the recall rate within a certain range. + * 2. For dimension maxscore based algorithms like WAND and MaxScore, + * they use the sum of the max scores to filter the candidate + * vectors. If the sum is smaller than the threshold, skip current + * vector. If approximate searching is accepted, we can make the + * skipping more aggressive by downscaling the max score with a + * ratio less than 1.0. Since the possibility that the max score + * of all dims in the query appears on the same vector is + * relatively small, it won't lead to a sharp decline in the + * recall rate within a certain range. */ - KNOWHERE_CONFIG_DECLARE_FIELD(wand_bm25_max_score_ratio) + KNOWHERE_CONFIG_DECLARE_FIELD(dim_max_score_ratio) .set_range(0.5, 1.3) .set_default(1.05) .description("ratio to upscale/downscale the max score of each dimension") - .for_train_and_search() - .for_deserialize() - .for_deserialize_from_file(); + .for_search(); KNOWHERE_CONFIG_DECLARE_FIELD(inverted_index_algo) .description("inverted index algorithm") .set_default("DAAT_MAXSCORE") - .for_train_and_search() + .for_train() .for_deserialize() .for_deserialize_from_file(); } + + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + if (param_type == PARAM_TYPE::TRAIN) { + constexpr std::array legal_inverted_index_algo_list{"TAAT_NAIVE", "DAAT_WAND", + "DAAT_MAXSCORE"}; + std::string inverted_index_algo_str = inverted_index_algo.value_or(""); + if (std::find(legal_inverted_index_algo_list.begin(), legal_inverted_index_algo_list.end(), + inverted_index_algo_str) == legal_inverted_index_algo_list.end()) { + std::string msg = "sparse inverted index algo " + inverted_index_algo_str + + " not found or not supported, supported: [TAAT_NAIVE DAAT_WAND DAAT_MAXSCORE]"; + return HandleError(err_msg, msg, Status::invalid_args); + } + } + + return Status::success; + } }; // class SparseInvertedIndexConfig } // namespace knowhere