Skip to content

Latest commit

 

History

History
57 lines (50 loc) · 2.65 KB

README.md

File metadata and controls

57 lines (50 loc) · 2.65 KB

Importance Weighted Auto-Encoders Pytorch

This Code to reproduce the experiments in the Importance Weighted Auto-Encoders(IWAE) paper(2016) by Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov.The implementation was tested on the MNIST dataset to replicate the result in the above paper. You can train and test VAE and IWAE with 1 or 2 stochastic layers in different configurations of K and M in this repo.

Prerequisites for running the code

Dataset

Download the required dataset by running the following command.

python download_MNIST.py

Python packages

pytorch==1.1.0 numpy==1.14.2

Running the experiments

This code allows you to train, evaluate and compare VAE and IWAE architectures on the mnist dataset. To train and test the model, run the following commands.

Trainning Original VAE

python main_train.py  --model VAE --num_stochastic_layers 1 --num_m 1 --num_k 1

Training IWAE with 2 stochastic layers

python main_train.py  --model IWAE --num_stochastic_layers 2 --num_m 1 --num_k 5

Testing Original VAE

python main_test.py  --model VAE --num_stochastic_layers 1 --num_m 1 --num_k 1 --epoch 4999

Testing IWAE with 2 stochastic layers

python main_test.py  --model IWAE --num_stochastic_layers 2 --num_m 1 --num_k 5 --epoch 4999

Testing IWAE with 2 stochastic layers on log likelihood

python main_test_k.py  --model IWAE --num_stochastic_layers 2 --num_m 1 --num_k 5 --num_k_test 5000 --epoch 4999

See the training file and the test file for more options.

Experiment results of this repo on binarized MNIST dataset

Method NLL (This repo) NLL (IWAE paper) NLL (MIWAE paper)
VAE or IWAE(M=K=1) 86.28 86.76 -
MIWAE(1,64) 84.62 - 84.52
MIWAE(4,16) 83.81 - 84.56
MIWAE(8,8) 84.77 - 84.97
MIWAE(16,4) 85.01 - -
MIWAE(64,1) 87.15 - 86.21
Method IWAEMK loss (This repo) IWAEMK loss (MIWAE paper)
VAE or IWAEM=K=1 90.32 -
MIWAE(1,64) 86.21 86.11
MIWAE(4,16) 84.92 85.60
MIWAE(8,8) 85.82 85.69
MIWAE(16,4) 86.12 -
MIWAE(64,1) 87.81 86.69