Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
hieuddo committed Nov 22, 2023
1 parent e510e98 commit 8846778
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions cornac/models/mf/backend_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def __init__(

def forward(self, uids, iids):
ues = self.u_factors(uids)
uis = self.i_factors(iids)
ies = self.i_factors(iids)

preds = (self.dropout(ues) * self.dropout(uis)).sum(dim=1, keepdim=True)
preds = (self.dropout(ues) * self.dropout(ies)).sum(dim=1, keepdim=True)
if self.use_bias:
preds += self.u_biases(uids) + self.i_biases(iids) + self.global_mean

Expand All @@ -85,17 +85,13 @@ def learn(
):
model = model.to(device)
criteria = nn.MSELoss(reduction="sum")
optimizer = OPTIMIZER_DICT[optimizer](
params=model.parameters(), lr=learn_rate, weight_decay=reg
)
optimizer = OPTIMIZER_DICT[optimizer](params=model.parameters(), lr=learn_rate, weight_decay=reg)

progress_bar = trange(1, n_epochs + 1, disable=not verbose)
for _ in progress_bar:
sum_loss = 0.0
count = 0
for batch_id, (u_batch, i_batch, r_batch) in enumerate(
train_set.uir_iter(batch_size, shuffle=True)
):
for batch_id, (u_batch, i_batch, r_batch) in enumerate(train_set.uir_iter(batch_size, shuffle=True)):
u_batch = torch.from_numpy(u_batch).to(device)
i_batch = torch.from_numpy(i_batch).to(device)
r_batch = torch.tensor(r_batch, dtype=torch.float).to(device)
Expand Down

0 comments on commit 8846778

Please sign in to comment.