diff --git a/cornac/data/dataset.py b/cornac/data/dataset.py index 11c2b3e5..a26f6563 100644 --- a/cornac/data/dataset.py +++ b/cornac/data/dataset.py @@ -941,8 +941,8 @@ def ub_iter(self, batch_size=1, shuffle=False): batch_baskets = [self.user_basket_data[uid] for uid in batch_users] yield batch_users, batch_baskets - def ubis_iter(self, batch_size=1, shuffle=False): - """Create an iterator over data yielding batch of users and batch of baskets (each basket is a list of items) + def ubi_iter(self, batch_size=1, shuffle=False): + """Create an iterator over data yielding batch of users, basket ids, and batch of the corresponding items Parameters ---------- @@ -953,7 +953,7 @@ def ubis_iter(self, batch_size=1, shuffle=False): Returns ------- - iterator : batch of user indices, batch of baskets (each basket is a list of items) corresponding to user indices + iterator : batch of user indices, batch of baskets corresponding to user indices, and batch of items correponding to baskets """ user_indices = np.asarray(list(self.user_basket_data.keys()), dtype="int") @@ -961,6 +961,7 @@ def ubis_iter(self, batch_size=1, shuffle=False): len(self.user_basket_data), batch_size=batch_size, shuffle=shuffle ): batch_users = user_indices[batch_ids] + batch_baskets = [self.user_basket_data[uid] for uid in batch_users] batch_basket_items = [ [ [self.uir_tuple[1][idx] for idx in self.baskets[bid]] @@ -968,7 +969,7 @@ def ubis_iter(self, batch_size=1, shuffle=False): ] for uid in batch_users ] - yield batch_users, batch_basket_items + yield batch_users, batch_baskets, batch_basket_items def basket_iter(self, batch_size=1, shuffle=False): """Create an iterator over data yielding batch of basket indices diff --git a/cornac/eval_methods/next_basket_evaluation.py b/cornac/eval_methods/next_basket_evaluation.py index 94b8926f..16a8a776 100644 --- a/cornac/eval_methods/next_basket_evaluation.py +++ b/cornac/eval_methods/next_basket_evaluation.py @@ -109,8 +109,8 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns): return item_indices, u_gt_pos_items, u_gt_neg_items (test_user_indices, *_) = test_set.uir_tuple - for [user_idx], [(*history_baskets, gt_basket)] in tqdm( - test_set.ubis_iter(batch_size=1, shuffle=False), + for [user_idx], [bids], [(*history_baskets, gt_basket)] in tqdm( + test_set.ubi_iter(batch_size=1, shuffle=False), total=len(set(test_user_indices)), desc="Ranking", disable=not verbose, @@ -128,7 +128,7 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns): user_idx, item_indices, history_baskets=history_baskets, - history_basket_ids=test_set.user_basket_data[user_idx][:-1], + history_basket_ids=bids[:-1], uir_tuple=test_set.uir_tuple, baskets=test_set.baskets, basket_ids=test_set.basket_ids,