This is a implementation of OpenAI's DALL-E 2 [Link] [Paper] in PyTorch. This implementation is suitable for simple text-to-image generation tasks.
Generated samples on CIFAR-10 dataset:
Generated samples on custom geometric shapes dataset:
![image](https://private-user-images.githubusercontent.com/23311201/248138240-d608873e-36cc-4342-93cb-c5e21fe4c03b.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkxNjA2MzgsIm5iZiI6MTczOTE2MDMzOCwicGF0aCI6Ii8yMzMxMTIwMS8yNDgxMzgyNDAtZDYwODg3M2UtMzZjYy00MzQyLTkzY2ItYzVlMjFmZTRjMDNiLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTAlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEwVDA0MDUzOFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWY1ZDY1NjUyZTAyZmJiOGVjMjE3NGRlZWRmNDZiOTgzNTJiOGE4ZjdhMThhMDA3MDU5MzU1N2Y2NDA2MTU2OGMmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.U8rqmPJMoW-rT6Sow8f8HlwoMByvjp3h9CO6ifm2xsU)
The full pipeline consists of 3 models: CLIP [Paper], DALL-E 2 prior and DALL-E 2 decoder.
CLIP is a zero-shot model that learns a shared, multimodal latent representation of text captions and images. Unlike standard image classification models that use a feature extraction network and a final linear classification network, CLIP uses an image encoder and a text encoder to obtain pairs of shared embeddings of images and texts in the latent space.
To train DALL-E 2, you need train CLIP first. To train CLIP, run
python clip/train.py
You have to specify the dataset path and the path where the final model is saved in model_config.yml
.
The prior generates the CLIP image embedding based on the text caption.
To train the prior, run
python dalle2/train_prior.py
similar to CLIP, you have to specify the dataset path and model saving path in model_config.yml
.
The DALL-E 2 decoder is used to generate images conditioned on CLIP image embeddings and text captions.
To train the decoder, run
python dalle2/train_decoder.py
Do not forget to specify the paths in model_config.yml
.
The example below shows how to sample images from texts
# Initialise and load CLIP
clip = CLIP(...)
clip_path = ...
clip.load_state_dict(clip_path)
# Initialise and load prior
prior = Prior(...)
prior_path = ...
prior.load_state_dict(prior_path)
# Initialise and load decoder
decoder = Decoder(...)
decoder_path = ...
decoder.load_state_dict(decoder_path)
# Initialise DALL-E 2
dalle2 = DALLE2(clip, prior, decoder)
# Set DALL-E 2 to evaluation mode
dalle2.val_mode()
# Sample the image from text caption, cf_guidance_scale is the classifier-free guidance scale
image_size = (3, 32, 32)
image = dalle2(image_size, text="a small black square and a large gold pentagon", cf_guidance_scale=2)