diff --git a/README.md b/README.md index b0500abb3..f66fc5262 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ options: ## Efficient retrieval with ANN search -One important aspect of deploying recommender model is efficient retrieval via Approximate Nearest Neighor (ANN) search in vector space. Cornac integrates several vector similarity search frameworks for the ease of deployment. [This example](tutorials/ann_hnswlib.ipynb) demonstrates how ANN search will work seamlessly as other recommender models. +One important aspect of deploying recommender model is efficient retrieval via Approximate Nearest Neighor (ANN) search in vector space. Cornac integrates several vector similarity search frameworks for the ease of deployment. [This example](tutorials/ann_hnswlib.ipynb) demonstrates how ANN search will work seamlessly with any recommender models supporting it (e.g., MF). | Supported framework | Cornac wrapper | Examples | | :---: | :---: | :---: | diff --git a/cornac/models/ann/recom_ann_base.py b/cornac/models/ann/recom_ann_base.py index defea3971..62ec031c3 100644 --- a/cornac/models/ann/recom_ann_base.py +++ b/cornac/models/ann/recom_ann_base.py @@ -41,11 +41,6 @@ def __init__(self, recom, name="BaseANN", verbose=False): if not is_ann_supported(recom): raise ValueError(f"{recom.name} doesn't support ANN search") - # ANN required attributes - self.measure = recom.get_vector_measure() - self.user_vectors = recom.get_user_vectors() - self.item_vectors = recom.get_item_vectors() - # get basic attributes to be a proper recommender super().fit(train_set=recom.train_set, val_set=recom.val_set) diff --git a/cornac/models/ann/recom_ann_hnswlib.py b/cornac/models/ann/recom_ann_hnswlib.py index 9d4720817..0d06f4d27 100644 --- a/cornac/models/ann/recom_ann_hnswlib.py +++ b/cornac/models/ann/recom_ann_hnswlib.py @@ -15,6 +15,7 @@ import sys +import random import multiprocessing import numpy as np @@ -45,7 +46,9 @@ class HNSWLibANN(BaseANN): Parameter that controls speed/accuracy trade-off during the index construction. ef: int, optional, default: 50 - Parameter controlling query time/accuracy trade-off. + Parameter controlling query time/accuracy trade-off. Higher `ef` leads to more accurate but + slower search. `ef` cannot be set lower than the number of queried nearest neighbors k. The + value of `ef` can be anything between `k` and the total number of items. num_threads: int, optional, default: -1 Default number of threads to use when querying. If num_threads = -1, all cores will be used. @@ -80,6 +83,11 @@ def __init__( ) self.seed = seed + # ANN required attributes + self.measure = recom.get_vector_measure() + self.user_vectors = recom.get_user_vectors() + self.item_vectors = recom.get_item_vectors() + self.index = None self.ignored_attrs.extend( [ @@ -97,17 +105,20 @@ def build_index(self): self.index = hnswlib.Index( space=SUPPORTED_MEASURES[self.measure], dim=self.item_vectors.shape[1] ) - random_seed = ( - self.seed if self.seed is not None else np.random.randint(sys.maxsize) - ) + self.index.init_index( max_elements=self.item_vectors.shape[0], ef_construction=self.ef_construction, M=self.M, - random_seed=random_seed, + random_seed=( + np.random.randint(sys.maxsize) if self.seed is None else self.seed + ), + ) + self.index.add_items( + data=self.item_vectors, + ids=np.arange(self.item_vectors.shape[0]), + num_threads=(-1 if self.seed is None else 1), ) - self.index.add_items(self.item_vectors, np.arange(self.item_vectors.shape[0])) - self.index.set_ef(self.ef) self.index.set_num_threads(self.num_threads) @@ -123,17 +134,19 @@ def knn_query(self, query, k): return neighbors, distances def save(self, save_dir=None): - model_file = super().save(save_dir) - self.index.save_index(model_file + ".idx") - return model_file + saved_path = super().save(save_dir) + self.index.save_index(saved_path + ".idx") + return saved_path @staticmethod def load(model_path, trainable=False): import hnswlib - model = BaseANN.load(model_path, trainable) - model.index = hnswlib.Index( - space=SUPPORTED_MEASURES[model.measure], dim=model.user_vectors.shape[1] + ann = BaseANN.load(model_path, trainable) + ann.index = hnswlib.Index( + space=SUPPORTED_MEASURES[ann.measure], dim=ann.user_vectors.shape[1] ) - model.index.load_index(model.load_from + ".idx") - return model + ann.index.load_index(ann.load_from + ".idx") + ann.index.set_ef(ann.ef) + ann.index.set_num_threads(ann.num_threads) + return ann diff --git a/cornac/models/mf/recom_mf.pyx b/cornac/models/mf/recom_mf.pyx index 4e2a3c1e4..7b54a30a9 100644 --- a/cornac/models/mf/recom_mf.pyx +++ b/cornac/models/mf/recom_mf.pyx @@ -282,7 +282,7 @@ class MF(Recommender, ANNMixin): return MEASURE_DOT def get_user_vectors(self): - """Getting a matrix of user vectors served as query for ANN search. + """Getting a matrix of user vectors serving as query for ANN search. Returns ------- @@ -302,7 +302,7 @@ class MF(Recommender, ANNMixin): return user_vectors def get_item_vectors(self): - """Getting a matrix of item vectors used for building index for ANN search. + """Getting a matrix of item vectors used for building the index for ANN search. Returns ------- diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py index b796de52a..c81d7f9f3 100644 --- a/cornac/models/recommender.py +++ b/cornac/models/recommender.py @@ -538,7 +538,7 @@ def get_vector_measure(self): raise NotImplementedError() def get_user_vectors(self): - """Getting a matrix of user vectors served as query for ANN search. + """Getting a matrix of user vectors serving as query for ANN search. Returns ------- @@ -547,7 +547,7 @@ def get_user_vectors(self): raise NotImplementedError() def get_item_vectors(self): - """Getting a matrix of item vectors used for building index for ANN search. + """Getting a matrix of item vectors used for building the index for ANN search. Returns ------- diff --git a/tutorials/ann_hnswlib.ipynb b/tutorials/ann_hnswlib.ipynb index 980f80627..ff32bb537 100644 --- a/tutorials/ann_hnswlib.ipynb +++ b/tutorials/ann_hnswlib.ipynb @@ -59,22 +59,22 @@ "exclude_unknowns = True\n", "---\n", "Training data:\n", - "Number of users = 9985\n", - "Number of items = 4926\n", + "Number of users = 9986\n", + "Number of items = 4921\n", "Number of ratings = 547022\n", "Max rating = 1.0\n", "Min rating = 1.0\n", "Global mean = 1.0\n", "---\n", "Test data:\n", - "Number of users = 9985\n", - "Number of items = 4926\n", - "Number of ratings = 60748\n", + "Number of users = 9986\n", + "Number of items = 4921\n", + "Number of ratings = 60747\n", "Number of unknown users = 0\n", "Number of unknown items = 0\n", "---\n", - "Total users = 9985\n", - "Total items = 4926\n", + "Total users = 9986\n", + "Total items = 4921\n", "\n", "[MF] Training started!\n" ] @@ -82,7 +82,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8136f6f5bca44cde89e682ab2a62dd24", + "model_id": "d19c59bb6f934859aacdae530daeb020", "version_major": 2, "version_minor": 0 }, @@ -105,7 +105,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d917bf14c68548a490d8b9f38562320b", + "model_id": "7e8a119273404c099b8344534a702059", "version_major": 2, "version_minor": 0 }, @@ -125,7 +125,7 @@ "...\n", " | AUC | Recall@20 | Train (s) | Test (s)\n", "-- + ------ + --------- + --------- + --------\n", - "MF | 0.8517 | 0.0440 | 1.1022 | 11.8385\n", + "MF | 0.8530 | 0.0669 | 1.0909 | 11.7523\n", "\n" ] } @@ -139,6 +139,7 @@ " rating_threshold=1.0,\n", " exclude_unknowns=True,\n", " verbose=True,\n", + " seed=123,\n", ")\n", "\n", "mf = MF(\n", @@ -169,7 +170,7 @@ "source": [ "## Building index for ANN recommender\n", "\n", - "After MF model is trained, we need to wrap it with an ANN recommender. We employ Cornac built-in HNSWLibANN which implements [HNSW algorithm](https://arxiv.org/abs/1603.09320) for building index and doing approximate K-nearest neighbor search. More on how to tune the hyper-parameters at https://github.com/nmslib/hnswlib." + "After MF model is trained, we need to wrap it with an ANN recommender. We employ Cornac built-in HNSWLibANN which implements [HNSW algorithm](https://arxiv.org/abs/1603.09320) for building index and doing approximate K-nearest neighbor search. More on how to tune the hyper-parameters at https://github.com/nmslib/hnswlib and https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md." ] }, { @@ -184,6 +185,7 @@ " M=16,\n", " ef_construction=100,\n", " ef=50,\n", + " seed=123,\n", " num_threads=-1,\n", ")\n", "ann.build_index()" @@ -214,9 +216,7 @@ "source": [ "K = 20\n", "N = 10000\n", - "test_users = np.random.choice(mf.user_ids, size=N)\n", - "mf_recs = []\n", - "ann_recs = []" + "test_users = np.random.RandomState(123).choice(mf.user_ids, size=N)" ] }, { @@ -229,13 +229,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2min 40s, sys: 19.5 ms, total: 2min 40s\n", - "Wall time: 5.04 s\n" + "CPU times: user 2min 40s, sys: 15.4 ms, total: 2min 40s\n", + "Wall time: 5.02 s\n" ] } ], "source": [ "%%time\n", + "\n", + "mf_recs = []\n", "for uid in test_users:\n", " mf_recs.append(mf.recommend(uid, k=K))" ] @@ -250,13 +252,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 312 ms, sys: 31 µs, total: 312 ms\n", - "Wall time: 309 ms\n" + "CPU times: user 288 ms, sys: 32 µs, total: 288 ms\n", + "Wall time: 285 ms\n" ] } ], "source": [ "%%time\n", + "\n", + "ann_recs = []\n", "for uid in test_users:\n", " ann_recs.append(ann.recommend(uid, k=K))" ] @@ -266,7 +270,7 @@ "id": "e2f4f68a-c69b-4e70-b32a-81eb78d21279", "metadata": {}, "source": [ - "While it took MF 5.04s to complete the task, it's only 309ms for ANN. The speed up is about 15 times. Note that our dataset contains less than 5000 items. We will see an even bigger improvement with more items and with higher dimensions of the factors." + "While it took MF 5.02s to complete the task, it's only 285ms for ANN. The speed up is about 17 times. Note that our dataset contains less than 5000 items. We will see an even bigger improvement with more items and with high dimensional factors." ] }, { @@ -279,7 +283,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "99.75699999999999\n" + "99.87549999999999\n" ] } ], @@ -315,7 +319,7 @@ { "data": { "text/plain": [ - "'save_dir/HNSWLibANN/2023-11-09_06-15-10-660690.pkl'" + "'save_dir/HNSWLibANN/2023-11-14_00-01-22-869313.pkl'" ] }, "execution_count": 9, @@ -368,6 +372,39 @@ " loaded_ann.recommend_batch(test_users[:5], k=K),\n", ")" ] + }, + { + "cell_type": "markdown", + "id": "1699eac6-8b72-4dbc-87dd-94e68080bbb8", + "metadata": {}, + "source": [ + "One more test, the loaded ANN should achieve the same recall as the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7982d439-aadd-434a-8dd3-01d462486350", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "99.87549999999999\n" + ] + } + ], + "source": [ + "loaded_ann_recs = []\n", + "for uid in test_users:\n", + " loaded_ann_recs.append(loaded_ann.recommend(uid, k=K))\n", + " \n", + "recalls = []\n", + "for mf_rec, ann_rec in zip(mf_recs, loaded_ann_recs):\n", + " recalls.append(len(set(mf_rec) & set(ann_rec)) / len(mf_rec))\n", + "print(np.mean(recalls) * 100.0)" + ] } ], "metadata": {