From 1eadb8361fb6db01770560b9286129880459cd3d Mon Sep 17 00:00:00 2001 From: Quoc-Tuan Truong Date: Wed, 15 Nov 2023 12:19:01 -0800 Subject: [PATCH] Add ANN search using HNSWLib (#544) --- README.md | 9 + cornac/models/__init__.py | 1 + cornac/models/ann/__init__.py | 1 + cornac/models/ann/recom_ann_base.py | 147 +++++++++ cornac/models/ann/recom_ann_hnswlib.py | 155 +++++++++ cornac/models/mf/recom_mf.pyx | 53 ++- cornac/models/recommender.py | 69 +++- tutorials/README.md | 1 + tutorials/ann_hnswlib.ipynb | 431 +++++++++++++++++++++++++ 9 files changed, 861 insertions(+), 6 deletions(-) create mode 100644 cornac/models/ann/__init__.py create mode 100644 cornac/models/ann/recom_ann_base.py create mode 100644 cornac/models/ann/recom_ann_hnswlib.py create mode 100644 tutorials/ann_hnswlib.ipynb diff --git a/README.md b/README.md index 8f810d2ca..f66fc5262 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,15 @@ options: --port PORT service port ``` +## Efficient retrieval with ANN search + +One important aspect of deploying recommender model is efficient retrieval via Approximate Nearest Neighor (ANN) search in vector space. Cornac integrates several vector similarity search frameworks for the ease of deployment. [This example](tutorials/ann_hnswlib.ipynb) demonstrates how ANN search will work seamlessly with any recommender models supporting it (e.g., MF). + +| Supported framework | Cornac wrapper | Examples | +| :---: | :---: | :---: | +| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb) + + ## Models The recommender models supported by Cornac are listed below. Why don't you join us to lengthen the list? diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 91610481b..cb4c5570f 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -16,6 +16,7 @@ from .recommender import Recommender from .amr import AMR +from .ann import HNSWLibANN from .baseline_only import BaselineOnly from .bivaecf import BiVAECF from .bpr import BPR diff --git a/cornac/models/ann/__init__.py b/cornac/models/ann/__init__.py new file mode 100644 index 000000000..c89556c69 --- /dev/null +++ b/cornac/models/ann/__init__.py @@ -0,0 +1 @@ +from .recom_ann_hnswlib import HNSWLibANN diff --git a/cornac/models/ann/recom_ann_base.py b/cornac/models/ann/recom_ann_base.py new file mode 100644 index 000000000..e29270b0a --- /dev/null +++ b/cornac/models/ann/recom_ann_base.py @@ -0,0 +1,147 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import numpy as np + +from ..recommender import Recommender +from ..recommender import is_ann_supported + + +class BaseANN(Recommender): + """Base class for a recommender model supporting Approximate Nearest Neighbor (ANN) search. + + Parameters + ---------------- + model: object: :obj:`cornac.models.Recommender`, required + Trained recommender model which to get user/item vectors from. + + name: str, required + Name of the recommender model. + + verbose: boolean, optional, default: False + When True, running logs are displayed. + """ + + def __init__(self, model, name="BaseANN", verbose=False): + super().__init__(name=name, verbose=verbose, trainable=False) + + if not is_ann_supported(model): + raise ValueError(f"{model.name} doesn't support ANN search") + + # get basic attributes to be a proper recommender + super().fit(train_set=model.train_set, val_set=model.val_set) + + def build_index(self): + """Building index from the base recommender model. + + :raise NotImplementedError + """ + raise NotImplementedError() + + def knn_query(self, query, k): + """Implementing ANN search for a given query. + + Returns + ------- + :raise NotImplementedError + """ + raise NotImplementedError() + + def recommend(self, user_id, k=-1, remove_seen=False, train_set=None): + """Generate top-K item recommendations for a given user. Backward compatibility. + + Parameters + ---------- + user_id: str, required + The original ID of user. + + k: int, optional, default=-1 + Cut-off length for recommendations, k=-1 will return ranked list of all items. + + remove_seen: bool, optional, default: False + Remove seen/known items during training and validation from output recommendations. + This might shrink the list of recommendations to be less than k. + + train_set: :obj:`cornac.data.Dataset`, optional, default: None + Training dataset needs to be provided in order to remove seen items. + + Returns + ------- + recommendations: list + Recommended items in the form of their original IDs. + """ + assert isinstance(user_id, str) + return self.recommend_batch( + batch_users=[user_id], + k=k, + remove_seen=remove_seen, + train_set=train_set, + )[0] + + def recommend_batch(self, batch_users, k=-1, remove_seen=False, train_set=None): + """Generate top-K item recommendations for a given batch of users. This is to leverage + parallelization provided by some ANN frameworks. + + Parameters + ---------- + batch_users: list, required + The original ID of users. + + k: int, optional, default=-1 + Cut-off length for recommendations, k=-1 will return ranked list of all items. + + remove_seen: bool, optional, default: False + Remove seen/known items during training and validation from output recommendations. + This might shrink the list of recommendations to be less than k. + + train_set: :obj:`cornac.data.Dataset`, optional, default: None + Training dataset needs to be provided in order to remove seen items. + + Returns + ------- + recommendations: list + Recommended items in the form of their original IDs. + """ + user_idx = [self.uid_map.get(uid, -1) for uid in batch_users] + + if any(i == -1 for i in user_idx): + raise ValueError(f"{batch_users} is unknown to the model.") + + if k < -1 or k > self.total_items: + raise ValueError( + f"k={k} is invalid, there are {self.total_users} users in total." + ) + + query = self.user_vectors[user_idx] + knn_items, distances = self.knn_query(query, k=k) + + if remove_seen: + if train_set is None: + raise ValueError("train_set must be provided to remove seen items.") + filtered_knn_items = [] + for u, i in zip(user_idx, knn_items): + if u >= train_set.csr_matrix.shape[0]: + continue + seen_mask = np.in1d( + np.arange(i.size), train_set.csr_matrix.getrow(u).indices + ) + filtered_knn_items.append(i[~seen_mask]) + knn_items = filtered_knn_items + + recommendations = [ + [self.item_ids[i] for i in knn_items[u]] for u in range(len(user_idx)) + ] + return recommendations diff --git a/cornac/models/ann/recom_ann_hnswlib.py b/cornac/models/ann/recom_ann_hnswlib.py new file mode 100644 index 000000000..5f960d628 --- /dev/null +++ b/cornac/models/ann/recom_ann_hnswlib.py @@ -0,0 +1,155 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import sys +import random +import multiprocessing +import numpy as np + +from ..recommender import MEASURE_L2, MEASURE_DOT, MEASURE_COSINE +from .recom_ann_base import BaseANN + + +SUPPORTED_MEASURES = { + MEASURE_L2: "l2", + MEASURE_DOT: "ip", + MEASURE_COSINE: "cosine", +} + + +class HNSWLibANN(BaseANN): + """Approximate Nearest Neighbor Search with HNSWLib (https://github.com/nmslib/hnswlib/). + + Parameters + ---------------- + model: object: :obj:`cornac.models.Recommender`, required + Trained recommender model which to get user/item vectors from. + + M: int, optional, default: 16 + Parameter that defines the maximum number of outgoing connections in the HNSW graph. + Higher M leads to higher accuracy/run_time at fixed ef/ef_construction. Reasonable range + for M is 2-100. Higher M work better on model with high dimensional factors, while low M + work better for low dimensional factors. More details: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md. + + ef_construction: int, optional, default: 100 + Parameter that controls speed/accuracy trade-off during the index construction. Bigger ef_construction leads to longer construction, but better index quality. At some point, + increasing ef_construction does not improve the quality of the index. + + ef: int, optional, default: 50 + Parameter controlling query time/accuracy trade-off. Higher `ef` leads to more accurate but + slower search. `ef` cannot be set lower than the number of queried nearest neighbors k. The + value of `ef` can be anything between `k` and the total number of items. + + num_threads: int, optional, default: -1 + Default number of threads to use when querying. If num_threads = -1, all cores will be used. + + seed: int, optional, default: None + Random seed for reproducibility. + + name: str, required + Name of the recommender model. + + verbose: boolean, optional, default: False + When True, running logs are displayed. + """ + + def __init__( + self, + model, + M=16, + ef_construction=100, + ef=50, + num_threads=-1, + seed=None, + name="HNSWLibANN", + verbose=False, + ): + super().__init__(model=model, name=name, verbose=verbose) + self.M = M + self.ef_construction = ef_construction + self.ef = ef + self.num_threads = ( + num_threads if num_threads != -1 else multiprocessing.cpu_count() + ) + self.seed = seed + + # ANN required attributes + self.measure = model.get_vector_measure() + self.user_vectors = model.get_user_vectors() + self.item_vectors = model.get_item_vectors() + + self.index = None + self.ignored_attrs.extend( + [ + "index", # will be saved separately + "item_vectors", # redundant after index is built + ] + ) + + def build_index(self): + """Building index from the base recommender model.""" + import hnswlib + + assert self.measure in SUPPORTED_MEASURES + + self.index = hnswlib.Index( + space=SUPPORTED_MEASURES[self.measure], dim=self.item_vectors.shape[1] + ) + + self.index.init_index( + max_elements=self.item_vectors.shape[0], + ef_construction=self.ef_construction, + M=self.M, + random_seed=( + np.random.randint(sys.maxsize) if self.seed is None else self.seed + ), + ) + self.index.add_items( + data=self.item_vectors, + ids=np.arange(self.item_vectors.shape[0]), + num_threads=(-1 if self.seed is None else 1), + ) + self.index.set_ef(self.ef) + self.index.set_num_threads(self.num_threads) + + def knn_query(self, query, k): + """Implementing ANN search for a given query. + + Returns + ------- + neighbors, distances: numpy.array and numpy.array + Array of k-nearest neighbors and corresponding distances for the given query. + """ + neighbors, distances = self.index.knn_query(query, k=k) + return neighbors, distances + + def save(self, save_dir=None): + saved_path = super().save(save_dir) + self.index.save_index(saved_path + ".idx") + return saved_path + + @staticmethod + def load(model_path, trainable=False): + import hnswlib + + ann = BaseANN.load(model_path, trainable) + ann.index = hnswlib.Index( + space=SUPPORTED_MEASURES[ann.measure], dim=ann.user_vectors.shape[1] + ) + ann.index.load_index(ann.load_from + ".idx") + ann.index.set_ef(ann.ef) + ann.index.set_num_threads(ann.num_threads) + return ann diff --git a/cornac/models/mf/recom_mf.pyx b/cornac/models/mf/recom_mf.pyx index 6fc9c316e..7b54a30a9 100644 --- a/cornac/models/mf/recom_mf.pyx +++ b/cornac/models/mf/recom_mf.pyx @@ -28,6 +28,7 @@ cimport numpy as np from tqdm.auto import trange from ..recommender import Recommender +from ..recommender import ANNMixin, MEASURE_DOT from ...exception import ScoreException from ...utils import fast_dot from ...utils import get_rng @@ -35,7 +36,7 @@ from ...utils.init_utils import normal, zeros -class MF(Recommender): +class MF(Recommender, ANNMixin): """Matrix Factorization. Parameters @@ -269,3 +270,53 @@ class MF(Recommender): raise ScoreException("Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx)) item_score = np.dot(self.u_factors[user_idx], self.i_factors[item_idx]) return item_score + + def get_vector_measure(self): + """Getting a valid choice of vector measurement in ANNMixin._measures. + + Returns + ------- + measure: MEASURE_DOT + Dot product aka. inner product + """ + return MEASURE_DOT + + def get_user_vectors(self): + """Getting a matrix of user vectors serving as query for ANN search. + + Returns + ------- + out: numpy.array + Matrix of user vectors for all users available in the model. + """ + user_vectors = self.u_factors + if self.use_bias: + user_vectors = np.concatenate( + ( + user_vectors, + self.u_biases.reshape((-1, 1)), + np.ones([user_vectors.shape[0], 1]), # augmented for item bias + ), + axis=1 + ) + return user_vectors + + def get_item_vectors(self): + """Getting a matrix of item vectors used for building the index for ANN search. + + Returns + ------- + out: numpy.array + Matrix of item vectors for all items available in the model. + """ + item_vectors = self.i_factors + if self.use_bias: + item_vectors = np.concatenate( + ( + item_vectors, + np.ones([item_vectors.shape[0], 1]), # augmented for user bias + self.i_biases.reshape((-1, 1)), + ), + axis=1 + ) + return item_vectors \ No newline at end of file diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py index 564e9eec9..c81d7f9f3 100644 --- a/cornac/models/recommender.py +++ b/cornac/models/recommender.py @@ -47,15 +47,15 @@ class Recommender: num_items: int Number of items in training data. - + total_users: int - Number of users in training, validation, and test data. + Number of users in training, validation, and test data. In other words, this includes unknown/unseen users. total_items: int - Number of items in training, validation, and test data. + Number of items in training, validation, and test data. In other words, this includes unknown/unseen items. - + uid_map: int Global mapping of user ID-index. @@ -77,7 +77,8 @@ def __init__(self, name, trainable=True, verbose=False): self.trainable = trainable self.verbose = verbose - self.ignored_attrs = [] # attributes to be ignored when saving model + # attributes to be ignored when saving model + self.ignored_attrs = ["train_set", "val_set", "test_set"] # useful information getting from train_set for prediction self.num_users = None @@ -249,6 +250,10 @@ def fit(self, train_set, val_set=None): self.max_rating = train_set.max_rating self.global_mean = train_set.global_mean + # just for future wrapper to call fit(), not supposed to be used during prediction + self.train_set = train_set + self.val_set = val_set + return self def knows_user(self, user_idx): @@ -511,3 +516,57 @@ def early_stop(self, train_set, val_set, min_delta=0.0, patience=0): ) return True return False + + +MEASURE_L2 = "l2 distance aka. Euclidean distance" +MEASURE_DOT = "dot product aka. inner product" +MEASURE_COSINE = "cosine similarity" + + +class ANNMixin: + """Mixin class for Approximate Nearest Neighbor Search.""" + + _ann_supported = True + + def get_vector_measure(self): + """Getting a valid choice of vector measurement in ANNMixin._measures. + + Returns + ------- + :raise NotImplementedError + """ + raise NotImplementedError() + + def get_user_vectors(self): + """Getting a matrix of user vectors serving as query for ANN search. + + Returns + ------- + :raise NotImplementedError + """ + raise NotImplementedError() + + def get_item_vectors(self): + """Getting a matrix of item vectors used for building the index for ANN search. + + Returns + ------- + :raise NotImplementedError + """ + raise NotImplementedError() + + +def is_ann_supported(recom): + """Return True if the given recommender model support ANN search. + + Parameters + ---------- + recom : recommender model + Recommender object to test. + + Returns + ------- + out : bool + True if recom supports ANN search and False otherwise. + """ + return getattr(recom, "_ann_supported", False) diff --git a/tutorials/README.md b/tutorials/README.md index 4c3f4e0c3..efbb62bef 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -7,6 +7,7 @@ If you are new to Cornac, the [Getting Started](#getting-started) tutorials are - [Installation](../README.md#installation) - [Your first Cornac experiment](../README.md#getting-started-your-first-cornac-experiment) - [Hyperparameter search for VAECF](./param_search_vaecf.ipynb) +- [Approximate nearest neighbor (ANN) search](./ann_hnswlib.ipynb) - [Introduction to BPR, using Cornac with Microsoft Recommenders](https://github.com/microsoft/recommenders/blob/main/examples/02_model_collaborative_filtering/cornac_bpr_deep_dive.ipynb) - [Tutorials on recommender systems by Preferred.AI](https://github.com/PreferredAI/tutorials/tree/master/recommender-systems) - [BiVAE model with Microsoft Recommenders](https://github.com/microsoft/recommenders/blob/main/examples/02_model_collaborative_filtering/cornac_bivae_deep_dive.ipynb) diff --git a/tutorials/ann_hnswlib.ipynb b/tutorials/ann_hnswlib.ipynb new file mode 100644 index 000000000..583b929cd --- /dev/null +++ b/tutorials/ann_hnswlib.ipynb @@ -0,0 +1,431 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "b9a4225b-1a05-4b58-9e1d-1511650ef225", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: hnswlib in /opt/conda/lib/python3.10/site-packages (0.7.0)\n", + "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from hnswlib) (1.26.0)\n" + ] + } + ], + "source": [ + "!pip install hnswlib" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "74a9e78f-3e8a-4ee2-89fe-b3a3f4784b53", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import hnswlib\n", + "import cornac\n", + "from cornac.data import Reader\n", + "from cornac.datasets import netflix\n", + "from cornac.eval_methods import RatioSplit\n", + "from cornac.models import MF, HNSWLibANN" + ] + }, + { + "cell_type": "markdown", + "id": "cf6bb9a5-ffb5-4221-8122-9aa286af1d9c", + "metadata": {}, + "source": [ + "## Recommender model training\n", + "\n", + "The following experiment shows how to perform ANN-search within Cornac. First, we need to train a model that supports ANN search. Here we choose MF for simple illustration purpose. Other models that support ANN search should work in a similar fashion." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76a0c130-7dd7-4004-a613-5b123dcc75d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rating_threshold = 1.0\n", + "exclude_unknowns = True\n", + "---\n", + "Training data:\n", + "Number of users = 9986\n", + "Number of items = 4921\n", + "Number of ratings = 547022\n", + "Max rating = 1.0\n", + "Min rating = 1.0\n", + "Global mean = 1.0\n", + "---\n", + "Test data:\n", + "Number of users = 9986\n", + "Number of items = 4921\n", + "Number of ratings = 60747\n", + "Number of unknown users = 0\n", + "Number of unknown items = 0\n", + "---\n", + "Total users = 9986\n", + "Total items = 4921\n", + "\n", + "[MF] Training started!\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e588c08cad71410aae17e2cc84456631", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/25 [00:00