Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Nov 14, 2023
1 parent 8f2527f commit ad05429
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ options:

## Efficient retrieval with ANN search

One important aspect of deploying recommender model is efficient retrieval via Approximate Nearest Neighor (ANN) search in vector space. Cornac integrates several vector similarity search frameworks for the ease of deployment. [This example](tutorials/ann_hnswlib.ipynb) demonstrates how ANN search will work seamlessly as other recommender models.
One important aspect of deploying recommender model is efficient retrieval via Approximate Nearest Neighor (ANN) search in vector space. Cornac integrates several vector similarity search frameworks for the ease of deployment. [This example](tutorials/ann_hnswlib.ipynb) demonstrates how ANN search will work seamlessly with any recommender models supporting it (e.g., MF).

| Supported framework | Cornac wrapper | Examples |
| :---: | :---: | :---: |
Expand Down
5 changes: 0 additions & 5 deletions cornac/models/ann/recom_ann_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ def __init__(self, recom, name="BaseANN", verbose=False):
if not is_ann_supported(recom):
raise ValueError(f"{recom.name} doesn't support ANN search")

# ANN required attributes
self.measure = recom.get_vector_measure()
self.user_vectors = recom.get_user_vectors()
self.item_vectors = recom.get_item_vectors()

# get basic attributes to be a proper recommender
super().fit(train_set=recom.train_set, val_set=recom.val_set)

Expand Down
43 changes: 28 additions & 15 deletions cornac/models/ann/recom_ann_hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import sys
import random
import multiprocessing
import numpy as np

Expand Down Expand Up @@ -45,7 +46,9 @@ class HNSWLibANN(BaseANN):
Parameter that controls speed/accuracy trade-off during the index construction.
ef: int, optional, default: 50
Parameter controlling query time/accuracy trade-off.
Parameter controlling query time/accuracy trade-off. Higher `ef` leads to more accurate but
slower search. `ef` cannot be set lower than the number of queried nearest neighbors k. The
value of `ef` can be anything between `k` and the total number of items.
num_threads: int, optional, default: -1
Default number of threads to use when querying. If num_threads = -1, all cores will be used.
Expand Down Expand Up @@ -80,6 +83,11 @@ def __init__(
)
self.seed = seed

# ANN required attributes
self.measure = recom.get_vector_measure()
self.user_vectors = recom.get_user_vectors()
self.item_vectors = recom.get_item_vectors()

