Skip to content

Commit

Permalink
Let models use knows_user() and knows_item() instead of using train_set
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Oct 26, 2023
1 parent 62d2563 commit a429d61
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 21 deletions.
8 changes: 0 additions & 8 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,14 +578,6 @@ def item_iter(self, batch_size=1, shuffle=False):
for batch_ids in self.idx_iter(len(item_indices), batch_size, shuffle):
yield item_indices[batch_ids]

def contains_user(self, user_idx):
"""Return whether given user index is in the dataset"""
return user_idx >= 0 and user_idx < self.num_users

def contains_item(self, item_idx):
"""Return whether given item index is in the dataset"""
return item_idx >= 0 and item_idx < self.num_items

def add_modalities(self, **kwargs):
self.user_feature = kwargs.get("user_feature", None)
self.item_feature = kwargs.get("item_feature", None)
Expand Down
6 changes: 3 additions & 3 deletions cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def ranking_eval(
avg_results = []
user_results = [{} for _ in enumerate(metrics)]

gt_mat = test_set.csr_matrix
test_mat = test_set.csr_matrix
train_mat = train_set.csr_matrix
val_mat = None if val_set is None else val_set.csr_matrix

Expand All @@ -175,7 +175,7 @@ def pos_items(csr_row):
for user_idx in tqdm(
test_user_indices, desc="Ranking", disable=not verbose, miniters=100
):
test_pos_items = pos_items(gt_mat.getrow(user_idx))
test_pos_items = pos_items(test_mat.getrow(user_idx))
if len(test_pos_items) == 0:
continue

Expand All @@ -186,7 +186,7 @@ def pos_items(csr_row):
val_pos_items = [] if val_mat is None else pos_items(val_mat.getrow(user_idx))
train_pos_items = (
pos_items(train_mat.getrow(user_idx))
if train_set.contains_user(user_idx)
if user_idx < train_mat.shape[0]
else []
)

Expand Down
2 changes: 1 addition & 1 deletion cornac/models/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def recommend(self, user_id, k=-1, remove_seen=False, train_set=None):
seen_mask = np.zeros(len(item_indices), dtype="bool")
if train_set is None:
raise ValueError("train_set must be provided to remove seen items.")
if train_set.contains_user(user_idx):
if user_idx < train_set.csr_matrix.shape[0]:
seen_mask[train_set.csr_matrix.getrow(user_idx).indices] = True
item_indices = item_indices[~seen_mask]

Expand Down
6 changes: 0 additions & 6 deletions tests/cornac/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ def test_init(self):
self.assertEqual(train_set.num_users, 10)
self.assertEqual(train_set.num_items, 10)

self.assertTrue(train_set.contains_user(7))
self.assertFalse(train_set.contains_item(13))

self.assertTrue(train_set.contains_item(3))
self.assertFalse(train_set.contains_item(16))

self.assertEqual(train_set.uid_map["768"], 1)
self.assertEqual(train_set.iid_map["195"], 7)

Expand Down
14 changes: 11 additions & 3 deletions tests/cornac/models/test_recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@
import unittest

from cornac.data import Reader, Dataset
from cornac.eval_methods import RatioSplit
from cornac.models import MF
from cornac.metrics import MAE, RMSE
from cornac.experiment.experiment import Experiment


class TestRecommender(unittest.TestCase):
def setUp(self):
self.data = Reader().read("./tests/data.txt")

def test_knows_x(self):
mf = MF(1, 1, seed=123)
dataset = Dataset.from_uir(self.data)
mf.fit(dataset)

self.assertTrue(mf.knows_user(7))
self.assertFalse(mf.knows_item(13))

self.assertTrue(mf.knows_item(3))
self.assertFalse(mf.knows_item(16))

def test_recommend(self):
mf = MF(1, 1, seed=123)
dataset = Dataset.from_uir(self.data)
Expand Down

0 comments on commit a429d61

Please sign in to comment.