diff --git a/README.md b/README.md index 1bfbca59..3f190cd2 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ $ tensorboard --logdir . | SWAE (200 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] | | VQ-VAE (*K = 512, D = 64*) ([Code][vqvae_code], [Config][vqvae_config])|[Link](https://arxiv.org/abs/1711.00937) | ![][31] | **N/A** | | DIP VAE ([Code][dipvae_code], [Config][dipvae_config]) |[Link](https://arxiv.org/abs/1711.00848) | ![][36] | ![][35] | +| VNDAE ([Code][vndae_code], [Config][vndae_config]) |[Link](https://arxiv.org/pdf/2101.11353.pdf) | ![][37] | ![][38] | @@ -189,6 +190,7 @@ Additionally, if you would like to contribute some models, please submit a PR. [infovae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/info_vae.py [vqvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py [dipvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dip_vae.py +[vndae_code]: https://github.com/ralphc1212/PyTorch-VAE/blob/deaac5a3165ea1048cfb129aa72b0a1c33f55041/models/vnd_ae.py [vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml [cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml @@ -208,6 +210,7 @@ Additionally, if you would like to contribute some models, please submit a PR. [infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml [vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml [dipvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dip_vae.yaml +[vndae_config]: https://github.com/ralphc1212/PyTorch-VAE/blob/deaac5a3165ea1048cfb129aa72b0a1c33f55041/configs/vndae.yaml [1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png [2]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_Vanilla%20VAE_25.png @@ -244,6 +247,8 @@ Additionally, if you would like to contribute some models, please submit a PR. [34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_49.png [35]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DIPVAE_83.png [36]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DIPVAE_83.png +[37]: https://github.com/ralphc1212/PyTorch-VAE/blob/deaac5a3165ea1048cfb129aa72b0a1c33f55041/assets/recons_VNDAE_1.png +[38]: https://github.com/ralphc1212/PyTorch-VAE/blob/deaac5a3165ea1048cfb129aa72b0a1c33f55041/assets/VNDAE_1.png [python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg [python-url]: https://www.python.org/ diff --git a/assets/VNDAE_1.png b/assets/VNDAE_1.png new file mode 100644 index 00000000..2f8dfdf9 Binary files /dev/null and b/assets/VNDAE_1.png differ diff --git a/assets/recons_VNDAE_1.png b/assets/recons_VNDAE_1.png new file mode 100644 index 00000000..87e4f2b3 Binary files /dev/null and b/assets/recons_VNDAE_1.png differ diff --git a/configs/vndae.yaml b/configs/vndae.yaml new file mode 100644 index 00000000..20028ea5 --- /dev/null +++ b/configs/vndae.yaml @@ -0,0 +1,28 @@ +model_params: + name: 'VNDAE' + in_channels: 3 + latent_dim: 128 + + +data_params: + data_path: "data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + +exp_params: + LR: 0.005 + weight_decay: 0 + scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 + +trainer_params: + gpus: [0] + max_epochs: 50 + +logging_params: + save_dir: "logs/" + name: "VNDAE" \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py index 3f310f29..e0e59fee 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -21,7 +21,7 @@ from .vq_vae import * from .betatc_vae import * from .dip_vae import * - +from .vnd_ae import * # Aliases VAE = VanillaVAE @@ -35,6 +35,7 @@ 'SWAE':SWAE, 'MIWAE':MIWAE, 'VQVAE':VQVAE, + 'VNDAE':VNDAE, 'DFCVAE':DFCVAE, 'DIPVAE':DIPVAE, 'BetaVAE':BetaVAE, diff --git a/models/vnd_ae.py b/models/vnd_ae.py new file mode 100644 index 00000000..7d105468 --- /dev/null +++ b/models/vnd_ae.py @@ -0,0 +1,218 @@ +import torch +from models import BaseVAE +from torch import nn +from torch.nn import functional as F +from .types_ import * + +TAU = 1. +PI = 0.95 +RSV_DIM = 1 +EPS = 1e-8 +SAMPLE_LEN = 1. + +class VNDAE(BaseVAE): + + + def __init__(self, + in_channels: int, + latent_dim: int, + hidden_dims: List = None, + **kwargs) -> None: + super(VNDAE, self).__init__() + + self.latent_dim = latent_dim + + modules = [] + if hidden_dims is None: + hidden_dims = [32, 64, 128, 256, 512] + + # Build Encoder + for h_dim in hidden_dims: + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels=h_dim, + kernel_size= 3, stride= 2, padding = 1), + nn.BatchNorm2d(h_dim), + nn.LeakyReLU()) + ) + in_channels = h_dim + + self.encoder = nn.Sequential(*modules) + self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim) + self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim) + self.fc_p_vnd = nn.Linear(hidden_dims[-1] * 4, latent_dim) + + Pi = nn.Parameter(PI * torch.ones(latent_dim - RSV_DIM), requires_grad=False) + + self.ZERO = nn.Parameter(torch.tensor([0.]), requires_grad=False) + self.ONE = nn.Parameter(torch.tensor([1.]), requires_grad=False) + self.pv = nn.Parameter(torch.cat([self.ONE, torch.cumprod(Pi, dim=0)]) + * torch.cat([1 - Pi, self.ONE]), requires_grad=False) + + # Build Decoder + modules = [] + + self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) + + hidden_dims.reverse() + + for i in range(len(hidden_dims) - 1): + modules.append( + nn.Sequential( + nn.ConvTranspose2d(hidden_dims[i], + hidden_dims[i + 1], + kernel_size=3, + stride = 2, + padding=1, + output_padding=1), + nn.BatchNorm2d(hidden_dims[i + 1]), + nn.LeakyReLU()) + ) + + self.decoder = nn.Sequential(*modules) + + self.final_layer = nn.Sequential( + nn.ConvTranspose2d(hidden_dims[-1], + hidden_dims[-1], + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + nn.BatchNorm2d(hidden_dims[-1]), + nn.LeakyReLU(), + nn.Conv2d(hidden_dims[-1], out_channels= 3, + kernel_size= 3, padding= 1), + nn.Tanh()) + + @staticmethod + def clip_beta(tensor, to=5.): + """ + Shrink all tensor's values to range [-to,to] + """ + return torch.clamp(tensor, -to, to) + + def encode(self, input: Tensor) -> List[Tensor]: + """ + Encodes the input by passing through the encoder network + and returns the latent codes. + :param input: (Tensor) Input tensor to encoder [N x C x H x W] + :return: (Tensor) List of latent codes + """ + result = self.encoder(input) + result = torch.flatten(result, start_dim=1) + + # Split the result into mu, var and p_vnd components + # of the latent mixture + mu = self.fc_mu(result) + log_var = self.fc_var(result) + p_vnd = self.fc_p_vnd(result) + + return [mu, log_var, p_vnd] + + def decode(self, z: Tensor) -> Tensor: + """ + Maps the given latent codes + onto the image space. + :param z: (Tensor) [B x D] + :return: (Tensor) [B x C x H x W] + """ + result = self.decoder_input(z) + result = result.view(-1, 512, 2, 2) + result = self.decoder(result) + result = self.final_layer(result) + return result + + def reparameterize(self, mu: Tensor, logvar: Tensor, p_vnd: Tensor) -> Tensor: + """ + Reparameterization trick to sample from the mixture posterior shown in Eq. 28 in [https://arxiv.org/pdf/2101.11353.pdf]. + :param mu: (Tensor) Mean of the latent Gaussian [B x D] + :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] + :param p_vnd: (Tensor) Parameter for the Downhill distribution [B x D] + :return: (Tensor) [B x D] + """ + std = torch.exp(0.5 * logvar) + + # Generate samples for the Downhill distribution + + eps = torch.randn_like(std) + beta = torch.sigmoid(self.clip_beta(p_vnd[:,RSV_DIM:])) + ONES = torch.ones_like(beta[:,0:1]) + qv = torch.cat([ONES, torch.cumprod(beta, dim=1)], dim = -1) * torch.cat([1 - beta, ONES], dim = -1) + s_vnd = F.gumbel_softmax(qv, tau=TAU, hard=True) + + cumsum = torch.cumsum(s_vnd, dim=1) + dif = cumsum - s_vnd + mask0 = dif[:, 1:] + mask1 = 1. - mask0 + s_vnd = torch.cat([torch.ones_like(p_vnd[:,:RSV_DIM]), mask1], dim = -1) + + return (eps * std + mu) * s_vnd + + def forward(self, input: Tensor, **kwargs) -> List[Tensor]: + mu, log_var, p_vnd = self.encode(input) + z = self.reparameterize(mu, log_var, p_vnd) + return [self.decode(z), input, mu, log_var, p_vnd] + + def loss_function(self, + *args, + **kwargs) -> dict: + """ + Computes the VNDAE loss function shown in Eq.29 in [https://arxiv.org/pdf/2101.11353.pdf]. + :param args: + :param kwargs: + :return: + """ + recons = args[0] + input = args[1] + mu = args[2] + log_var = args[3] + p_vnd = args[4] + beta = torch.sigmoid(self.clip_beta(p_vnd[:,RSV_DIM:])) + ONES = torch.ones_like(beta[:,0:1]) + qv = torch.cat([ONES, torch.cumprod(beta, dim=1)], dim = -1) * torch.cat([1 - beta, ONES], dim = -1) + + ZEROS = torch.zeros_like(beta[:, 0:1]) + cum_sum = torch.cat([ZEROS, torch.cumsum(qv[:, 1:], dim = 1)], dim = -1)[:, :-1] + coef1 = torch.sum(qv, dim=1, keepdim=True) - cum_sum + coef1 = torch.cat([torch.ones_like(p_vnd[:,:RSV_DIM]), coef1], dim = -1) + + kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset + recons_loss =F.mse_loss(recons, input) + + kld_gaussian = -0.5 * (1 + log_var - mu ** 2 - log_var.exp()) + + kld_weighted_gaussian = torch.diagonal(kld_gaussian.mm(coef1.t()), 0).mean() + + log_frac = torch.log(qv / self.pv + EPS) + kld_vnd = torch.diagonal(qv.mm(log_frac.t()), 0).mean() + + kld_loss = kld_vnd + kld_weighted_gaussian + loss = recons_loss + kld_weight * kld_loss + return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD': - kld_loss.detach()} + + def sample(self, + num_samples:int, + current_device: int, **kwargs) -> Tensor: + """ + Samples from the latent space given fixed width SAMPLE_LEN. + :param num_samples: (Int) Number of samples + :param current_device: (Int) Device to run the model + :return: (Tensor) + """ + z = torch.randn(num_samples, + self.latent_dim) + + z = torch.cat([z[:, :int(SAMPLE_LEN * self.latent_dim)], torch.zeros_like(z[:, :int((1 - SAMPLE_LEN) * self.latent_dim)])], dim = -1) + z = z.to(current_device) + + samples = self.decode(z) + return samples + + def generate(self, x: Tensor, **kwargs) -> Tensor: + """ + Given an input image x, returns the reconstructed image + :param x: (Tensor) [B x C x H x W] + :return: (Tensor) [B x C x H x W] + """ + + return self.forward(x)[0]