Skip to content

Commit

Permalink
Mock register VECTOR_INT8 for all index types (#1037)
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored Jan 16, 2025
1 parent a0e3ad9 commit 86ca90f
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 9 deletions.
12 changes: 11 additions & 1 deletion include/knowhere/index/index_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#include "index_static.h"
#include "knowhere/index/index.h"
#include "knowhere/index/index_node_data_mock_wrapper.h"
#include "knowhere/operands.h"
#include "knowhere/utils.h"

namespace knowhere {
Expand Down Expand Up @@ -122,18 +124,25 @@ class IndexFactory {
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__);

// register vector index supporting binary data type
#define KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__);

// register vector index supporting int8 data type
#define KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(name, index_node, features, ...) \
#define KNOWHERE_SIMPLE_REGISTER_DENSE_INT_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, int8, (features | knowhere::feature::INT8), ##__VA_ARGS__);

// register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types
#define KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__);

// register vector index supporting int data type
#define KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(name, index_node, features, ...) \
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, int8, (features | knowhere::feature::INT8), ##__VA_ARGS__);

// register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types, but mocked bf16 and fp16
#define KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
Expand All @@ -149,6 +158,7 @@ class IndexFactory {
std::make_unique<index_node<MockData<data_type>::type>>(version, object), thread_size)); \
}, \
data_type, typeCheck<data_type>(features), features)

#define KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(table_index, name, index_table) \
static int name = []() -> int { \
auto& static_index_table = std::get<table_index>(IndexFactory::StaticIndexTableInstance()); \
Expand Down
7 changes: 7 additions & 0 deletions include/knowhere/index/index_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,37 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
{IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_INT8},

{IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_INT8},

{IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_INT8},

{IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_INT8},

{IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_INT8},

{IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_INT8},

{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_INT8},

// gpu index
{IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT},
Expand Down
4 changes: 4 additions & 0 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FLAT, FlatIndexNode,
knowhere::feature::MMAP,
faiss::IndexFlat);

KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(FLAT, FlatIndexNode,
knowhere::feature::NO_TRAIN | knowhere::feature::KNN | knowhere::feature::MMAP,
faiss::IndexFlat);

KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(BINFLAT, FlatIndexNode,
knowhere::feature::NO_TRAIN | knowhere::feature::KNN |
knowhere::feature::MMAP,
Expand Down
16 changes: 8 additions & 8 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2918,21 +2918,21 @@ KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_DEPRECATED,
#else
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
#endif

KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)

} // namespace knowhere
2 changes: 2 additions & 0 deletions src/index/index_node_data_mock_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,6 @@ IndexNodeDataMockWrapper<DataType>::GetVectorByIds(const DataSetPtr dataset) con

template class knowhere::IndexNodeDataMockWrapper<knowhere::fp16>;
template class knowhere::IndexNodeDataMockWrapper<knowhere::bf16>;
template class knowhere::IndexNodeDataMockWrapper<knowhere::int8>;

} // namespace knowhere
14 changes: 14 additions & 0 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1261,4 +1261,18 @@ KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_SQ8, IvfIndexNode, knowhere::f
KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_SQ_CC, IvfIndexNode, knowhere::feature::NONE,
faiss::IndexIVFScalarQuantizerCC)

// int
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVFFLAT, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFFlat)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVF_FLAT, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFFlat)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVFFLATCC, IvfIndexNode, knowhere::feature::NONE, faiss::IndexIVFFlatCC)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVF_FLAT_CC, IvfIndexNode, knowhere::feature::NONE, faiss::IndexIVFFlatCC)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(SCANN, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexScaNN)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVFPQ, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFPQ)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVF_PQ, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFPQ)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVFSQ, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFScalarQuantizer)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVF_SQ, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFScalarQuantizer)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVF_SQ8, IvfIndexNode, knowhere::feature::MMAP, faiss::IndexIVFScalarQuantizer)
KNOWHERE_MOCK_REGISTER_DENSE_INT_GLOBAL(IVF_SQ_CC, IvfIndexNode, knowhere::feature::NONE,
faiss::IndexIVFScalarQuantizerCC)

} // namespace knowhere
11 changes: 11 additions & 0 deletions tests/ut/test_index_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,48 +31,59 @@ TEST_CASE("Test index and data type check", "[IndexCheckTest]") {
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_BFLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_BFLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_BFLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_BFLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_BFLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_BFLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_BFLOAT16));
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_INT8));

// gpu index
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT16));
CHECK_FALSE(
KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_BFLOAT16));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT16));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_BFLOAT16));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_IVFPQ, VecType::VECTOR_FLOAT));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_IVFPQ, VecType::VECTOR_FLOAT16));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_IVFPQ, VecType::VECTOR_BFLOAT16));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_IVFPQ, VecType::VECTOR_INT8));

CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_FLOAT));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_FLOAT16));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_BFLOAT16));
CHECK_FALSE(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_INT8));

// HNSW
CHECK(KnowhereCheck::IndexTypeAndDataTypeCheck(IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT));
Expand Down

0 comments on commit 86ca90f

Please sign in to comment.