Skip to content

Commit

Permalink
Update next-basket evaluation (#559)
Browse files Browse the repository at this point in the history
* Fix eval and add batch basket items iteration

* refactor code

* Fix scoring function

* refactor code

* Add unittest for NextBasketEvaluation

* Add unittest for BasketDataset

* Add test NextBasketRecommender

* Add test case reading basket data

* refactor code

* Add history basket ids for accessing extra data in scoring function

* Add option to scoring function based on quantity provided in extra_data

* refactor code

* reuse user_iter() in ub_iter()

* reuse ub_iter() in ubi_iter()

* consider using num_user_batches() instead of overwriting existing num_batches()

---------

Co-authored-by: tqtg <[email protected]>
  • Loading branch information
lthoang and tqtg authored Dec 8, 2023
1 parent 190c5a1 commit af62a20
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 107 deletions.
70 changes: 43 additions & 27 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,14 @@ def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.uir_tuple[0]), batch_size)

def num_user_batches(self, batch_size):
"""Estimate number of batches per epoch iterating over users"""
return estimate_batches(self.num_users, batch_size)

def num_item_batches(self, batch_size):
"""Estimate number of batches per epoch iterating over items"""
return estimate_batches(self.num_items, batch_size)

def idx_iter(self, idx_range, batch_size=1, shuffle=False):
"""Create an iterator over batch of indices
Expand Down Expand Up @@ -700,9 +708,8 @@ def __init__(
def baskets(self):
"""A dictionary to store indices where basket ID appears in the data."""
if self.__baskets is None:
self.__baskets = OrderedDict()
self.__baskets = defaultdict(list)
for idx, bid in enumerate(self.basket_ids):
self.__baskets.setdefault(bid, [])
self.__baskets[bid].append(idx)
return self.__baskets

Expand All @@ -712,10 +719,9 @@ def user_basket_data(self):
values are list of baskets purchased by corresponding users.
"""
if self.__user_basket_data is None:
self.__user_basket_data = defaultdict()
self.__user_basket_data = defaultdict(list)
for bid, ids in self.baskets.items():
u = self.uir_tuple[0][ids[0]]
self.__user_basket_data.setdefault(u, [])
self.__user_basket_data[u].append(bid)
return self.__user_basket_data

Expand Down Expand Up @@ -916,37 +922,50 @@ def from_ubitjson(cls, data, seed=None):
"""
return cls.build(data, fmt="UBITJson", seed=seed)

def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.user_data), batch_size)
def ub_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of users and batch of baskets
def user_basket_data_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of basket indices and batch of baskets
Parameters
----------
batch_size: int, optional, default = 1
shuffle: bool, optional, default: False
If `True`, orders of users will be randomized. If `False`, default orders kept.
Returns
-------
iterator : batch of user indices, batch of baskets corresponding to user indices
"""
for batch_users in self.user_iter(batch_size, shuffle):
batch_baskets = [self.user_basket_data[uid] for uid in batch_users]
yield batch_users, batch_baskets

