From a429d61f27f63c20e55d18f309a5539b8a037cf1 Mon Sep 17 00:00:00 2001 From: tqtg Date: Thu, 26 Oct 2023 19:33:39 +0000 Subject: [PATCH] Let models use knows_user() and knows_item() instead of using train_set --- cornac/data/dataset.py | 8 -------- cornac/eval_methods/base_method.py | 6 +++--- cornac/models/recommender.py | 2 +- tests/cornac/data/test_dataset.py | 6 ------ tests/cornac/models/test_recommender.py | 14 +++++++++++--- 5 files changed, 15 insertions(+), 21 deletions(-) diff --git a/cornac/data/dataset.py b/cornac/data/dataset.py index 6691bd086..59deb9e09 100644 --- a/cornac/data/dataset.py +++ b/cornac/data/dataset.py @@ -578,14 +578,6 @@ def item_iter(self, batch_size=1, shuffle=False): for batch_ids in self.idx_iter(len(item_indices), batch_size, shuffle): yield item_indices[batch_ids] - def contains_user(self, user_idx): - """Return whether given user index is in the dataset""" - return user_idx >= 0 and user_idx < self.num_users - - def contains_item(self, item_idx): - """Return whether given item index is in the dataset""" - return item_idx >= 0 and item_idx < self.num_items - def add_modalities(self, **kwargs): self.user_feature = kwargs.get("user_feature", None) self.item_feature = kwargs.get("item_feature", None) diff --git a/cornac/eval_methods/base_method.py b/cornac/eval_methods/base_method.py index eb11926a0..0bc29f683 100644 --- a/cornac/eval_methods/base_method.py +++ b/cornac/eval_methods/base_method.py @@ -160,7 +160,7 @@ def ranking_eval( avg_results = [] user_results = [{} for _ in enumerate(metrics)] - gt_mat = test_set.csr_matrix + test_mat = test_set.csr_matrix train_mat = train_set.csr_matrix val_mat = None if val_set is None else val_set.csr_matrix @@ -175,7 +175,7 @@ def pos_items(csr_row): for user_idx in tqdm( test_user_indices, desc="Ranking", disable=not verbose, miniters=100 ): - test_pos_items = pos_items(gt_mat.getrow(user_idx)) + test_pos_items = pos_items(test_mat.getrow(user_idx)) if len(test_pos_items) == 0: continue @@ -186,7 +186,7 @@ def pos_items(csr_row): val_pos_items = [] if val_mat is None else pos_items(val_mat.getrow(user_idx)) train_pos_items = ( pos_items(train_mat.getrow(user_idx)) - if train_set.contains_user(user_idx) + if user_idx < train_mat.shape[0] else [] ) diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py index 4dcf051f8..948d49b2e 100644 --- a/cornac/models/recommender.py +++ b/cornac/models/recommender.py @@ -421,7 +421,7 @@ def recommend(self, user_id, k=-1, remove_seen=False, train_set=None): seen_mask = np.zeros(len(item_indices), dtype="bool") if train_set is None: raise ValueError("train_set must be provided to remove seen items.") - if train_set.contains_user(user_idx): + if user_idx < train_set.csr_matrix.shape[0]: seen_mask[train_set.csr_matrix.getrow(user_idx).indices] = True item_indices = item_indices[~seen_mask] diff --git a/tests/cornac/data/test_dataset.py b/tests/cornac/data/test_dataset.py index e7ecddcc9..e7a308a27 100644 --- a/tests/cornac/data/test_dataset.py +++ b/tests/cornac/data/test_dataset.py @@ -40,12 +40,6 @@ def test_init(self): self.assertEqual(train_set.num_users, 10) self.assertEqual(train_set.num_items, 10) - self.assertTrue(train_set.contains_user(7)) - self.assertFalse(train_set.contains_item(13)) - - self.assertTrue(train_set.contains_item(3)) - self.assertFalse(train_set.contains_item(16)) - self.assertEqual(train_set.uid_map["768"], 1) self.assertEqual(train_set.iid_map["195"], 7) diff --git a/tests/cornac/models/test_recommender.py b/tests/cornac/models/test_recommender.py index af7accee3..a59e46c42 100644 --- a/tests/cornac/models/test_recommender.py +++ b/tests/cornac/models/test_recommender.py @@ -16,16 +16,24 @@ import unittest from cornac.data import Reader, Dataset -from cornac.eval_methods import RatioSplit from cornac.models import MF -from cornac.metrics import MAE, RMSE -from cornac.experiment.experiment import Experiment class TestRecommender(unittest.TestCase): def setUp(self): self.data = Reader().read("./tests/data.txt") + def test_knows_x(self): + mf = MF(1, 1, seed=123) + dataset = Dataset.from_uir(self.data) + mf.fit(dataset) + + self.assertTrue(mf.knows_user(7)) + self.assertFalse(mf.knows_item(13)) + + self.assertTrue(mf.knows_item(3)) + self.assertFalse(mf.knows_item(16)) + def test_recommend(self): mf = MF(1, 1, seed=123) dataset = Dataset.from_uir(self.data)