Skip to content

Commit

Permalink
Isolate faiss_hnsw and hnsw by index version (#952)
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <[email protected]>
  • Loading branch information
foxspy authored Nov 16, 2024
1 parent 6b7d756 commit 98253b8
Showing 1 changed file with 48 additions and 77 deletions.
125 changes: 48 additions & 77 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1381,20 +1381,33 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplate : public BaseFaissRegularIndexHN
// but a deserialization may override its search behavior.
// It is a concrete implementation's responsibility to initialize BaseIndex and
// FallbackSearchIndex properly.
class IndexNodeWithSearchFallback : public IndexNode {
class HNSWIndexNodeWithFallback : public IndexNode {
public:
IndexNodeWithSearchFallback(const int32_t& version, const Object& object) {
use_base_index = true;
HNSWIndexNodeWithFallback(const int32_t& version, const Object& object) {
constexpr int faiss_hnsw_support_version = 6;
if (version >= faiss_hnsw_support_version) {
use_base_index = true;
} else {
use_base_index = false;
}
}

Status
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
return base_index->Train(dataset, cfg);
if (use_base_index) {
return base_index->Train(dataset, cfg);
} else {
return fallback_search_index->Train(dataset, cfg);
}
}

Status
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
return base_index->Add(dataset, cfg);
if (use_base_index) {
return base_index->Add(dataset, cfg);
} else {
return fallback_search_index->Add(dataset, cfg);
}
}

expected<DataSetPtr>
Expand All @@ -1408,7 +1421,29 @@ class IndexNodeWithSearchFallback : public IndexNode {

Status
Serialize(BinarySet& binset) const override {
return base_index->Serialize(binset);
if (use_base_index) {
return base_index->Serialize(binset);
} else {
return fallback_search_index->Serialize(binset);
}
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
if (use_base_index) {
return base_index->Deserialize(binset, config);
} else {
return fallback_search_index->Deserialize(binset, config);
}
}

Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> config) override {
if (use_base_index) {
return base_index->DeserializeFromFile(filename, config);
} else {
return fallback_search_index->DeserializeFromFile(filename, config);
}
}

int64_t
Expand Down Expand Up @@ -1440,7 +1475,11 @@ class IndexNodeWithSearchFallback : public IndexNode {

std::string
Type() const override {
return base_index->Type();
if (use_base_index) {
return base_index->Type();
} else {
return fallback_search_index->Type();
}
}

bool
Expand Down Expand Up @@ -1494,79 +1533,11 @@ class IndexNodeWithSearchFallback : public IndexNode {
std::unique_ptr<IndexNode> fallback_search_index;
};

class BaseFaissRegular2HnswlibIndexNode : public IndexNodeWithSearchFallback {
public:
BaseFaissRegular2HnswlibIndexNode(const int32_t& version, const Object& object)
: IndexNodeWithSearchFallback(version, object) {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
// is the name for a base index?
BinaryPtr binary = binset.GetByName(base_index->Type());
if (binary != nullptr) {
auto base_status = base_index->Deserialize(binset, config);
if (base_status == Status::success) {
// switch to a base index
use_base_index = true;
}

if (base_status != Status::invalid_serialized_index_type) {
return base_status;
}

// we go ahead if base_index returned Status::invalid_serialized_index_type
}

// ok, try to deserialize as a fallback one
BinaryPtr binary_fallback = binset.GetByName(fallback_search_index->Type());
if (binary_fallback != nullptr) {
LOG_KNOWHERE_INFO_ << "The provided data does not look like a FAISS index. Falling back to hnswlib index.";
auto fallback_status = fallback_search_index->Deserialize(binset, config);
if (fallback_status == Status::success) {
// switch to a fallback index
use_base_index = false;
}

return fallback_status;
}

// unknown index
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
};

Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> config) override {
auto base_status = base_index->DeserializeFromFile(filename, config);
if (base_status == Status::success) {
// switch to a base index
use_base_index = true;
}

if (base_status != Status::invalid_serialized_index_type) {
return base_status;
}

// we go ahead if base_index returned Status::invalid_serialized_index_type

// ok, try to deserialize as a fallback one
LOG_KNOWHERE_INFO_ << "The provided data does not look like a FAISS index. Falling back to hnswlib index.";
auto fallback_status = fallback_search_index->DeserializeFromFile(filename, config);
if (fallback_status == Status::success) {
// switch to a fallback index
use_base_index = false;
}

return fallback_status;
}
};

template <typename DataType>
class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public BaseFaissRegular2HnswlibIndexNode {
class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public HNSWIndexNodeWithFallback {
public:
BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback(const int32_t& version, const Object& object)
: BaseFaissRegular2HnswlibIndexNode(version, object) {
: HNSWIndexNodeWithFallback(version, object) {
// initialize underlying nodes
base_index = std::make_unique<BaseFaissRegularIndexHNSWFlatNodeTemplate<DataType>>(version, object);
fallback_search_index = std::make_unique<HnswIndexNode<DataType, hnswlib::QuantType::None>>(version, object);
Expand Down

0 comments on commit 98253b8

Please sign in to comment.