def ubi_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of users, basket ids, and batch of the corresponding items
Parameters
----------
batch_size: int, optional, default = 1
shuffle: bool, optional, default: False
If `True`, orders of triplets will be randomized. If `False`, default orders kept.
If `True`, orders of users will be randomized. If `False`, default orders kept.
Returns
-------
iterator : batch of user indices, batch of user data corresponding to user indices
iterator : batch of user indices, batch of baskets corresponding to user indices, and batch of items correponding to baskets
"""
user_indices = np.asarray(list(self.user_basket_data.keys()), dtype="int")
for batch_ids in self.idx_iter(
len(self.user_basket_data), batch_size=batch_size, shuffle=shuffle
):
batch_users = user_indices[batch_ids]
batch_basket_ids = np.asarray(
[self.user_basket_data[uid] for uid in batch_users], dtype="int"
)
yield batch_users, batch_basket_ids
_, item_indices, _ = self.uir_tuple
for batch_users, batch_baskets in self.ub_iter(batch_size, shuffle):
batch_basket_items = [
[item_indices[self.baskets[bid]] for bid in user_baskets]
for user_baskets in batch_baskets
]
yield batch_users, batch_baskets, batch_basket_items

def basket_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of basket indices and batch of baskets
"""Create an iterator over data yielding batch of basket indices
Parameters
----------
Expand All @@ -957,12 +976,9 @@ def basket_iter(self, batch_size=1, shuffle=False):
Returns
-------
iterator : batch of basket indices, batch of baskets (list of list)
iterator : batch of basket indices (array of 'int')
"""
basket_indices = np.array(list(self.baskets.keys()))
baskets = list(self.baskets.values())
basket_indices = np.fromiter(set(self.baskets.keys()), dtype="int")
for batch_ids in self.idx_iter(len(basket_indices), batch_size, shuffle):
batch_basket_indices = basket_indices[batch_ids]
batch_baskets = [baskets[idx] for idx in batch_ids]
yield batch_basket_indices, batch_baskets
yield basket_indices[batch_ids]
67 changes: 28 additions & 39 deletions cornac/eval_methods/next_basket_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,15 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0]
return item_indices, u_gt_pos_items, u_gt_neg_items

(test_user_indices, test_item_indices, _) = test_set.uir_tuple
for user_idx in tqdm(
set(test_user_indices), desc="Ranking", disable=not verbose, miniters=100
(test_user_indices, *_) = test_set.uir_tuple
for [user_idx], [bids], [(*history_baskets, gt_basket)] in tqdm(
test_set.ubi_iter(batch_size=1, shuffle=False),
total=len(set(test_user_indices)),
desc="Ranking",
disable=not verbose,
miniters=100,
):
[*history_bids, gt_bid] = test_set.user_basket_data[user_idx]
test_pos_items = pos_items(
[[test_item_indices[idx] for idx in test_set.baskets[gt_bid]]]
)
test_pos_items = pos_items([gt_basket])
if len(test_pos_items) == 0:
continue

Expand All @@ -126,10 +127,9 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
item_rank, item_scores = model.rank(
user_idx,
item_indices,
history_baskets=[
[test_item_indices[idx] for idx in test_set.baskets[bid]]
for bid in history_bids
],
history_baskets=history_baskets,
history_basket_ids=bids[:-1],
uir_tuple=test_set.uir_tuple,
baskets=test_set.baskets,
basket_ids=test_set.basket_ids,
extra_data=test_set.extra_data,
Expand All @@ -146,19 +146,11 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
user_results["conventional"][i][user_idx] = mt_score

history_items = set(
test_item_indices[idx]
for bid in history_bids
for idx in test_set.baskets[bid]
item_idx for basket in history_baskets for item_idx in basket
)
if repetition_eval:
test_repetition_pos_items = pos_items(
[
[
test_item_indices[idx]
for idx in test_set.baskets[gt_bid]
if test_item_indices[idx] in history_items
]
]
[[iid for iid in gt_basket if iid in history_items]]
)
if len(test_repetition_pos_items) > 0:
_, u_gt_pos_items, u_gt_neg_items = get_gt_items(
Expand All @@ -176,13 +168,7 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):

if exploration_eval:
test_exploration_pos_items = pos_items(
[
[
test_item_indices[idx]
for idx in test_set.baskets[gt_bid]
if test_item_indices[idx] not in history_items
]
]
[[iid for iid in gt_basket if iid not in history_items]]
)
if len(test_exploration_pos_items) > 0:
_, u_gt_pos_items, u_gt_neg_items = get_gt_items(
Expand All @@ -200,18 +186,21 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
# avg results of ranking metrics
for i, mt in enumerate(metrics):
avg_results["conventional"].append(
sum(user_results["conventional"][i].values())
/ len(user_results["conventional"][i])
np.mean(list(user_results["conventional"][i].values()))
if len(user_results["conventional"][i]) > 0
else 0
)
if repetition_eval:
avg_results["repetition"].append(
sum(user_results["repetition"][i].values())
/ len(user_results["repetition"][i])
np.mean(list(user_results["repetition"][i].values()))
if len(user_results["repetition"][i]) > 0
else 0
)
if exploration_eval:
avg_results["exploration"].append(
sum(user_results["exploration"][i].values())
/ len(user_results["exploration"][i])
np.mean(list(user_results["exploration"][i].values()))
if len(user_results["repetition"][i]) > 0
else 0
)

return avg_results, user_results
Expand Down Expand Up @@ -365,13 +354,13 @@ def _build_datasets(self, train_data, test_data, val_data=None):
print("Total items = {}".format(self.total_items))
print("Total baskets = {}".format(self.total_baskets))

def _eval(self, model, test_set, **kwargs):
def eval(self, model, test_set, ranking_metrics, **kwargs):
metric_avg_results = OrderedDict()
metric_user_results = OrderedDict()

avg_results, user_results = ranking_eval(
model=model,
metrics=self.ranking_metrics,
metrics=ranking_metrics,
train_set=self.train_set,
test_set=test_set,
repetition_eval=self.repetition_eval,
Expand All @@ -380,12 +369,12 @@ def _eval(self, model, test_set, **kwargs):
verbose=self.verbose,
)

for i, mt in enumerate(self.ranking_metrics):
for i, mt in enumerate(ranking_metrics):
metric_avg_results[mt.name] = avg_results["conventional"][i]
metric_user_results[mt.name] = user_results["conventional"][i]

if self.repetition_eval:
for i, mt in enumerate(self.ranking_metrics):
for i, mt in enumerate(ranking_metrics):
metric_avg_results["{}-rep".format(mt.name)] = avg_results[
"repetition"
][i]
Expand All @@ -394,7 +383,7 @@ def _eval(self, model, test_set, **kwargs):
][i]

if self.repetition_eval:
for i, mt in enumerate(self.ranking_metrics):
for i, mt in enumerate(ranking_metrics):
metric_avg_results["{}-expl".format(mt.name)] = avg_results[
"exploration"
][i]
Expand Down
46 changes: 36 additions & 10 deletions cornac/models/gp_top/recom_gp_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class GPTop(NextBasketRecommender):
use_personalized_popularity: boolean, optional, default: True
When False, no item frequency from history baskets are being used.
use_quantity: boolean, optional, default: False
When True, constructing item frequency based on its quantity (getting from extra_data).
The data must be in fmt 'UBITJson'.
References
----------
Ming Li, Sami Jullien, Mozhdeh Ariannezhad, and Maarten de Rijke. 2023.
Expand All @@ -42,31 +46,53 @@ class GPTop(NextBasketRecommender):
"""