self.index = None
self.ignored_attrs.extend(
[
Expand All @@ -97,17 +105,20 @@ def build_index(self):
self.index = hnswlib.Index(
space=SUPPORTED_MEASURES[self.measure], dim=self.item_vectors.shape[1]
)
random_seed = (
self.seed if self.seed is not None else np.random.randint(sys.maxsize)
)

self.index.init_index(
max_elements=self.item_vectors.shape[0],
ef_construction=self.ef_construction,
M=self.M,
random_seed=random_seed,
random_seed=(
np.random.randint(sys.maxsize) if self.seed is None else self.seed
),
)
self.index.add_items(
data=self.item_vectors,
ids=np.arange(self.item_vectors.shape[0]),
num_threads=(-1 if self.seed is None else 1),
)
self.index.add_items(self.item_vectors, np.arange(self.item_vectors.shape[0]))

self.index.set_ef(self.ef)
self.index.set_num_threads(self.num_threads)

Expand All @@ -123,17 +134,19 @@ def knn_query(self, query, k):
return neighbors, distances

def save(self, save_dir=None):
model_file = super().save(save_dir)
self.index.save_index(model_file + ".idx")
return model_file
saved_path = super().save(save_dir)
self.index.save_index(saved_path + ".idx")
return saved_path

@staticmethod
def load(model_path, trainable=False):
import hnswlib

model = BaseANN.load(model_path, trainable)
model.index = hnswlib.Index(
space=SUPPORTED_MEASURES[model.measure], dim=model.user_vectors.shape[1]
ann = BaseANN.load(model_path, trainable)
ann.index = hnswlib.Index(
space=SUPPORTED_MEASURES[ann.measure], dim=ann.user_vectors.shape[1]
)
model.index.load_index(model.load_from + ".idx")
return model
ann.index.load_index(ann.load_from + ".idx")
ann.index.set_ef(ann.ef)
ann.index.set_num_threads(ann.num_threads)
return ann
4 changes: 2 additions & 2 deletions cornac/models/mf/recom_mf.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class MF(Recommender, ANNMixin):
return MEASURE_DOT

def get_user_vectors(self):
"""Getting a matrix of user vectors served as query for ANN search.
"""Getting a matrix of user vectors serving as query for ANN search.
Returns
-------
Expand All @@ -302,7 +302,7 @@ class MF(Recommender, ANNMixin):
return user_vectors

def get_item_vectors(self):
"""Getting a matrix of item vectors used for building index for ANN search.
"""Getting a matrix of item vectors used for building the index for ANN search.
Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions cornac/models/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def get_vector_measure(self):
raise NotImplementedError()

def get_user_vectors(self):
"""Getting a matrix of user vectors served as query for ANN search.
"""Getting a matrix of user vectors serving as query for ANN search.
Returns
-------
Expand All @@ -547,7 +547,7 @@ def get_user_vectors(self):
raise NotImplementedError()

def get_item_vectors(self):
"""Getting a matrix of item vectors used for building index for ANN search.
"""Getting a matrix of item vectors used for building the index for ANN search.
Returns
-------
Expand Down
79 changes: 58 additions & 21 deletions tutorials/ann_hnswlib.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,30 @@
"exclude_unknowns = True\n",
"---\n",
"Training data:\n",
"Number of users = 9985\n",
"Number of items = 4926\n",
"Number of users = 9986\n",
"Number of items = 4921\n",
"Number of ratings = 547022\n",
"Max rating = 1.0\n",
"Min rating = 1.0\n",
"Global mean = 1.0\n",
"---\n",
"Test data:\n",
"Number of users = 9985\n",
"Number of items = 4926\n",
"Number of ratings = 60748\n",
"Number of users = 9986\n",
"Number of items = 4921\n",
"Number of ratings = 60747\n",
"Number of unknown users = 0\n",
"Number of unknown items = 0\n",
"---\n",
"Total users = 9985\n",
"Total items = 4926\n",
"Total users = 9986\n",
"Total items = 4921\n",
"\n",
"[MF] Training started!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8136f6f5bca44cde89e682ab2a62dd24",
"model_id": "d19c59bb6f934859aacdae530daeb020",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -105,7 +105,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d917bf14c68548a490d8b9f38562320b",
"model_id": "7e8a119273404c099b8344534a702059",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -125,7 +125,7 @@
"...\n",
" | AUC | Recall@20 | Train (s) | Test (s)\n",
"-- + ------ + --------- + --------- + --------\n",
"MF | 0.8517 | 0.0440 | 1.1022 | 11.8385\n",
"MF | 0.8530 | 0.0669 | 1.0909 | 11.7523\n",
"\n"
]
}
Expand All @@ -139,6 +139,7 @@
" rating_threshold=1.0,\n",
" exclude_unknowns=True,\n",
" verbose=True,\n",
" seed=123,\n",
")\n",
"\n",
"mf = MF(\n",
Expand Down Expand Up @@ -169,7 +170,7 @@
"source": [
"## Building index for ANN recommender\n",
"\n",
"After MF model is trained, we need to wrap it with an ANN recommender. We employ Cornac built-in HNSWLibANN which implements [HNSW algorithm](https://arxiv.org/abs/1603.09320) for building index and doing approximate K-nearest neighbor search. More on how to tune the hyper-parameters at https://github.com/nmslib/hnswlib."
"After MF model is trained, we need to wrap it with an ANN recommender. We employ Cornac built-in HNSWLibANN which implements [HNSW algorithm](https://arxiv.org/abs/1603.09320) for building index and doing approximate K-nearest neighbor search. More on how to tune the hyper-parameters at https://github.com/nmslib/hnswlib and https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md."
]
},
{
Expand All @@ -184,6 +185,7 @@
" M=16,\n",
" ef_construction=100,\n",
" ef=50,\n",
" seed=123,\n",
" num_threads=-1,\n",
")\n",
"ann.build_index()"
Expand Down Expand Up @@ -214,9 +216,7 @@
"source": [
"K = 20\n",
"N = 10000\n",
"test_users = np.random.choice(mf.user_ids, size=N)\n",
"mf_recs = []\n",
"ann_recs = []"
"test_users = np.random.RandomState(123).choice(mf.user_ids, size=N)"
]
},
{
Expand All @@ -229,13 +229,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2min 40s, sys: 19.5 ms, total: 2min 40s\n",
"Wall time: 5.04 s\n"
"CPU times: user 2min 40s, sys: 15.4 ms, total: 2min 40s\n",
"Wall time: 5.02 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"mf_recs = []\n",
"for uid in test_users:\n",
" mf_recs.append(mf.recommend(uid, k=K))"
]
Expand All @@ -250,13 +252,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 312 ms, sys: 31 µs, total: 312 ms\n",
"Wall time: 309 ms\n"
"CPU times: user 288 ms, sys: 32 µs, total: 288 ms\n",
"Wall time: 285 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"ann_recs = []\n",
"for uid in test_users:\n",
" ann_recs.append(ann.recommend(uid, k=K))"
]
Expand All @@ -266,7 +270,7 @@
"id": "e2f4f68a-c69b-4e70-b32a-81eb78d21279",
"metadata": {},
"source": [
"While it took MF 5.04s to complete the task, it's only 309ms for ANN. The speed up is about 15 times. Note that our dataset contains less than 5000 items. We will see an even bigger improvement with more items and with higher dimensions of the factors."
"While it took MF 5.02s to complete the task, it's only 285ms for ANN. The speed up is about 17 times. Note that our dataset contains less than 5000 items. We will see an even bigger improvement with more items and with high dimensional factors."
]
},
{
Expand All @@ -279,7 +283,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"99.75699999999999\n"
"99.87549999999999\n"
]
}
],
Expand Down Expand Up @@ -315,7 +319,7 @@
{
"data": {
"text/plain": [
"'save_dir/HNSWLibANN/2023-11-09_06-15-10-660690.pkl'"
"'save_dir/HNSWLibANN/2023-11-14_00-01-22-869313.pkl'"
]
},
"execution_count": 9,
Expand Down Expand Up @@ -368,6 +372,39 @@
" loaded_ann.recommend_batch(test_users[:5], k=K),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1699eac6-8b72-4dbc-87dd-94e68080bbb8",
"metadata": {},
"source": [
"One more test, the loaded ANN should achieve the same recall as the original one."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7982d439-aadd-434a-8dd3-01d462486350",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"99.87549999999999\n"
]
}
],
"source": [
"loaded_ann_recs = []\n",
"for uid in test_users:\n",
" loaded_ann_recs.append(loaded_ann.recommend(uid, k=K))\n",
" \n",
"recalls = []\n",
"for mf_rec, ann_rec in zip(mf_recs, loaded_ann_recs):\n",
" recalls.append(len(set(mf_rec) & set(ann_rec)) / len(mf_rec))\n",
"print(np.mean(recalls) * 100.0)"
]
}
],
"metadata": {
Expand Down

0 comments on commit ad05429

Please sign in to comment.