Official repository for LoReTTa (
While we regret that we cannot release the full code due to company policy, we do our best to guide researchers in implementing our work. We provide the dataset, pseudocode, and point to implementations of related work.
Imagine we have two datasets, one with paired image and text, and one with paired text and audio. How do we train a multimodal model that also works with paired image and audio? Here we introduce commutative and transitive pre-training (see also pseudocode.py):
- Train a model to generate image from text, text from image, text from audio, and audio from text.
- Given a paired sample (image, text), we use the text to generate audio as a pseudo data point.
- The generated audio, aligned with the text, is then used as conditioning to generate an image.
- The generated image is compared to the original image in (image, text) to enforce consistency.
- This is how we connect image and audio. This also works the other way around with (text, audio).
LoReTTa is a self-supervised learning framework that works with any modality-agnostic architecture. We choose the Transformer decoder for its simplicity and scalability. For the best performance, we recommend using its modern implementation based on Llama or Mistral. We also enable FlashAttention-2 to speed up training and inference time. Alternative models that can handle sequences like Hyena or Mamba can also be used.
The input to the Transformer is a sequence of tokens. So we need to tokenize our data. For images, we use image patches as tokens; for text, we use subwords as tokens; and so on. Since we are modeling the data in pixel space, we can either use the raw discretized values or pre-trained VQ-VAEs. It is also possible to model the data in latent space to avoid using VQ-VAEs
The core of LoReTTa is next token prediction (also known as causal language modeling). It is currently one of the most powerful frameworks for generative pre-training due to its data efficiency, as training can be effectively parallelized using attention masks. During training the input and target are shifted by one and a upper-triangular causal attention mask is used so that only the previous tokens can be used to predict the next one.
Since language modeling only models the next token given previous tokens, these tokens can theoretically come from any modality — in any order. This idea is explored in DALLE and MMGPT to generate images from text and more. In a nutshell, these methods model the relation
Given a data point