def __init__(
self, name="GPTop", use_global_popularity=True, use_personalized_popularity=True
self,
name="GPTop",
use_global_popularity=True,
use_personalized_popularity=True,
use_quantity=False,
):
super().__init__(name=name, trainable=False)
self.use_global_popularity = use_global_popularity
self.use_personalized_popularity = use_personalized_popularity
self.use_quantity = use_quantity
self.item_freq = Counter()

def fit(self, train_set, val_set=None):
super().fit(train_set=train_set, val_set=val_set)
if self.use_global_popularity:
self.item_freq = Counter(self.train_set.uir_tuple[1])
if self.use_quantity:
self.item_freq = Counter()
for idx, iid in enumerate(self.train_set.uir_tuple[1]):
self.item_freq[iid] += self.train_set.extra_data[idx].get(
"quantity", 0
)
else:
self.item_freq = Counter(self.train_set.uir_tuple[1])
return self

def score(self, user_idx, history_baskets, **kwargs):
item_scores = np.ones(self.total_items)
item_scores = np.zeros(self.total_items, dtype=np.float32)
if self.use_global_popularity:
for iid, freq in self.item_freq.items():
item_scores[iid] = freq

if self.use_personalized_popularity:
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])

max_item_freq = (
max(self.item_freq.values()) if len(self.item_freq) > 0 else 1
)
for iid, freq in self.item_freq.items():
item_scores[iid] = freq / max_item_freq

if self.use_personalized_popularity:
if self.use_quantity:
history_basket_bids = kwargs.get("history_basket_ids")
baskets = kwargs.get("baskets")
p_item_freq = Counter()
(_, item_ids, _) = kwargs.get("uir_tuple")
extra_data = kwargs.get("extra_data")
for bid in history_basket_bids:
ids = baskets[bid]
for idx in ids:
p_item_freq[item_ids[idx]] += extra_data[idx].get("quantity", 0)
else:
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])
for iid, cnt in p_item_freq.most_common():
item_scores[iid] = max_item_freq + cnt
item_scores[iid] += cnt
return item_scores
3 changes: 2 additions & 1 deletion examples/gp_top_tafeng.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
)

models = [
GPTop(name="PTop", use_global_popularity=False),
GPTop(name="GTop", use_personalized_popularity=False),
GPTop(name="PTop", use_global_popularity=False),
GPTop(name="GPTop-quantity", use_quantity=True),
GPTop(),
]

Expand Down
Loading

0 comments on commit af62a20

Please sign in to comment.