Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add support of logical merge in Cagra #713

Open
wants to merge 12 commits into
base: branch-25.04
Choose a base branch
from
Open
204 changes: 200 additions & 4 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ struct extend_params {
/**
* @brief Determines the strategy for merging CAGRA graphs.
*
* @note Currently, only the PHYSICAL strategy is supported.
*/
enum MergeStrategy {
/**
Expand All @@ -286,9 +285,16 @@ enum MergeStrategy {
* This is expensive to build but does not impact search latency or quality.
* Preferred for many smaller CAGRA graphs.
*
* @note Currently, this is the only supported strategy.
*/
PHYSICAL
PHYSICAL,
/**
* @brief Logical merge: Wraps a new index structure around existing CAGRA graphs
* and broadcasts the query to each of them.
*
* This is a fast merge but incurs a small hit in search latency.
* Preferred for fewer larger CAGRA graphs.
*/
LOGICAL
};

/**
Expand Down Expand Up @@ -565,6 +571,82 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<dataset_index_type>> dataset_;
};
/**
* @}
*/

/**
* @defgroup cagra_cpp_composite_index CAGRA composite index type
* @{
*/

/**
* @brief Lightweight composite kNN index for CAGRA.
*
* This class aggregates logically multiple CAGRA indices into a single composite index,
* providing a unified interface for kNN search. It is a lightweight structure
* that does not own or manage the lifecycle of the underlying indices; instead,
* it holds non-owning pointers to them.
*
* All sub-indices within the composite index **must share the same distance metric
* and dimensionality**.
*
* @tparam T Data element type.
* @tparam IdxT Index type representing dataset.extent(0), used for vector indices.
*/

template <typename T, typename IdxT>
struct composite_index {
template <typename Container>
explicit composite_index(Container&& indices) : sub_indices(std::forward<Container>(indices))
{
RAFT_EXPECTS(!sub_indices.empty(), "composite_index requires at least one sub-index.");

for (auto* idx : sub_indices) {
RAFT_EXPECTS(idx != nullptr, "sub_indices contains a null pointer.");
}

auto& first_index = *sub_indices.front();
metric_ = first_index.metric();
dim_ = first_index.dim();
size_ = 0;

for (auto* idx : sub_indices) {
RAFT_EXPECTS(idx->metric() == metric_, "All sub-indices must have the same metric.");
RAFT_EXPECTS(idx->dim() == dim_, "All sub-indices must have the same dim.");
size_ += idx->size();
}
}

public:
composite_index(const composite_index& other) = default;
composite_index& operator=(const composite_index& other) = default;

composite_index(composite_index&& other) noexcept = default;
composite_index& operator=(composite_index&& other) noexcept = default;

constexpr inline auto metric() const noexcept -> cuvs::distance::DistanceType { return metric_; }

constexpr inline auto size() const noexcept -> IdxT { return size_; }

constexpr inline auto dim() const noexcept -> uint32_t { return dim_; }

constexpr inline auto graph_degree() const noexcept -> uint32_t
{
return sub_indices.front()->graph_degree();
}

constexpr inline auto num_indices() const noexcept -> uint32_t { return sub_indices.size(); }

public:
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> sub_indices;

private:
cuvs::distance::DistanceType metric_;
IdxT size_;
uint32_t dim_;
};

/**
* @}
*/
Expand Down Expand Up @@ -1125,7 +1207,6 @@ void extend(
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/

void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
Expand Down Expand Up @@ -1209,7 +1290,105 @@ void search(raft::resources const& res,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @brief Search ANN using the composite cagra index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index composite cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
Copy link
Member Author

@rhdong rhdong Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cjnolet @achirkin , I had to add a new search for composite_index, the root cause is the cuve::neighbors::index is not designed to be an abstract virtual class, which makes many pathways impossible, like inheriting from index, which can avoid declaring a standalone composible_index and a new search API. I was trying to refactor it, but I found it was used widely in different algo codes than give up under a tight timeline; maybe we can make it in the long term.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay that we don't require a strict class hierarchy initially, so long as we have proper overloads so that from a user's perspective, they can call auto index = cuvs::neighbors::cagra::merge(...) and get back a proper object from which they can later call cuvs::neighbors::cagra::search(..., index,..)

cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::composite_index<float, uint32_t>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the composite cagra index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index composite cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::composite_index<half, uint32_t>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the composite cagra index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index composite cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::composite_index<int8_t, uint32_t>& index,
raft::device_matrix_view<const int8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the composite cagra index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index composite cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::composite_index<uint8_t, uint32_t>& index,
raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @}
*/
Expand Down Expand Up @@ -1985,6 +2164,23 @@ auto merge(raft::resources const& res,
const cuvs::neighbors::cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<uint8_t, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::index<uint8_t, uint32_t>;

auto make_composite_index(const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<float, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::composite_index<float, uint32_t>;

auto make_composite_index(const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<half, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::composite_index<half, uint32_t>;

auto make_composite_index(const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<int8_t, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::composite_index<int8_t, uint32_t>;

auto make_composite_index(const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<uint8_t, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::composite_index<uint8_t, uint32_t>;

/**
* @}
*/
Expand Down
32 changes: 31 additions & 1 deletion cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ void search(raft::resources const& res,
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, none_filter_type>(
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
return;
} catch (const std::bad_cast&) {
}

Expand All @@ -371,6 +370,27 @@ void search(raft::resources const& res,
}
}

template <typename T, typename IdxT>
void search(raft::resources const& res,
const search_params& params,
const composite_index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter_ref)
{
try {
using expected_filter_t = cuvs::neighbors::filtering::none_sample_filter;

auto& sample_filter = dynamic_cast<const expected_filter_t&>(sample_filter_ref);
auto sample_filter_copy = sample_filter;
return cagra::detail::search_on_composite_index<T, IdxT, expected_filter_t>(
res, params, idx, queries, neighbors, distances, sample_filter_copy);
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type by composite_index");
}
}

template <class T, class IdxT, class Accessor>
void extend(
raft::resources const& handle,
Expand All @@ -391,6 +411,16 @@ index<T, IdxT> merge(raft::resources const& handle,
return cagra::detail::merge<T, IdxT>(handle, params, indices);
}

template <class T, class IdxT>
composite_index<T, IdxT> make_composite_index(const cagra::merge_params& params,
std::vector<index<T, IdxT>*>& indices)
{
if (params.strategy != cagra::MergeStrategy::LOGICAL) {
RAFT_LOG_WARN("Merge strategy should be LOGICAL.");
}
return composite_index<T, IdxT>(std::move(indices));
}

/** @} */ // end group cagra

} // namespace cuvs::neighbors::cagra
20 changes: 13 additions & 7 deletions cpp/src/neighbors/cagra_merge_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

namespace cuvs::neighbors::cagra {

#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
}; \
auto make_composite_index(const cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::composite_index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::make_composite_index<T, IdxT>(params, indices); \
}

RAFT_INST_CAGRA_MERGE(float, uint32_t);
Expand Down
20 changes: 13 additions & 7 deletions cpp/src/neighbors/cagra_merge_half.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

namespace cuvs::neighbors::cagra {

#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
}; \
auto make_composite_index(const cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::composite_index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::make_composite_index<T, IdxT>(params, indices); \
}

RAFT_INST_CAGRA_MERGE(half, uint32_t);
Expand Down
20 changes: 13 additions & 7 deletions cpp/src/neighbors/cagra_merge_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

namespace cuvs::neighbors::cagra {

#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
}; \
auto make_composite_index(const cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::composite_index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::make_composite_index<T, IdxT>(params, indices); \
}

RAFT_INST_CAGRA_MERGE(int8_t, uint32_t);
Expand Down
Loading