Skip to content

Commit

Permalink
add ANN support
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Nov 30, 2023
1 parent 7b2e295 commit 051a321
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions cornac/models/fm/recom_fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import numpy as np

from ..recommender import Recommender
from ..recommender import ANNMixin, MEASURE_DOT
from ...utils import get_rng
from ...utils.init_utils import zeros, normal


class FM(Recommender):
class FM(Recommender, ANNMixin):
"""Factorization Machines.
Parameters
Expand Down Expand Up @@ -222,9 +223,7 @@ def _fm_predict(self, user_idx, item_idx):
if self.k1:
score += self.w[uid] + self.w[iid]
if self.k2:
sum_ = self.v[:, uid] + self.v[:, iid]
sum_sqr_ = self.v[:, uid] ** 2 + self.v[:, iid] ** 2
score += 0.5 * (sum_**2 - sum_sqr_).sum()
score += self.v[:, uid].dot(self.v[:, iid])
return score

def _fm_predict_all(self, user_idx):
Expand Down Expand Up @@ -263,3 +262,43 @@ def score(self, user_idx, item_idx=None):
return self._fm_predict_all(user_idx)
else:
return self._fm_predict(user_idx, item_idx)

def get_vector_measure(self):
"""Getting a valid choice of vector measurement in ANNMixin._measures.
Returns
-------
measure: MEASURE_DOT
Dot product aka. inner product
"""
return MEASURE_DOT

def get_user_vectors(self):
"""Getting a matrix of user vectors serving as query for ANN search.
Returns
-------
out: numpy.array
Matrix of user vectors for all users available in the model.
"""
user_vectors = self.v[:, : self.total_users]
if self.k1: # has bias term
user_vectors = np.concatenate(
(user_vectors, np.ones([user_vectors.shape[0], 1])), axis=1
)
return user_vectors

def get_item_vectors(self):
"""Getting a matrix of item vectors used for building the index for ANN search.
Returns
-------
out: numpy.array
Matrix of item vectors for all items available in the model.
"""
item_vectors = self.v[:, self.total_users :]
if self.k1: # has bias term
item_vectors = np.concatenate(
(item_vectors, self.w[self.total_users :].reshape((-1, 1))), axis=1
)
return item_vectors

0 comments on commit 051a321

Please sign in to comment.