-
Notifications
You must be signed in to change notification settings - Fork 1
/
loss.py
25 lines (21 loc) · 835 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
class VAELoss(torch.nn.Module):
"""
Calculates reconstruction loss and KL divergence loss for VAE.
"""
def __init__(self):
super(VAELoss, self).__init__()
def forward(self, x, x0, mu, logvar):
"""
Args:
x (torch.Tensor): reconstructed input tensor
x0 (torch.Tensor): original input tensor
mu (torch.Tensor): latent space mu
logvar (torch.Tensor): latent space log variance
Returns:
bce (torch.Tensor): binary cross entropy loss (VAE recon loss)
kld (torch.Tensor): KL divergence loss
"""
bce = torch.nn.functional.binary_cross_entropy(x0, x, reduction='mean')
kld = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1), dim=0)
return bce, kld