Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Dec 7, 2023
1 parent 92b7891 commit 92b29a3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,8 +941,8 @@ def ub_iter(self, batch_size=1, shuffle=False):
batch_baskets = [self.user_basket_data[uid] for uid in batch_users]
yield batch_users, batch_baskets

def ubis_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of users and batch of baskets (each basket is a list of items)
def ubi_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of users, basket ids, and batch of the corresponding items
Parameters
----------
Expand All @@ -953,22 +953,23 @@ def ubis_iter(self, batch_size=1, shuffle=False):
Returns
-------
iterator : batch of user indices, batch of baskets (each basket is a list of items) corresponding to user indices
iterator : batch of user indices, batch of baskets corresponding to user indices, and batch of items correponding to baskets
"""
user_indices = np.asarray(list(self.user_basket_data.keys()), dtype="int")
for batch_ids in self.idx_iter(
len(self.user_basket_data), batch_size=batch_size, shuffle=shuffle
):
batch_users = user_indices[batch_ids]
batch_baskets = [self.user_basket_data[uid] for uid in batch_users]
batch_basket_items = [
[
[self.uir_tuple[1][idx] for idx in self.baskets[bid]]
for bid in self.user_basket_data[uid]
]
for uid in batch_users
]
yield batch_users, batch_basket_items
yield batch_users, batch_baskets, batch_basket_items

def basket_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of basket indices
Expand Down
6 changes: 3 additions & 3 deletions cornac/eval_methods/next_basket_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
return item_indices, u_gt_pos_items, u_gt_neg_items

(test_user_indices, *_) = test_set.uir_tuple
for [user_idx], [(*history_baskets, gt_basket)] in tqdm(
test_set.ubis_iter(batch_size=1, shuffle=False),
for [user_idx], [bids], [(*history_baskets, gt_basket)] in tqdm(
test_set.ubi_iter(batch_size=1, shuffle=False),
total=len(set(test_user_indices)),
desc="Ranking",
disable=not verbose,
Expand All @@ -128,7 +128,7 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
user_idx,
item_indices,
history_baskets=history_baskets,
history_basket_ids=test_set.user_basket_data[user_idx][:-1],
history_basket_ids=bids[:-1],
uir_tuple=test_set.uir_tuple,
baskets=test_set.baskets,
basket_ids=test_set.basket_ids,
Expand Down

0 comments on commit 92b29a3

Please sign in to comment.