diff --git a/cornac/data/dataset.py b/cornac/data/dataset.py index 2f16a8651..58a7c76bd 100644 --- a/cornac/data/dataset.py +++ b/cornac/data/dataset.py @@ -415,6 +415,14 @@ def num_batches(self, batch_size): """Estimate number of batches per epoch""" return estimate_batches(len(self.uir_tuple[0]), batch_size) + def num_user_batches(self, batch_size): + """Estimate number of batches per epoch iterating over users""" + return estimate_batches(self.num_users, batch_size) + + def num_item_batches(self, batch_size): + """Estimate number of batches per epoch iterating over items""" + return estimate_batches(self.num_items, batch_size) + def idx_iter(self, idx_range, batch_size=1, shuffle=False): """Create an iterator over batch of indices @@ -700,9 +708,8 @@ def __init__( def baskets(self): """A dictionary to store indices where basket ID appears in the data.""" if self.__baskets is None: - self.__baskets = OrderedDict() + self.__baskets = defaultdict(list) for idx, bid in enumerate(self.basket_ids): - self.__baskets.setdefault(bid, []) self.__baskets[bid].append(idx) return self.__baskets @@ -712,10 +719,9 @@ def user_basket_data(self): values are list of baskets purchased by corresponding users. """ if self.__user_basket_data is None: - self.__user_basket_data = defaultdict() + self.__user_basket_data = defaultdict(list) for bid, ids in self.baskets.items(): u = self.uir_tuple[0][ids[0]] - self.__user_basket_data.setdefault(u, []) self.__user_basket_data[u].append(bid) return self.__user_basket_data @@ -916,37 +922,50 @@ def from_ubitjson(cls, data, seed=None): """ return cls.build(data, fmt="UBITJson", seed=seed) - def num_batches(self, batch_size): - """Estimate number of batches per epoch""" - return estimate_batches(len(self.user_data), batch_size) + def ub_iter(self, batch_size=1, shuffle=False): + """Create an iterator over data yielding batch of users and batch of baskets - def user_basket_data_iter(self, batch_size=1, shuffle=False): - """Create an iterator over data yielding batch of basket indices and batch of baskets + Parameters + ---------- + batch_size: int, optional, default = 1 + + shuffle: bool, optional, default: False + If `True`, orders of users will be randomized. If `False`, default orders kept. + + Returns + ------- + iterator : batch of user indices, batch of baskets corresponding to user indices + + """ + for batch_users in self.user_iter(batch_size, shuffle): + batch_baskets = [self.user_basket_data[uid] for uid in batch_users] + yield batch_users, batch_baskets + + def ubi_iter(self, batch_size=1, shuffle=False): + """Create an iterator over data yielding batch of users, basket ids, and batch of the corresponding items Parameters ---------- batch_size: int, optional, default = 1 shuffle: bool, optional, default: False - If `True`, orders of triplets will be randomized. If `False`, default orders kept. + If `True`, orders of users will be randomized. If `False`, default orders kept. Returns ------- - iterator : batch of user indices, batch of user data corresponding to user indices + iterator : batch of user indices, batch of baskets corresponding to user indices, and batch of items correponding to baskets """ - user_indices = np.asarray(list(self.user_basket_data.keys()), dtype="int") - for batch_ids in self.idx_iter( - len(self.user_basket_data), batch_size=batch_size, shuffle=shuffle - ): - batch_users = user_indices[batch_ids] - batch_basket_ids = np.asarray( - [self.user_basket_data[uid] for uid in batch_users], dtype="int" - ) - yield batch_users, batch_basket_ids + _, item_indices, _ = self.uir_tuple + for batch_users, batch_baskets in self.ub_iter(batch_size, shuffle): + batch_basket_items = [ + [item_indices[self.baskets[bid]] for bid in user_baskets] + for user_baskets in batch_baskets + ] + yield batch_users, batch_baskets, batch_basket_items def basket_iter(self, batch_size=1, shuffle=False): - """Create an iterator over data yielding batch of basket indices and batch of baskets + """Create an iterator over data yielding batch of basket indices Parameters ---------- @@ -957,12 +976,9 @@ def basket_iter(self, batch_size=1, shuffle=False): Returns ------- - iterator : batch of basket indices, batch of baskets (list of list) + iterator : batch of basket indices (array of 'int') """ - basket_indices = np.array(list(self.baskets.keys())) - baskets = list(self.baskets.values()) + basket_indices = np.fromiter(set(self.baskets.keys()), dtype="int") for batch_ids in self.idx_iter(len(basket_indices), batch_size, shuffle): - batch_basket_indices = basket_indices[batch_ids] - batch_baskets = [baskets[idx] for idx in batch_ids] - yield batch_basket_indices, batch_baskets + yield basket_indices[batch_ids] diff --git a/cornac/eval_methods/next_basket_evaluation.py b/cornac/eval_methods/next_basket_evaluation.py index bd6286c6c..16a8a7768 100644 --- a/cornac/eval_methods/next_basket_evaluation.py +++ b/cornac/eval_methods/next_basket_evaluation.py @@ -108,14 +108,15 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns): item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0] return item_indices, u_gt_pos_items, u_gt_neg_items - (test_user_indices, test_item_indices, _) = test_set.uir_tuple - for user_idx in tqdm( - set(test_user_indices), desc="Ranking", disable=not verbose, miniters=100 + (test_user_indices, *_) = test_set.uir_tuple + for [user_idx], [bids], [(*history_baskets, gt_basket)] in tqdm( + test_set.ubi_iter(batch_size=1, shuffle=False), + total=len(set(test_user_indices)), + desc="Ranking", + disable=not verbose, + miniters=100, ): - [*history_bids, gt_bid] = test_set.user_basket_data[user_idx] - test_pos_items = pos_items( - [[test_item_indices[idx] for idx in test_set.baskets[gt_bid]]] - ) + test_pos_items = pos_items([gt_basket]) if len(test_pos_items) == 0: continue @@ -126,10 +127,9 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns): item_rank, item_scores = model.rank( user_idx, item_indices, - history_baskets=[ - [test_item_indices[idx] for idx in test_set.baskets[bid]] - for bid in history_bids - ], + history_baskets=history_baskets, + history_basket_ids=bids[:-1], + uir_tuple=test_set.uir_tuple, baskets=test_set.baskets, basket_ids=test_set.basket_ids, extra_data=test_set.extra_data, @@ -146,19 +146,11 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns): user_results["conventional"][i][user_idx] = mt_score history_items = set( - test_item_indices[idx] - for bid in history_bids - for idx in test_set.baskets[bid] + item_idx for basket in history_baskets for item_idx in basket ) if repetition_eval: test_repetition_pos_items = pos_items( - [ - [ - test_item_indices[idx] - for idx in test_set.baskets[gt_bid] - if test_item_indices[idx] in history_items - ] - ] + [[iid for iid in gt_basket if iid in history_items]] ) if len(test_repetition_pos_items) > 0: _, u_gt_pos_items, u_gt_neg_items = get_gt_items( @@ -176,13 +168,7 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns): if exploration_eval: test_exploration_pos_items = pos_items( - [ - [ - test_item_indices[idx] - for idx in test_set.baskets[gt_bid] - if test_item_indices[idx] not in history_items - ] - ] + [[iid for iid in gt_basket if iid not in history_items]] ) if len(test_exploration_pos_items) > 0: _, u_gt_pos_items, u_gt_neg_items = get_gt_items( @@ -200,18 +186,21 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns): # avg results of ranking metrics for i, mt in enumerate(metrics): avg_results["conventional"].append( - sum(user_results["conventional"][i].values()) - / len(user_results["conventional"][i]) + np.mean(list(user_results["conventional"][i].values())) + if len(user_results["conventional"][i]) > 0 + else 0 ) if repetition_eval: avg_results["repetition"].append( - sum(user_results["repetition"][i].values()) - / len(user_results["repetition"][i]) + np.mean(list(user_results["repetition"][i].values())) + if len(user_results["repetition"][i]) > 0 + else 0 ) if exploration_eval: avg_results["exploration"].append( - sum(user_results["exploration"][i].values()) - / len(user_results["exploration"][i]) + np.mean(list(user_results["exploration"][i].values())) + if len(user_results["repetition"][i]) > 0 + else 0 ) return avg_results, user_results @@ -365,13 +354,13 @@ def _build_datasets(self, train_data, test_data, val_data=None): print("Total items = {}".format(self.total_items)) print("Total baskets = {}".format(self.total_baskets)) - def _eval(self, model, test_set, **kwargs): + def eval(self, model, test_set, ranking_metrics, **kwargs): metric_avg_results = OrderedDict() metric_user_results = OrderedDict() avg_results, user_results = ranking_eval( model=model, - metrics=self.ranking_metrics, + metrics=ranking_metrics, train_set=self.train_set, test_set=test_set, repetition_eval=self.repetition_eval, @@ -380,12 +369,12 @@ def _eval(self, model, test_set, **kwargs): verbose=self.verbose, ) - for i, mt in enumerate(self.ranking_metrics): + for i, mt in enumerate(ranking_metrics): metric_avg_results[mt.name] = avg_results["conventional"][i] metric_user_results[mt.name] = user_results["conventional"][i] if self.repetition_eval: - for i, mt in enumerate(self.ranking_metrics): + for i, mt in enumerate(ranking_metrics): metric_avg_results["{}-rep".format(mt.name)] = avg_results[ "repetition" ][i] @@ -394,7 +383,7 @@ def _eval(self, model, test_set, **kwargs): ][i] if self.repetition_eval: - for i, mt in enumerate(self.ranking_metrics): + for i, mt in enumerate(ranking_metrics): metric_avg_results["{}-expl".format(mt.name)] = avg_results[ "exploration" ][i] diff --git a/cornac/models/gp_top/recom_gp_top.py b/cornac/models/gp_top/recom_gp_top.py index 32d2ba91b..c8e271984 100644 --- a/cornac/models/gp_top/recom_gp_top.py +++ b/cornac/models/gp_top/recom_gp_top.py @@ -33,6 +33,10 @@ class GPTop(NextBasketRecommender): use_personalized_popularity: boolean, optional, default: True When False, no item frequency from history baskets are being used. + use_quantity: boolean, optional, default: False + When True, constructing item frequency based on its quantity (getting from extra_data). + The data must be in fmt 'UBITJson'. + References ---------- Ming Li, Sami Jullien, Mozhdeh Ariannezhad, and Maarten de Rijke. 2023. @@ -42,31 +46,53 @@ class GPTop(NextBasketRecommender): """ def __init__( - self, name="GPTop", use_global_popularity=True, use_personalized_popularity=True + self, + name="GPTop", + use_global_popularity=True, + use_personalized_popularity=True, + use_quantity=False, ): super().__init__(name=name, trainable=False) self.use_global_popularity = use_global_popularity self.use_personalized_popularity = use_personalized_popularity + self.use_quantity = use_quantity self.item_freq = Counter() def fit(self, train_set, val_set=None): super().fit(train_set=train_set, val_set=val_set) if self.use_global_popularity: - self.item_freq = Counter(self.train_set.uir_tuple[1]) + if self.use_quantity: + self.item_freq = Counter() + for idx, iid in enumerate(self.train_set.uir_tuple[1]): + self.item_freq[iid] += self.train_set.extra_data[idx].get( + "quantity", 0 + ) + else: + self.item_freq = Counter(self.train_set.uir_tuple[1]) return self def score(self, user_idx, history_baskets, **kwargs): - item_scores = np.ones(self.total_items) + item_scores = np.zeros(self.total_items, dtype=np.float32) if self.use_global_popularity: - for iid, freq in self.item_freq.items(): - item_scores[iid] = freq - - if self.use_personalized_popularity: - p_item_freq = Counter([iid for iids in history_baskets for iid in iids]) - max_item_freq = ( max(self.item_freq.values()) if len(self.item_freq) > 0 else 1 ) + for iid, freq in self.item_freq.items(): + item_scores[iid] = freq / max_item_freq + + if self.use_personalized_popularity: + if self.use_quantity: + history_basket_bids = kwargs.get("history_basket_ids") + baskets = kwargs.get("baskets") + p_item_freq = Counter() + (_, item_ids, _) = kwargs.get("uir_tuple") + extra_data = kwargs.get("extra_data") + for bid in history_basket_bids: + ids = baskets[bid] + for idx in ids: + p_item_freq[item_ids[idx]] += extra_data[idx].get("quantity", 0) + else: + p_item_freq = Counter([iid for iids in history_baskets for iid in iids]) for iid, cnt in p_item_freq.most_common(): - item_scores[iid] = max_item_freq + cnt + item_scores[iid] += cnt return item_scores diff --git a/examples/gp_top_tafeng.py b/examples/gp_top_tafeng.py index efaa56859..834acf67a 100644 --- a/examples/gp_top_tafeng.py +++ b/examples/gp_top_tafeng.py @@ -30,8 +30,9 @@ ) models = [ - GPTop(name="PTop", use_global_popularity=False), GPTop(name="GTop", use_personalized_popularity=False), + GPTop(name="PTop", use_global_popularity=False), + GPTop(name="GPTop-quantity", use_quantity=True), GPTop(), ] diff --git a/tests/basket.txt b/tests/basket.txt new file mode 100644 index 000000000..40f2ec99f --- /dev/null +++ b/tests/basket.txt @@ -0,0 +1,50 @@ +1 1 1 882606572 {'quantity': 1} +1 1 6 882606572 {'quantity': 1} +1 2 2 882606573 {'quantity': 1} +1 2 4 882606573 {'quantity': 1} +1 3 3 882606574 {'quantity': 1} +1 3 7 882606574 {'quantity': 1} +2 4 1 882606575 {'quantity': 1} +2 4 2 882606575 {'quantity': 1} +2 5 3 882606576 {'quantity': 1} +2 5 6 882606576 {'quantity': 1} +2 6 4 882606577 {'quantity': 1} +2 6 7 882606577 {'quantity': 1} +3 7 1 882606578 {'quantity': 1} +3 7 3 882606578 {'quantity': 1} +3 8 2 882606579 {'quantity': 1} +3 8 5 882606579 {'quantity': 1} +3 9 6 882606580 {'quantity': 1} +3 9 7 882606580 {'quantity': 1} +4 10 3 882606581 {'quantity': 1} +4 10 5 882606581 {'quantity': 1} +4 11 4 882606582 {'quantity': 1} +4 11 5 882606582 {'quantity': 1} +5 12 5 882606583 {'quantity': 1} +5 12 6 882606583 {'quantity': 1} +5 13 6 882606584 {'quantity': 1} +5 13 7 882606584 {'quantity': 1} +6 14 4 882606585 {'quantity': 1} +6 14 6 882606585 {'quantity': 1} +6 15 3 882606586 {'quantity': 1} +6 15 5 882606586 {'quantity': 1} +7 16 3 882606587 {'quantity': 1} +7 16 6 882606587 {'quantity': 1} +7 17 2 882606588 {'quantity': 1} +7 17 6 882606588 {'quantity': 1} +8 18 3 882606589 {'quantity': 1} +8 19 7 882606590 {'quantity': 1} +9 20 3 882606591 {'quantity': 1} +9 20 7 882606591 {'quantity': 1} +9 21 4 882606592 {'quantity': 1} +9 21 5 882606592 {'quantity': 1} +10 22 1 882606593 {'quantity': 1} +10 22 3 882606593 {'quantity': 1} +10 23 2 882606594 {'quantity': 1} +10 23 3 882606594 {'quantity': 1} +10 24 4 882606595 {'quantity': 1} +10 24 5 882606595 {'quantity': 1} +10 24 7 882606595 {'quantity': 1} +10 25 1 882606595 {'quantity': 1} +10 25 2 882606595 {'quantity': 1} +10 25 5 882606595 {'quantity': 1} \ No newline at end of file diff --git a/tests/cornac/data/test_dataset.py b/tests/cornac/data/test_dataset.py index e7a308a27..161c542e7 100644 --- a/tests/cornac/data/test_dataset.py +++ b/tests/cornac/data/test_dataset.py @@ -14,13 +14,11 @@ # ============================================================================ import unittest -from collections import OrderedDict import numpy as np import numpy.testing as npt -from cornac.data import Reader -from cornac.data import Dataset +from cornac.data import BasketDataset, Dataset, Reader class TestDataset(unittest.TestCase): @@ -171,7 +169,7 @@ def test_uir_tuple(self): self.assertEqual(train_set.num_batches(batch_size=5), 2) def test_matrix(self): - from scipy.sparse import csr_matrix, csc_matrix, dok_matrix + from scipy.sparse import csc_matrix, csr_matrix, dok_matrix train_set = Dataset.from_uir(self.triplet_data) @@ -234,5 +232,34 @@ def test_chrono_item_data(self): assert True +class TestBasketDataset(unittest.TestCase): + def setUp(self): + self.basket_data = Reader().read("./tests/basket.txt", fmt="UBITJson") + + def test_init(self): + train_set = BasketDataset.from_ubi(self.basket_data) + + self.assertEqual(train_set.num_baskets, 25) + self.assertEqual(train_set.max_basket_size, 3) + self.assertEqual(train_set.min_basket_size, 1) + + self.assertEqual(train_set.num_users, 10) + self.assertEqual(train_set.num_items, 7) + + self.assertEqual(train_set.uid_map["1"], 0) + self.assertEqual(train_set.bid_map["1"], 0) + self.assertEqual(train_set.iid_map["1"], 0) + + self.assertSetEqual( + set(train_set.user_ids), + set(["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]), + ) + + self.assertSetEqual( + set(train_set.item_ids), + set(["1", "2", "3", "4", "5", "6", "7"]), + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/cornac/data/test_reader.py b/tests/cornac/data/test_reader.py index fec565a76..04639396e 100644 --- a/tests/cornac/data/test_reader.py +++ b/tests/cornac/data/test_reader.py @@ -20,24 +20,24 @@ class TestReader(unittest.TestCase): - def setUp(self): - self.data_file = './tests/data.txt' + self.data_file = "./tests/data.txt" + self.basket_file = "./tests/basket.txt" self.reader = Reader() def test_raise(self): try: - self.reader.read(self.data_file, fmt='bla bla') + self.reader.read(self.data_file, fmt="bla bla") except ValueError: assert True def test_read_ui(self): - triplets = self.reader.read(self.data_file, fmt='UI') + triplets = self.reader.read(self.data_file, fmt="UI") self.assertEqual(len(triplets), 30) - self.assertEqual(triplets[0][1], '93') + self.assertEqual(triplets[0][1], "93") self.assertEqual(triplets[1][2], 1.0) - triplets = self.reader.read(self.data_file, fmt='UI', id_inline=True) + triplets = self.reader.read(self.data_file, fmt="UI", id_inline=True) self.assertEqual(len(triplets), 40) def test_read_uir(self): @@ -45,32 +45,32 @@ def test_read_uir(self): self.assertEqual(len(triplet_data), 10) self.assertEqual(triplet_data[4][2], 3) - self.assertEqual(triplet_data[6][1], '478') - self.assertEqual(triplet_data[8][0], '543') + self.assertEqual(triplet_data[6][1], "478") + self.assertEqual(triplet_data[8][0], "543") def test_read_uirt(self): - data = self.reader.read(self.data_file, fmt='UIRT') + data = self.reader.read(self.data_file, fmt="UIRT") self.assertEqual(len(data), 10) self.assertEqual(data[4][3], 891656347) self.assertEqual(data[4][2], 3) - self.assertEqual(data[4][1], '705') - self.assertEqual(data[4][0], '329') + self.assertEqual(data[4][1], "705") + self.assertEqual(data[4][0], "329") self.assertEqual(data[9][3], 879451804) def test_read_tup(self): - tup_data = self.reader.read(self.data_file, fmt='UITup') + tup_data = self.reader.read(self.data_file, fmt="UITup") self.assertEqual(len(tup_data), 10) - self.assertEqual(tup_data[4][2], [('3',), ('891656347',)]) - self.assertEqual(tup_data[6][1], '478') - self.assertEqual(tup_data[8][0], '543') + self.assertEqual(tup_data[4][2], [("3",), ("891656347",)]) + self.assertEqual(tup_data[6][1], "478") + self.assertEqual(tup_data[8][0], "543") def test_read_review(self): - review_data = self.reader.read('./tests/review.txt', fmt='UIReview') + review_data = self.reader.read("./tests/review.txt", fmt="UIReview") self.assertEqual(len(review_data), 5) - self.assertEqual(review_data[0][2], 'Sample text 1') - self.assertEqual(review_data[1][1], '257') - self.assertEqual(review_data[4][0], '329') + self.assertEqual(review_data[0][2], "Sample text 1") + self.assertEqual(review_data[1][1], "257") + self.assertEqual(review_data[4][0], "329") def test_filter(self): reader = Reader(bin_threshold=4.0) @@ -84,19 +84,30 @@ def test_filter(self): reader = Reader(min_item_freq=2) self.assertEqual(len(reader.read(self.data_file)), 0) - reader = Reader(user_set=['76'], item_set=['93']) + reader = Reader(user_set=["76"], item_set=["93"]) self.assertEqual(len(reader.read(self.data_file)), 1) - reader = Reader(user_set=['76', '768']) + reader = Reader(user_set=["76", "768"]) self.assertEqual(len(reader.read(self.data_file)), 2) - reader = Reader(item_set=['93', '257', '795']) + reader = Reader(item_set=["93", "257", "795"]) self.assertEqual(len(reader.read(self.data_file)), 3) def test_read_text(self): self.assertEqual(len(read_text(self.data_file, sep=None)), 10) - self.assertEqual(read_text(self.data_file, sep='\t')[1][0], '76') + self.assertEqual(read_text(self.data_file, sep="\t")[1][0], "76") + + def test_read_basket(self): + self.assertEqual( + len(self.reader.read(self.basket_file, sep="\t", fmt="UBI")), 50 + ) + self.assertEqual( + len(self.reader.read(self.basket_file, sep="\t", fmt="UBIT")), 50 + ) + self.assertEqual( + len(self.reader.read(self.basket_file, sep="\t", fmt="UBITJson")), 50 + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/cornac/eval_methods/test_next_basket_evaluation.py b/tests/cornac/eval_methods/test_next_basket_evaluation.py new file mode 100644 index 000000000..00a479b97 --- /dev/null +++ b/tests/cornac/eval_methods/test_next_basket_evaluation.py @@ -0,0 +1,57 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import unittest + +from cornac.eval_methods import NextBasketEvaluation +from cornac.data import Reader +from cornac.models import GPTop +from cornac.metrics import HitRatio, Recall + + +class TestNextBasketEvaluation(unittest.TestCase): + def setUp(self): + self.data = Reader().read("./tests/basket.txt", fmt="UBITJson", sep="\t") + + def test_splits(self): + next_basket_eval = NextBasketEvaluation( + self.data, test_size=0.1, val_size=0.1, seed=123, verbose=True + ) + + self.assertTrue(next_basket_eval.train_size == 8) + self.assertTrue(next_basket_eval.test_size == 1) + self.assertTrue(next_basket_eval.val_size == 1) + + def test_evaluate(self): + next_basket_eval = NextBasketEvaluation( + self.data, exclude_unknowns=False, verbose=True + ) + next_basket_eval.evaluate( + GPTop(), [HitRatio(k=2), Recall(k=2)], user_based=True + ) + + next_basket_eval = NextBasketEvaluation( + self.data, + repetition_eval=True, + exploration_eval=True, + exclude_unknowns=False, + verbose=True, + ) + next_basket_eval.evaluate( + GPTop(), [HitRatio(k=2), Recall(k=2)], user_based=True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cornac/models/test_recommender.py b/tests/cornac/models/test_recommender.py index a59e46c42..22dbf91ac 100644 --- a/tests/cornac/models/test_recommender.py +++ b/tests/cornac/models/test_recommender.py @@ -15,8 +15,8 @@ import unittest -from cornac.data import Reader, Dataset -from cornac.models import MF +from cornac.data import BasketDataset, Dataset, Reader +from cornac.models import MF, GPTop, NextBasketRecommender class TestRecommender(unittest.TestCase): @@ -51,5 +51,23 @@ def test_recommend(self): ) +class TestNextBasketRecommender(unittest.TestCase): + def setUp(self): + self.data = Reader().read("./tests/basket.txt", fmt="UBITJson") + + def test_init(self): + model = NextBasketRecommender("test") + self.assertTrue(model.name == "test") + + def test_fit(self): + dataset = BasketDataset.from_ubi(self.data) + model = NextBasketRecommender("") + model.fit(dataset) + model = GPTop() + model.fit(dataset) + model.score(0, [[]]) + model.rank(0, history_baskets=[[]]) + + if __name__ == "__main__": unittest.main()