Skip to content

RobinKa/discrete-flow-matching-pytorch

Repository files navigation

Discrete Flow Matching implemented in PyTorch

Implementation of Discrete Flow Matching [1][2], which is a generative model for generating discrete things such as text with flow matching. The code is implemented in PyTorch.

Step 0 of 128 (input) Step 64 of 128 Step 128 of 128 (output)
Step 0 Step 64 Step 128

How to run

Environment setup

  1. Install uv for package management, e.g. pip install uv
  2. Make sure Python 3.12 is installed: uv python install 3.12
  3. Install the dependencies: uv sync --group jupyter

Run python -m discrete_flow_matching_pytorch.train --config configs/conv-8.yaml to start training a text generation model logging to wandb.

The sample notebook demonstrates the sampling process.

Note: Instead of using uv, it is also possible to install the dependencies in pyproject.toml with pip.

Summary of discrete flow matching compared to continuous flow matching

  • During training, we mask out text tokens according to the timestep
  • The model is trained to predict the original unmasked tokens with cross entropy loss
  • In sampling, we unmask text gradually with the sampled tokens

References