Skip to content

Official JAX implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States

Notifications You must be signed in to change notification settings

test-time-training/ttt-lm-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

Paper | PyTorch Codebase | Setup | Replicating Experiments | Model Docs | Dataset Preparation | Inference Benchmark

Abstract

Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning.

Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively.

Setup

This codebase is implemented in JAX and has been tested on both GPUs and Cloud TPU VMs with Python 3.11.

For a PyTorch model definition, please refer to this link. For inference kernels, or to replicate speed benchmarks from our paper, please view our kernel implementations.

Environment Installation

To setup and run our code on a (local) GPU machine, we highly recommend using Anaconda when installing python dependencies. Install GPU requirements using:

cd requirements
pip install -r gpu_requirements.txt

For TPU, please refer to this link for guidance on creating cloud TPU VMs. Then, run:

cd requirements
pip install -r tpu_requirements.txt

WandB Login

We use WandB for logging training metrics and TTT statistics. After installing the requirements, login to WandB using:

wandb login

Dataset Download

Our Llama-2 tokenized datasets are available for download from Google Cloud Buckets:

gsutil -m cp -r gs://llama-2-pile/* llama-2-pile/
gsutil -m cp -r gs://llama-2-books3/* llama-2-books3/

Once downloaded, set the dataset_path flag in train.py to the directory containing the tokenizer_name-meta-llama folder. This will allow the dataloader to find the correct path.

Alternatively, to tokenize datasets yourself, refer to dataset preparation.

Replicating Experiments

We provide scripts corresponding to each experiment in our paper in the scripts folder. After specifying the experiment name and directory, select the desired context length and divide by 0.5 million to calculate the appropriate batch size.

Depending on the model size, you may need to modify the mesh_dim to introduce model sharding. See the model docs for additional information on the training configuration.

Credits

About

Official JAX implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published