Skip to content

Commit

Permalink
Format code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 authored and github-actions[bot] committed Sep 21, 2023
1 parent b8e9a63 commit fea514f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions cca_zoo/deep/_generative/_dccae.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def _decode(self, z, **kwargs):
return recon

def loss(self, batch, **kwargs):
z = self(batch['views'])
z = self(batch["views"])
recons = self._decode(z)
loss = dict()
loss["reconstruction"] = torch.stack(
[
self.recon_loss(x, recon, loss_type=self.recon_loss_type)
for x, recon in zip(batch['views'], recons)
for x, recon in zip(batch["views"], recons)
]
).sum()
loss["correlation"] = self.objective.loss(z)
Expand Down
8 changes: 4 additions & 4 deletions cca_zoo/deep/_generative/_dvcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,22 @@ def _decode(self, z, uncertainty=False, **kwargs):
return x

def loss(self, batch, **kwargs):
z = self(batch['views'], mle=False)
z = self(batch["views"], mle=False)
recons = self._decode(z)
loss = dict()
loss["reconstruction"] = torch.stack(
[
self.recon_loss(x, recon, loss_type=self.recon_loss_type)
for x, recon in zip(batch['views'], recons)
for x, recon in zip(batch["views"], recons)
]
).sum()
loss["kl shared"] = (
self.kl_loss(z["mu_shared"], z["logvar_shared"]) / batch['views'][0].numel()
self.kl_loss(z["mu_shared"], z["logvar_shared"]) / batch["views"][0].numel()
)
if "private" in z:
loss["kl private"] = torch.stack(
[
self.kl_loss(mu_, logvar_) / batch['views'][0].numel()
self.kl_loss(mu_, logvar_) / batch["views"][0].numel()
for mu_, logvar_ in zip(z["mu_private"], z["logvar_private"])
]
).sum()
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/deep/_generative/_splitae.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def _decode(self, z, **kwargs):
return recon

def loss(self, batch, **kwargs):
z = self(batch['views'])
z = self(batch["views"])
recons = self._decode(z)
loss = dict()
loss["reconstruction"] = torch.stack(
[
self.recon_loss(x, recon, loss_type=self.recon_loss_type)
for x, recon in zip(batch['views'], recons)
for x, recon in zip(batch["views"], recons)
]
).sum()
loss["objective"] = loss["reconstruction"]
Expand Down

0 comments on commit fea514f

Please sign in to comment.