Skip to content

Commit

Permalink
Add logq for cross-entropy loss
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Dec 29, 2023
1 parent 9def663 commit 2c50190
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion cornac/models/gru4rec/recom_gru4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
from tqdm.auto import trange
from collections import Counter

from cornac.models.recommender import NextItemRecommender

Expand Down Expand Up @@ -68,6 +69,9 @@ class GRU4Rec(NextItemRecommender):
elu_param: float, optional, default: 0.5
Elu param for 'bpr-max' loss
logq: float, optional, default: 0,
LogQ correction to offset the sampling bias affecting 'cross-entropy' loss.
device: str, optional, default: 'cpu'
Set to 'cuda' for GPU support.
Expand All @@ -82,7 +86,7 @@ class GRU4Rec(NextItemRecommender):
References
----------
Hidasi, B., Karatzoglou, A., Baltrunas, L., & Tikk, D. (2015).
Hidasi, B., Karatzoglou, A., Baltrunas, L., & Tikk, D. (2015).
Session-based recommendations with recurrent neural networks.
arXiv preprint arXiv:1511.06939.
Expand All @@ -105,6 +109,7 @@ def __init__(
n_epochs=10,
bpreg=1.0,
elu_param=0.5,
logq=0.0,
device="cpu",
trainable=True,
verbose=False,
Expand All @@ -126,6 +131,7 @@ def __init__(
self.n_epochs = n_epochs
self.bpreg = bpreg
self.elu_param = elu_param
self.logq = logq
self.device = device
self.seed = seed
self.rng = get_rng(seed)
Expand All @@ -143,6 +149,10 @@ def _set_loss_function(self, loss):
def _xe_loss_with_softmax(self, O, Y, M):
import torch

if self.logq > 0:
O = O - self.logq * torch.log(
torch.cat([self.P0[Y[:M]], self.P0[Y[M:]] ** self.sample_alpha])
)
X = torch.exp(O - O.max(dim=1, keepdim=True)[0])
X = X / X.sum(dim=1, keepdim=True)
return -torch.sum(torch.log(torch.diag(X) + 1e-24))
Expand Down Expand Up @@ -194,6 +204,14 @@ def fit(self, train_set, val_set=None):

from .gru4rec import GRU4RecModel, IndexedAdagradM, io_iter

if self.logq and self.loss == "cross-entropy":
pop = Counter(self.train_set.uir_tuple[1])
self.P0 = torch.tensor(
[pop[iid] for (_, iid) in self.train_set.iid_map.items()],
dtype=torch.float32,
device=self.device,
)

self.model = GRU4RecModel(
self.total_items,
self.layers,
Expand Down Expand Up @@ -255,6 +273,7 @@ def fit(self, train_set, val_set=None):

def score(self, user_idx, history_items, **kwargs):
from .gru4rec import score

if len(history_items) > 0:
return score(self.model, self.layers, self.device, history_items)
return np.ones(self.total_items, dtype="float")

0 comments on commit 2c50190

Please sign in to comment.