Skip to content

Commit

Permalink
Merge pull request #69 from Lightricks/feature/staging-0.1.2
Browse files Browse the repository at this point in the history
Update: Version 0.1.2
  • Loading branch information
yoavhacohen authored Dec 19, 2024
2 parents 8965f34 + 6c9805b commit ea6d5d6
Show file tree
Hide file tree
Showing 12 changed files with 778 additions and 149 deletions.
75 changes: 70 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This is the official repository for LTX-Video.
[Website](https://www.lightricks.com/ltxv) |
[Model](https://huggingface.co/Lightricks/LTX-Video) |
[Demo](https://fal.ai/models/fal-ai/ltx-video) |
[Paper (Soon)](https://github.com/Lightricks/LTX-Video)
[Paper (Soon)](https://github.com/Lightricks/LTX-Video)

</div>

Expand All @@ -20,7 +20,11 @@ This is the official repository for LTX-Video.
- [Installation](#installation)
- [Inference](#inference)
- [ComfyUI Integration](#comfyui-integration)
- [Diffusers Integration](#diffusers-integration)
- [Model User Guide](#model-user-guide)
- [Community Contribution](#community-contribution)
- [Training](#trining)
- [Join Us!](#join-us)
- [Acknowledgement](#acknowledgement)

# Introduction
Expand Down Expand Up @@ -60,13 +64,13 @@ source env/bin/activate
python -m pip install -e .\[inference-script\]
```

Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/LTX-Video)
Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/LTX-Video)

```python
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download

model_path = 'PATH' # The local directory to save downloaded checkpoint
snapshot_download("Lightricks/LTX-Video", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
hf_hub_download(repo_id="Lightricks/LTX-Video", filename="ltx-video-2b-v0.9.1.safetensors", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
```

### Inference
Expand Down Expand Up @@ -113,7 +117,68 @@ When writing prompts, focus on detailed, chronological descriptions of actions a
* Guidance Scale: 3-3.5 are the recommended values
* Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed

## More to come...
## Community Contribution

### ComfyUI-LTXTricks 🛠️

A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings.

- **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks)
- **Features:**
- 🔄 **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json).
- ✂️ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json).
- 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json).
- 🎥 **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json).
-**Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json).
- 🖼️ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json).


### LTX-VideoQ8 🎱

**LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs.

- **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video)
- **Features:**
- 🚀 Up to 3X speed-up with no accuracy loss
- 🎥 Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM)
- 🛠️ Fine-tune 2B transformer models with precalculated latents
- **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/)

### Your Contribution

...is welcome! If you have a project or tool that integrates with LTX-Video,
please let us know by opening an issue or pull request.

# Training

## Diffusers

Diffusers implemented [LoRA support](https://github.com/huggingface/diffusers/pull/10228),
with a training script for fine-tuning.
More information and training script in
[finetrainers](https://github.com/a-r-r-o-w/finetrainers?tab=readme-ov-file#training).

## Diffusion-Pipe

An experimental training framework with pipeline parallelism, enabling fine-tuning of large models like **LTX-Video** across multiple GPUs.

- **Repository:** [Diffusion-Pipe](https://github.com/tdrussell/diffusion-pipe)
- **Features:**
- 🛠️ Full fine-tune support for LTX-Video using LoRA
- 📊 Useful metrics logged to Tensorboard
- 🔄 Training state checkpointing and resumption
- ⚡ Efficient pre-caching of latents and text embeddings for multi-GPU setups


# Join Us 🚀

Want to work on cutting-edge AI research and make a real impact on millions of users worldwide?

At **Lightricks**, an AI-first company, we’re revolutionizing how visual content is created.

If you are passionate about AI, computer vision, and video generation, we would love to hear from you!

Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information.

# Acknowledgement

Expand Down
139 changes: 85 additions & 54 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import json
import os
import random
from datetime import datetime
Expand All @@ -8,7 +7,6 @@

import imageio
import numpy as np
from safetensors import safe_open
import torch
import torch.nn.functional as F
from PIL import Image
Expand All @@ -22,41 +20,18 @@
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.utils.conditioning_method import ConditioningMethod

from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy

MAX_HEIGHT = 720
MAX_WIDTH = 1280
MAX_NUM_FRAMES = 257


def load_vae(vae_config, ckpt):
vae = CausalVideoAutoencoder.from_config(vae_config)
vae_state_dict = {
key.replace("vae.", ""): value
for key, value in ckpt.items()
if key.startswith("vae.")
}
vae.load_state_dict(vae_state_dict)
if torch.cuda.is_available():
vae = vae.cuda()
return vae.to(torch.bfloat16)


def load_transformer(transformer_config, ckpt):
transformer = Transformer3DModel.from_config(transformer_config)
transformer_state_dict = {
key.replace("model.diffusion_model.", ""): value
for key, value in ckpt.items()
if key.startswith("model.diffusion_model.")
}
transformer.load_state_dict(transformer_state_dict, strict=True)
def get_total_gpu_memory():
if torch.cuda.is_available():
transformer = transformer.cuda()
return transformer


def load_scheduler(scheduler_config):
return RectifiedFlowScheduler.from_config(scheduler_config)
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return total_memory
return None


def load_image_to_tensor_with_resize_and_crop(
Expand Down Expand Up @@ -204,6 +179,30 @@ def main():
default=3,
help="Guidance scale for the pipeline",
)
parser.add_argument(
"--stg_scale",
type=float,
default=1,
help="Spatiotemporal guidance scale for the pipeline. 0 to disable STG.",
)
parser.add_argument(
"--stg_rescale",
type=float,
default=0.7,
help="Spatiotemporal guidance rescaling scale for the pipeline. 1 to disable rescale.",
)
parser.add_argument(
"--stg_mode",
type=str,
default="stg_a",
help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.",
)
parser.add_argument(
"--stg_skip_layers",
type=str,
default="19",
help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.",
)
parser.add_argument(
"--image_cond_noise_scale",
type=float,
Expand Down Expand Up @@ -233,9 +232,24 @@ def main():
)

parser.add_argument(
"--bfloat16",
action="store_true",
help="Denoise in bfloat16",
"--precision",
choices=["bfloat16", "mixed_precision"],
default="bfloat16",
help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.",
)

# VAE noise augmentation
parser.add_argument(
"--decode_timestep",
type=float,
default=0.05,
help="Timestep for decoding noise",
)
parser.add_argument(
"--decode_noise_scale",
type=float,
default=0.025,
help="Noise level for decoding noise",
)

# Prompts
Expand All @@ -251,6 +265,12 @@ def main():
help="Negative prompt for undesired features",
)

parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Offloading unnecessary computations to CPU.",
)

logger = logging.get_logger(__name__)

args = parser.parse_args()
Expand All @@ -259,6 +279,8 @@ def main():

seed_everething(args.seed)

offload_to_cpu = False if not args.offload_to_cpu else get_total_gpu_memory() < 30

output_dir = (
Path(args.output_path)
if args.output_path
Expand Down Expand Up @@ -301,35 +323,36 @@ def main():
else:
media_items = None

# Paths for the separate mode directories
ckpt_path = Path(args.ckpt_path)
ckpt = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
ckpt[k] = f.get_tensor(k)

configs = json.loads(metadata["config"])
vae_config = configs["vae"]
transformer_config = configs["transformer"]
scheduler_config = configs["scheduler"]

# Load models
vae = load_vae(vae_config, ckpt)
transformer = load_transformer(transformer_config, ckpt)
scheduler = load_scheduler(scheduler_config)
patchifier = SymmetricPatchifier(patch_size=1)
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
transformer = Transformer3DModel.from_pretrained(ckpt_path)
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)

text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
)
if torch.cuda.is_available():
text_encoder = text_encoder.to("cuda")
patchifier = SymmetricPatchifier(patch_size=1)
tokenizer = T5Tokenizer.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
)

if args.bfloat16 and transformer.dtype != torch.bfloat16:
if torch.cuda.is_available():
transformer = transformer.cuda()
vae = vae.cuda()
text_encoder = text_encoder.cuda()

vae = vae.to(torch.bfloat16)
if args.precision == "bfloat16" and transformer.dtype != torch.bfloat16:
transformer = transformer.to(torch.bfloat16)
text_encoder = text_encoder.to(torch.bfloat16)

# Set spatiotemporal guidance
skip_block_list = [int(x.strip()) for x in args.stg_skip_layers.split(",")]
skip_layer_strategy = (
SkipLayerStrategy.Attention
if args.stg_mode.lower() == "stg_a"
else SkipLayerStrategy.Residual
)

# Use submodels for the pipeline
submodel_dict = {
Expand Down Expand Up @@ -362,6 +385,11 @@ def main():
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.num_images_per_prompt,
guidance_scale=args.guidance_scale,
skip_layer_strategy=skip_layer_strategy,
skip_block_list=skip_block_list,
stg_scale=args.stg_scale,
do_rescaling=args.stg_rescale != 1,
rescaling_scale=args.stg_rescale,
generator=generator,
output_type="pt",
callback_on_step_end=None,
Expand All @@ -378,7 +406,10 @@ def main():
else ConditioningMethod.UNCONDITIONAL
),
image_cond_noise_scale=args.image_cond_noise_scale,
mixed_precision=not args.bfloat16,
decode_timestep=args.decode_timestep,
decode_noise_scale=args.decode_noise_scale,
mixed_precision=(args.precision == "mixed_precision"),
offload_to_cpu=offload_to_cpu,
).images

# Crop the padded images to the desired resolution and number of frames
Expand Down
Loading

0 comments on commit ea6d5d6

Please sign in to comment.