diff --git a/README.md b/README.md index 3ed7e447a..aa08344e6 100644 --- a/README.md +++ b/README.md @@ -134,7 +134,8 @@ One important aspect of deploying recommender model is efficient retrieval via A | Supported framework | Cornac wrapper | Examples | | :---: | :---: | :---: | -| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb) +| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb) +| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_all.ipynb](examples/ann_all.ipynb) ## Models diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index d0521d33b..fb316f330 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -18,6 +18,7 @@ from .amr import AMR from .ann import HNSWLibANN +from .ann import ScaNNANN from .baseline_only import BaselineOnly from .bivaecf import BiVAECF from .bpr import BPR diff --git a/cornac/models/ann/__init__.py b/cornac/models/ann/__init__.py index c89556c69..77fd8630b 100644 --- a/cornac/models/ann/__init__.py +++ b/cornac/models/ann/__init__.py @@ -1 +1,2 @@ from .recom_ann_hnswlib import HNSWLibANN +from .recom_ann_scann import ScaNNANN diff --git a/cornac/models/ann/recom_ann_base.py b/cornac/models/ann/recom_ann_base.py index e29270b0a..c256a721e 100644 --- a/cornac/models/ann/recom_ann_base.py +++ b/cornac/models/ann/recom_ann_base.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ - +import copy import numpy as np from ..recommender import Recommender @@ -41,6 +41,11 @@ def __init__(self, model, name="BaseANN", verbose=False): if not is_ann_supported(model): raise ValueError(f"{model.name} doesn't support ANN search") + # ANN required attributes + self.measure = copy.deepcopy(model.get_vector_measure()) + self.user_vectors = copy.deepcopy(model.get_user_vectors()) + self.item_vectors = copy.deepcopy(model.get_item_vectors()) + # get basic attributes to be a proper recommender super().fit(train_set=model.train_set, val_set=model.val_set) diff --git a/cornac/models/ann/recom_ann_hnswlib.py b/cornac/models/ann/recom_ann_hnswlib.py index 5f960d628..ff4e68547 100644 --- a/cornac/models/ann/recom_ann_hnswlib.py +++ b/cornac/models/ann/recom_ann_hnswlib.py @@ -86,11 +86,6 @@ def __init__( ) self.seed = seed - # ANN required attributes - self.measure = model.get_vector_measure() - self.user_vectors = model.get_user_vectors() - self.item_vectors = model.get_item_vectors() - self.index = None self.ignored_attrs.extend( [ diff --git a/cornac/models/ann/recom_ann_scann.py b/cornac/models/ann/recom_ann_scann.py new file mode 100644 index 000000000..af116f1af --- /dev/null +++ b/cornac/models/ann/recom_ann_scann.py @@ -0,0 +1,164 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import os +import multiprocessing +import numpy as np + +from ..recommender import MEASURE_L2, MEASURE_DOT, MEASURE_COSINE +from .recom_ann_base import BaseANN + + +SUPPORTED_MEASURES = {MEASURE_L2: "squared_l2", MEASURE_DOT: "dot_product"} + + +class ScaNNANN(BaseANN): + """Approximate Nearest Neighbor Search with ScaNN + (https://github.com/google-research/google-research/tree/master/scann). + ScaNN performs vector search in three phases: paritioning, scoring, and rescoring. + More on the algorithms and parameter description: https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md + + Parameters + ---------------- + model: object: :obj:`cornac.models.Recommender`, required + Trained recommender model which to get user/item vectors from. + + partition_params: dict, optional + Parameters for the partitioning phase, to send to the tree() call in ScaNN. + + score_params: dict, optional + Parameters for the scoring phase, to send to the score_ah() call in ScaNN. + score_brute_force() will be called if score_brute_force is True. + + score_brute_force: bool, optional, default: False + Whether to call score_brute_force() for the scoring phase. + + rescore_params: dict, optional + Parameters for the rescoring phase, to send to the reorder() call in ScaNN. + + num_threads: int, optional, default: -1 + Default number of threads used for training. If num_threads = -1, all cores will be used. + + seed: int, optional, default: None + Random seed for reproducibility. + + name: str, required + Name of the recommender model. + + verbose: boolean, optional, default: False + When True, running logs are displayed. + """ + + def __init__( + self, + model, + partition_params=None, + score_params=None, + score_brute_force=False, + rescore_params=None, + num_threads=-1, + seed=None, + name="ScaNNANN", + verbose=False, + ): + super().__init__(model=model, name=name, verbose=verbose) + + if score_params is None: + score_params = {} + + self.model = model + self.partition_params = partition_params + self.score_params = score_params + self.score_brute_force = score_brute_force + self.rescore_params = rescore_params + self.num_threads = ( + num_threads if num_threads != -1 else multiprocessing.cpu_count() + ) + self.seed = seed + + self.index = None + self.ignored_attrs.extend( + [ + "index", # will be saved separately + "item_vectors", # redundant after index is built + ] + ) + + def build_index(self): + """Building index from the base recommender model.""" + import scann + + assert self.measure in SUPPORTED_MEASURES + + if self.measure == MEASURE_COSINE: + self.partition_params["spherical"] = True + self.item_vectors /= np.linalg.norm(self.item_vectors, axis=1)[ + :, np.newaxis + ] + self.measure = MEASURE_DOT + else: + self.partition_params["spherical"] = False + + index_builder = scann.scann_ops_pybind.builder( + self.item_vectors, 10, SUPPORTED_MEASURES[self.measure] + ) + index_builder.set_n_training_threads(self.num_threads) + + # partitioning + if self.partition_params: + self.partition_params.setdefault( + "training_sample_size", self.item_vectors.shape[0] + ) + index_builder = index_builder.tree(**self.partition_params) + + # scoring + if self.score_brute_force: + index_builder = index_builder.score_brute_force(**self.score_params) + else: + index_builder = index_builder.score_ah(**self.score_params) + + # rescoring + if self.rescore_params: + index_builder = index_builder.reorder(**self.rescore_params) + + self.index = index_builder.build() + + def knn_query(self, query, k): + """Implementing ANN search for a given query. + + Returns + ------- + neighbors, distances: numpy.array and numpy.array + Array of k-nearest neighbors and corresponding distances for the given query. + """ + neighbors, distances = self.index.search_batched(query, final_num_neighbors=k) + return neighbors, distances + + def save(self, save_dir=None): + saved_path = super().save(save_dir) + idx_path = saved_path + ".idx" + os.makedirs(idx_path, exist_ok=True) + self.index.searcher.serialize(idx_path) + return saved_path + + @staticmethod + def load(model_path, trainable=False): + from scann.scann_ops.py import scann_ops_pybind + + ann = BaseANN.load(model_path, trainable) + idx_path = ann.load_from + ".idx" + ann.index = scann_ops_pybind.load_searcher(idx_path) + return ann diff --git a/examples/ann_all.ipynb b/examples/ann_all.ipynb new file mode 100644 index 000000000..856adccdb --- /dev/null +++ b/examples/ann_all.ipynb @@ -0,0 +1,338 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "b9a4225b-1a05-4b58-9e1d-1511650ef225", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q hnswlib scann" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "74a9e78f-3e8a-4ee2-89fe-b3a3f4784b53", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import cornac\n", + "from cornac.data import Reader\n", + "from cornac.datasets import netflix\n", + "from cornac.eval_methods import RatioSplit\n", + "from cornac.models import MF" + ] + }, + { + "cell_type": "markdown", + "id": "cf6bb9a5-ffb5-4221-8122-9aa286af1d9c", + "metadata": {}, + "source": [ + "## Train a base recommender model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76a0c130-7dd7-4004-a613-5b123dcc75d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rating_threshold = 1.0\n", + "exclude_unknowns = True\n", + "---\n", + "Training data:\n", + "Number of users = 9986\n", + "Number of items = 4921\n", + "Number of ratings = 547022\n", + "Max rating = 1.0\n", + "Min rating = 1.0\n", + "Global mean = 1.0\n", + "---\n", + "Test data:\n", + "Number of users = 9986\n", + "Number of items = 4921\n", + "Number of ratings = 60747\n", + "Number of unknown users = 0\n", + "Number of unknown items = 0\n", + "---\n", + "Total users = 9986\n", + "Total items = 4921\n", + "\n", + "[MF] Training started!\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e984dd7e18d74f0090247ab9e8247797", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/25 [00:00