Skip to content

Commit

Permalink
PreferredAI#641 Fix BiVAECF expected dense matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
dvquy13 committed Aug 12, 2024
1 parent 706ce7a commit cd6184f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
8 changes: 4 additions & 4 deletions cornac/models/bivaecf/bivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def learn(
i_count = 0
for i_ids in train_set.item_iter(batch_size, shuffle=False):
i_batch = tx[i_ids, :]
i_batch = i_batch.A
i_batch = i_batch.todense().A
i_batch = torch.tensor(i_batch, dtype=dtype, device=device)

# Reconstructed batch
Expand Down Expand Up @@ -228,7 +228,7 @@ def learn(
u_count = 0
for u_ids in train_set.user_iter(batch_size, shuffle=False):
u_batch = x[u_ids, :]
u_batch = u_batch.A
u_batch = u_batch.todense().A
u_batch = torch.tensor(u_batch, dtype=dtype, device=device)

# Reconstructed batch
Expand Down Expand Up @@ -259,7 +259,7 @@ def learn(
# infer mu_beta
for i_ids in train_set.item_iter(batch_size, shuffle=False):
i_batch = tx[i_ids, :]
i_batch = i_batch.A
i_batch = i_batch.todense().A
i_batch = torch.tensor(i_batch, dtype=dtype, device=device)

beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta)
Expand All @@ -268,7 +268,7 @@ def learn(
# infer mu_theta
for u_ids in train_set.user_iter(batch_size, shuffle=False):
u_batch = x[u_ids, :]
u_batch = u_batch.A
u_batch = u_batch.todense().A
u_batch = torch.tensor(u_batch, dtype=dtype, device=device)

theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta)
Expand Down
15 changes: 15 additions & 0 deletions tests/cornac/models/bivae/test_recommender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import unittest

from cornac.data import Dataset, Reader
from cornac.models import BiVAECF


class TestRecommender(unittest.TestCase):
def setUp(self):
self.data = Reader().read("./tests/data.txt")

def test_run(self):
bivae = BiVAECF(k=1, seed=123)
dataset = Dataset.from_uir(self.data)
# Assert runs without error
bivae.fit(dataset)

0 comments on commit cd6184f

Please sign in to comment.