Skip to content

Commit

Permalink
Add option to scoring function based on quantity provided in extra_data
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Dec 7, 2023
1 parent 89ffe48 commit 92b7891
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
33 changes: 30 additions & 3 deletions cornac/models/gp_top/recom_gp_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class GPTop(NextBasketRecommender):
use_personalized_popularity: boolean, optional, default: True
When False, no item frequency from history baskets are being used.
use_quantity: boolean, optional, default: False
When True, constructing item frequency based on its quantity (getting from extra_data).
The data must be in fmt 'UBITJson'.
References
----------
Ming Li, Sami Jullien, Mozhdeh Ariannezhad, and Maarten de Rijke. 2023.
Expand All @@ -42,17 +46,29 @@ class GPTop(NextBasketRecommender):
"""

def __init__(
self, name="GPTop", use_global_popularity=True, use_personalized_popularity=True
self,
name="GPTop",
use_global_popularity=True,
use_personalized_popularity=True,
use_quantity=False,
):
super().__init__(name=name, trainable=False)
self.use_global_popularity = use_global_popularity
self.use_personalized_popularity = use_personalized_popularity
self.use_quantity = use_quantity
self.item_freq = Counter()

def fit(self, train_set, val_set=None):
super().fit(train_set=train_set, val_set=val_set)
if self.use_global_popularity:
self.item_freq = Counter(self.train_set.uir_tuple[1])
if self.use_quantity:
self.item_freq = Counter()
for idx, iid in enumerate(self.train_set.uir_tuple[1]):
self.item_freq[iid] += self.train_set.extra_data[idx].get(
"quantity", 0
)
else:
self.item_freq = Counter(self.train_set.uir_tuple[1])
return self

def score(self, user_idx, history_baskets, **kwargs):
Expand All @@ -65,7 +81,18 @@ def score(self, user_idx, history_baskets, **kwargs):
item_scores[iid] = freq / max_item_freq

if self.use_personalized_popularity:
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])
if self.use_quantity:
history_basket_bids = kwargs.get("history_basket_ids")
baskets = kwargs.get("baskets")
p_item_freq = Counter()
(_, item_ids, _) = kwargs.get("uir_tuple")
extra_data = kwargs.get("extra_data")
for bid in history_basket_bids:
ids = baskets[bid]
for idx in ids:
p_item_freq[item_ids[idx]] += extra_data[idx].get("quantity", 0)
else:
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])
for iid, cnt in p_item_freq.most_common():
item_scores[iid] += cnt
return item_scores
3 changes: 2 additions & 1 deletion examples/gp_top_tafeng.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
)

models = [
GPTop(name="PTop", use_global_popularity=False),
GPTop(name="GTop", use_personalized_popularity=False),
GPTop(name="PTop", use_global_popularity=False),
GPTop(name="GPTop-quantity", use_quantity=True),
GPTop(),
]

Expand Down

0 comments on commit 92b7891

Please sign in to comment.