Skip to content

Commit

Permalink
Add Faiss to the list of supported ANN frameworks (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg authored Dec 2, 2023
1 parent 3bef62d commit c4f32da
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 64 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ One important aspect of deploying recommender model is efficient retrieval via A

| Supported framework | Cornac wrapper | Examples |
| :---: | :---: | :---: |
| [meta/faiss](https://github.com/facebookresearch/faiss) | [FaissANN](cornac/models/ann/recom_ann_faiss.py) | [ann_all.ipynb](examples/ann_all.ipynb)
| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb)
| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_all.ipynb](examples/ann_all.ipynb)

Expand Down
1 change: 1 addition & 0 deletions cornac/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .recommender import NextBasketRecommender

from .amr import AMR
from .ann import FaissANN
from .ann import HNSWLibANN
from .ann import ScaNNANN
from .baseline_only import BaselineOnly
Expand Down
1 change: 1 addition & 0 deletions cornac/models/ann/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .recom_ann_faiss import FaissANN
from .recom_ann_hnswlib import HNSWLibANN
from .recom_ann_scann import ScaNNANN
153 changes: 153 additions & 0 deletions cornac/models/ann/recom_ann_faiss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


import multiprocessing
import numpy as np

from ..recommender import MEASURE_L2, MEASURE_DOT, MEASURE_COSINE
from .recom_ann_base import BaseANN


class FaissANN(BaseANN):
"""Approximate Nearest Neighbor Search with Faiss (https://github.com/facebookresearch/faiss).
Faiss provides both CPU and GPU implementation. More on the algorithms:
https://github.com/facebookresearch/faiss/wiki
Parameters
----------------
model: object: :obj:`cornac.models.Recommender`, required
Trained recommender model which to get user/item vectors from.
nlist: int, default: 100
The number of cells used for building the index.
nprobe: int, default: 50
The number of cells (out of nlist) that are visited to perform a search.
use_gpu : bool, optional
Whether or not to run Faiss on GPU. Requires faiss-gpu to be installed
instead of faiss-cpu.
num_threads: int, optional, default: -1
Default number of threads used for building index. If num_threads = -1,
all cores will be used.
seed: int, optional, default: None
Random seed for reproducibility.
name: str, required
Name of the recommender model.
verbose: boolean, optional, default: False
When True, running logs are displayed.
"""

def __init__(
self,
model,
nlist=100,
nprobe=50,
use_gpu=False,
num_threads=-1,
seed=None,
name="FaissANN",
verbose=False,
):
super().__init__(model=model, name=name, verbose=verbose)

self.model = model
self.nlist = nlist
self.nprobe = nprobe
self.use_gpu = use_gpu
self.num_threads = (
num_threads if num_threads != -1 else multiprocessing.cpu_count()
)
self.seed = seed

self.index = None
self.ignored_attrs.extend(
[
"index", # will be saved separately
"item_vectors", # redundant after index is built
]
)

def build_index(self):
"""Building index from the base recommender model."""
import faiss

faiss.omp_set_num_threads(self.num_threads)

SUPPORTED_MEASURES = {
MEASURE_L2: faiss.METRIC_L2,
MEASURE_DOT: faiss.METRIC_INNER_PRODUCT,
MEASURE_COSINE: faiss.METRIC_INNER_PRODUCT,
}

assert self.measure in SUPPORTED_MEASURES

if self.measure == MEASURE_COSINE:
self.item_vectors /= np.linalg.norm(self.item_vectors, axis=1)[
:, np.newaxis
]

self.item_vectors = self.item_vectors.astype("float32")

self.index = faiss.IndexIVFFlat(
faiss.IndexFlat(self.item_vectors.shape[1]),
self.item_vectors.shape[1],
self.nlist,
SUPPORTED_MEASURES[self.measure],
)

if self.use_gpu:
self.index = faiss.index_cpu_to_all_gpus(self.index)

self.index.train(self.item_vectors)
self.index.add(self.item_vectors)
self.index.nprobe = self.nprobe

def knn_query(self, query, k):
"""Implementing ANN search for a given query.
Returns
-------
neighbors, distances: numpy.array and numpy.array
Array of k-nearest neighbors and corresponding distances for the given query.
"""
distances, neighbors = self.index.search(query, k)
return neighbors, distances

def save(self, save_dir=None):
import faiss

saved_path = super().save(save_dir)
idx_path = saved_path + ".index"
if self.use_gpu:
self.index = faiss.index_gpu_to_cpu(self.index)
faiss.write_index(self.index, idx_path)
return saved_path

@staticmethod
def load(model_path, trainable=False):
import faiss

ann = BaseANN.load(model_path, trainable)
idx_path = ann.load_from + ".index"
ann.index = faiss.read_index(idx_path)
if ann.use_gpu:
ann.index = faiss.index_cpu_to_all_gpus(ann.index)
return ann
4 changes: 2 additions & 2 deletions cornac/models/ann/recom_ann_hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def knn_query(self, query, k):

def save(self, save_dir=None):
saved_path = super().save(save_dir)
self.index.save_index(saved_path + ".idx")
self.index.save_index(saved_path + ".index")
return saved_path

@staticmethod
Expand All @@ -144,7 +144,7 @@ def load(model_path, trainable=False):
ann.index = hnswlib.Index(
space=SUPPORTED_MEASURES[ann.measure], dim=ann.user_vectors.shape[1]
)
ann.index.load_index(ann.load_from + ".idx")
ann.index.load_index(ann.load_from + ".index")
ann.index.set_ef(ann.ef)
ann.index.set_num_threads(ann.num_threads)
return ann
11 changes: 7 additions & 4 deletions cornac/models/ann/recom_ann_scann.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from .recom_ann_base import BaseANN


SUPPORTED_MEASURES = {MEASURE_L2: "squared_l2", MEASURE_DOT: "dot_product"}
SUPPORTED_MEASURES = {
MEASURE_L2: "squared_l2",
MEASURE_DOT: "dot_product",
MEASURE_COSINE: "dot_product",
}


class ScaNNANN(BaseANN):
Expand Down Expand Up @@ -108,7 +112,6 @@ def build_index(self):
self.item_vectors /= np.linalg.norm(self.item_vectors, axis=1)[
:, np.newaxis
]
self.measure = MEASURE_DOT
else:
self.partition_params["spherical"] = False

Expand Down Expand Up @@ -149,7 +152,7 @@ def knn_query(self, query, k):

def save(self, save_dir=None):
saved_path = super().save(save_dir)
idx_path = saved_path + ".idx"
idx_path = saved_path + ".index"
os.makedirs(idx_path, exist_ok=True)
self.index.searcher.serialize(idx_path)
return saved_path
Expand All @@ -159,6 +162,6 @@ def load(model_path, trainable=False):
from scann.scann_ops.py import scann_ops_pybind

ann = BaseANN.load(model_path, trainable)
idx_path = ann.load_from + ".idx"
idx_path = ann.load_from + ".index"
ann.index = scann_ops_pybind.load_searcher(idx_path)
return ann
Loading

0 comments on commit c4f32da

Please sign in to comment.