diff --git a/cca_zoo/deep/_generative/_dccae.py b/cca_zoo/deep/_generative/_dccae.py index 4c69b480..30bc19bc 100644 --- a/cca_zoo/deep/_generative/_dccae.py +++ b/cca_zoo/deep/_generative/_dccae.py @@ -70,13 +70,13 @@ def _decode(self, z, **kwargs): return recon def loss(self, batch, **kwargs): - z = self(batch['views']) + z = self(batch["views"]) recons = self._decode(z) loss = dict() loss["reconstruction"] = torch.stack( [ self.recon_loss(x, recon, loss_type=self.recon_loss_type) - for x, recon in zip(batch['views'], recons) + for x, recon in zip(batch["views"], recons) ] ).sum() loss["correlation"] = self.objective.loss(z) diff --git a/cca_zoo/deep/_generative/_dvcca.py b/cca_zoo/deep/_generative/_dvcca.py index ded05e8c..591ff241 100644 --- a/cca_zoo/deep/_generative/_dvcca.py +++ b/cca_zoo/deep/_generative/_dvcca.py @@ -114,22 +114,22 @@ def _decode(self, z, uncertainty=False, **kwargs): return x def loss(self, batch, **kwargs): - z = self(batch['views'], mle=False) + z = self(batch["views"], mle=False) recons = self._decode(z) loss = dict() loss["reconstruction"] = torch.stack( [ self.recon_loss(x, recon, loss_type=self.recon_loss_type) - for x, recon in zip(batch['views'], recons) + for x, recon in zip(batch["views"], recons) ] ).sum() loss["kl shared"] = ( - self.kl_loss(z["mu_shared"], z["logvar_shared"]) / batch['views'][0].numel() + self.kl_loss(z["mu_shared"], z["logvar_shared"]) / batch["views"][0].numel() ) if "private" in z: loss["kl private"] = torch.stack( [ - self.kl_loss(mu_, logvar_) / batch['views'][0].numel() + self.kl_loss(mu_, logvar_) / batch["views"][0].numel() for mu_, logvar_ in zip(z["mu_private"], z["logvar_private"]) ] ).sum() diff --git a/cca_zoo/deep/_generative/_splitae.py b/cca_zoo/deep/_generative/_splitae.py index 330ed2dd..d8d81d36 100644 --- a/cca_zoo/deep/_generative/_splitae.py +++ b/cca_zoo/deep/_generative/_splitae.py @@ -62,13 +62,13 @@ def _decode(self, z, **kwargs): return recon def loss(self, batch, **kwargs): - z = self(batch['views']) + z = self(batch["views"]) recons = self._decode(z) loss = dict() loss["reconstruction"] = torch.stack( [ self.recon_loss(x, recon, loss_type=self.recon_loss_type) - for x, recon in zip(batch['views'], recons) + for x, recon in zip(batch["views"], recons) ] ).sum() loss["objective"] = loss["reconstruction"]