Skip to content

Training Transitive and Commutative Multimodal Transformers with LoReTTa

License

Notifications You must be signed in to change notification settings

manuel-tran/loretta

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

64 Commits
 
 
 
 
 
 
 
 

Repository files navigation

LoReTTa (NeurIPS 2023)

Official repository for LoReTTa ($\textbf{L}$inking m$\textbf{O}$dalities with a t$\textbf{R}$ansitive and commutativ$\textbf{E}$ pre-$\textbf{T}$raining s$\textbf{T}$r$\textbf{A}$tegy). [arXiv] [website]

While we regret that we cannot release the full code due to company policy, we do our best to guide researchers in implementing our work. We provide the dataset, pseudocode, and point to implementations of related work.

Method

Imagine we have two datasets, one with paired image and text, and one with paired text and audio. How do we train a multimodal model that also works with paired image and audio? Here we introduce commutative and transitive pre-training (see also pseudocode.py):

  1. Train a model to generate image from text, text from image, text from audio, and audio from text.
  2. Given a paired sample (image, text), we use the text to generate audio as a pseudo data point.
  3. The generated audio, aligned with the text, is then used as conditioning to generate an image.
  4. The generated image is compared to the original image in (image, text) to enforce consistency.
  5. This is how we connect image and audio. This also works the other way around with (text, audio).

Models

LoReTTa is a self-supervised learning framework that works with any modality-agnostic architecture. We choose the Transformer decoder for its simplicity and scalability. For the best performance, we recommend using its modern implementation based on Llama or Mistral. We also enable FlashAttention-2 to speed up training and inference time. Alternative models that can handle sequences like Hyena or Mamba can also be used.

Tokenization

The input to the Transformer is a sequence of tokens. So we need to tokenize our data. For images, we use image patches as tokens; for text, we use subwords as tokens; and so on. Since we are modeling the data in pixel space, we can either use the raw discretized values or pre-trained VQ-VAEs. It is also possible to model the data in latent space to avoid using VQ-VAEs‼️ In svl_mnist.py, we show an example using the byte values as tokens.

Causal modeling

The core of LoReTTa is next token prediction (also known as causal language modeling). It is currently one of the most powerful frameworks for generative pre-training due to its data efficiency, as training can be effectively parallelized using attention masks. During training the input and target are shifted by one and a upper-triangular causal attention mask is used so that only the previous tokens can be used to predict the next one.

Multimodality

Since language modeling only models the next token given previous tokens, these tokens can theoretically come from any modality — in any order. This idea is explored in DALLE and MMGPT to generate images from text and more. In a nutshell, these methods model the relation $A \rightarrow B$. We go one step further and model $B \rightarrow A$ as well. In fact, if the model has enough capacity, it can handle even more modality combinations, such as $B \rightarrow C$ and $C \rightarrow B$. To help the model better distinguish between different modalities, we prepend a class token (or modality token) to each modality.

Autoregressive decoding

Given a data point $(A, B)$, how do we generate modality $C$ conditioned on $B$ for transitive modeling $B \rightarrow C \rightarrow A$? We use autoregressive decoding. The input to the model $f$ is the tokenized context $B = [b_0, ...., b_n]$ together with the class token of the target modality (in this case $c_0$). The model predicts the next token $f([B, c_0]) = c_1$. We do this iteratively and get $f([B, c_0, c_1]) = c_2$, ..., $f([B, c_0, c_1, ..., c_{m-1}]) = c_m$.

About

Training Transitive and Commutative Multimodal Transformers with LoReTTa

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages