Skip to content

Commit

Permalink
Add recommend API in Recommender
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Oct 25, 2023
1 parent 139f0eb commit 89fa6d2
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions cornac/models/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,53 @@ def rank(self, user_idx, item_indices=None):

return item_rank, item_scores

def recommend(self, user_id, k=-1, remove_seen=False):
"""Generate top-K item recommendations for a given user. Key difference between
this function and rank() function is that rank() function works with mapped
user/item index while this function works with original user/item ID. This helps
hide the abstraction of ID-index mapping, and make model usage and deployment cleaner.
Parameters
----------
user_id: str, required
The original ID of the user.
k: int, optional, default=-1
Cut-off length for recommendations, k=-1 will return ranked list of all items.
remove_seen: bool, optional, default: False
Remove seen/known items during training and validation from output recommendations.
Returns
-------
output: list
Recommended items in the form of their original IDs.
"""
user_idx = self.train_set.uid_map.get(user_id, -1)
if user_idx == -1:
raise ValueError(f"{user_id} is unknown to the model.")

if k < -1 or k > self.train_set.total_items:
raise ValueError(
f"k={k} is invalid, there are {self.train_set.total_users} users in total."
)

item_indices = np.arange(self.train_set.total_items)
if remove_seen:
unk_mask = np.ones(len(item_indices), dtype="bool")
if not self.train_set.is_unk_user(user_idx):
unk_mask[self.train_set.csr_matrix.getrow(user_idx).indices] = False
if not self.val_set is None and not self.val_set.is_unk_user(user_idx):
unk_mask[self.val_set.csr_matrix.getrow(user_idx).indices] = False
item_indices = item_indices[unk_mask]

item_rank, _ = self.rank(user_idx, item_indices)
if k != -1:
item_rank = item_rank[:k]
output = [self.train_set.item_ids[i] for i in item_rank]

return output

def monitor_value(self):
"""Calculating monitored value used for early stopping on validation set (`val_set`).
This function will be called by `early_stop()` function.
Expand Down

0 comments on commit 89fa6d2

Please sign in to comment.