This repository implements VQVAE for mnist and colored version of mnist and follows up with a simple LSTM for generating numbers.
![VQVAE Video](https://private-user-images.githubusercontent.com/144267687/302552988-a411d732-8c99-41fb-b39c-dd2c3fbfa448.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1NzI4NjUsIm5iZiI6MTczOTU3MjU2NSwicGF0aCI6Ii8xNDQyNjc2ODcvMzAyNTUyOTg4LWE0MTFkNzMyLThjOTktNDFmYi1iMzljLWRkMmMzZmJmYTQ0OC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE0JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNFQyMjM2MDVaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT01MDJhMTgwNjE1NTEzNmQ5ODgxMzMxODk4NzllNTJkODQ0NTVmZTQzY2I3ZTBhNGE3YmQxZWFiNjg0Y2JhNjU3JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.86-qklX9Xa65cmf8kI5f2BPGocKSzCUL9biuCpisylA)
- Create a new conda environment with python 3.8 then run below commands
git clone https://github.com/explainingai-code/VQVAE-Pytorch.git
cd VQVAE-Pytorch
pip install -r requirements.txt
- For running a simple VQVAE with minimal code to understand the basics
python run_simple_vqvae.py
- For playing around with VQVAE and training/inferencing the LSTM use the below commands passing the desired configuration file as the config argument
python -m tools.train_vqvae
for training vqvaepython -m tools.infer_vqvae
for generating reconstructions and encoder outputs for LSTM trainingpython -m tools.train_lstm
for training minimal LSTMpython -m tools.generate_images
for using the trained LSTM to generate some numbers
config/vqvae_mnist.yaml
- VQVAE for training on black and white mnist imagesconfig/vqvae_colored_mnist.yaml
- VQVAE with more embedding vectors for training colored mnist images
For setting up the dataset: Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
Verify the data directory has the following structure:
VQVAE-Pytorch/data/train/images/{0/1/.../9}
*.png
VQVAE-Pytorch/data/test/images/{0/1/.../9}
*.png
Outputs will be saved according to the configuration present in yaml files.
For every run a folder of task_name
key in config will be created and output_train_dir
will be created inside it.
During training of VQVAE the following output will be saved
- Best Model checkpoints(VQVAE and LSTM) in
task_name
directory
During inference the following output will be saved
- Reconstructions for sample of test set in
task_name/output_train_dir/reconstruction.png
- Encoder outputs on train set for LSTM training in
task_name/output_train_dir/mnist_encodings.pkl
- LSTM generation output in
task_name/output_train_dir/generation_results.png
Running run_simple_vqvae
should be very quick (as its very simple model) and give you below reconstructions (input in black black background and reconstruction in white background)
![](https://private-user-images.githubusercontent.com/144267687/273195458-607fb5a8-b880-4af5-8ce0-5d7127aa66a7.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1NzI4NjUsIm5iZiI6MTczOTU3MjU2NSwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMTk1NDU4LTYwN2ZiNWE4LWI4ODAtNGFmNS04Y2UwLTVkNzEyN2FhNjZhNy5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE0JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNFQyMjM2MDVaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT04NWU2OTNmZjExNzhmZWNiMTUwM2YzNWQ2NTFlZTcwYmE4YzY5ZDQ0NWE0OTBjNjNlZWIxODZjYmM5N2Q5ZDdjJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.hZ28tBmHxxrqUwsHp1KblG42v5OdC2z64znrTH3KV20)
Running default config VQVAE for mnist should give you below reconstructions for both versions
![](https://private-user-images.githubusercontent.com/144267687/273195513-939f8f22-0145-467f-8cd6-4b6c6e6f315f.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1NzI4NjUsIm5iZiI6MTczOTU3MjU2NSwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMTk1NTEzLTkzOWY4ZjIyLTAxNDUtNDY3Zi04Y2Q2LTRiNmM2ZTZmMzE1Zi5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE0JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNFQyMjM2MDVaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0wMWEzZjExOGE2NjY0YjllYTY1M2JhMjIyMzZiNzkwZGRjYzgzMGFmZjUyZjI5OTJiN2UzNWJjYmI0N2I4MThkJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.yuDwTcocNTLl0_Zef44GHmdmPBiaZxiq61OacOrb-qw)
![](https://private-user-images.githubusercontent.com/144267687/273195627-0e28286a-bc4c-44e3-a385-84d1ae99492c.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1NzI4NjUsIm5iZiI6MTczOTU3MjU2NSwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMTk1NjI3LTBlMjgyODZhLWJjNGMtNDRlMy1hMzg1LTg0ZDFhZTk5NDkyYy5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE0JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNFQyMjM2MDVaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1mMzYzNWZkZmZhOWM0ODQwYTZkODhhNWFhYTVlYzc3NTE5MzFjNGMxYjYyNTk4MzAyMDNiOTRmYTg4ZTIyZGM4JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.RVtbYiMbeRCdBpcL4UJqg-IBBffxaO9GTfD6g7fQCeA)
Sample Generation Output after just 10 epochs Training the vqvae and lstm longer and more parameters(codebook size, codebook dimension, channels , lstm hidden dimension e.t.c) will give better results
![](https://private-user-images.githubusercontent.com/144267687/273196494-688a6631-df34-4fde-9508-a05ae3c2ae91.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1NzI4NjUsIm5iZiI6MTczOTU3MjU2NSwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMTk2NDk0LTY4OGE2NjMxLWRmMzQtNGZkZS05NTA4LWEwNWFlM2MyYWU5MS5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE0JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNFQyMjM2MDVaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1jNDY5NTUxMWFhOWM3YjFlZTQzYzg1NmIyZjU5MWUyZjEyMmEyY2NlYTBkODNlNjY5MGY2YWYzN2ZkYjcxN2NiJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.8mrhfUFFLWwMcRwTZEbRnp662fWugVVVld2Pfue5tpM)
![](https://private-user-images.githubusercontent.com/144267687/273203035-187fa630-a7ef-4f0b-aef7-5c6b53019b38.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1NzI4NjUsIm5iZiI6MTczOTU3MjU2NSwicGF0aCI6Ii8xNDQyNjc2ODcvMjczMjAzMDM1LTE4N2ZhNjMwLWE3ZWYtNGYwYi1hZWY3LTVjNmI1MzAxOWIzOC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE0JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNFQyMjM2MDVaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02ZWU2ZTQ5YmE0NmU2ZTYzYjcyY2Q0ZDNiOTY1ZjlmNzFjZjgyNDg4MTAzNDk2OWM5MjVkNDc3MzFhNmM4NWU2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.8uFjHWENDrPfNyUSSsvpjjq56b2R9vfGl1K62a0wIjs)
@misc{oord2018neural,
title={Neural Discrete Representation Learning},
author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
year={2018},
eprint={1711.00937},
archivePrefix={arXiv},
primaryClass={cs.LG}
}