diff --git a/README.md b/README.md index 5ac37e3..a1f1ecd 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This is the official repository for LTX-Video. [Website](https://www.lightricks.com/ltxv) | [Model](https://huggingface.co/Lightricks/LTX-Video) | [Demo](https://fal.ai/models/fal-ai/ltx-video) | -[Paper (Soon)](https://github.com/Lightricks/LTX-Video) +[Paper (Soon)](https://github.com/Lightricks/LTX-Video) @@ -20,7 +20,11 @@ This is the official repository for LTX-Video. - [Installation](#installation) - [Inference](#inference) - [ComfyUI Integration](#comfyui-integration) + - [Diffusers Integration](#diffusers-integration) - [Model User Guide](#model-user-guide) +- [Community Contribution](#community-contribution) +- [Training](#trining) +- [Join Us!](#join-us) - [Acknowledgement](#acknowledgement) # Introduction @@ -60,13 +64,13 @@ source env/bin/activate python -m pip install -e .\[inference-script\] ``` -Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/LTX-Video) +Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/LTX-Video) ```python -from huggingface_hub import snapshot_download +from huggingface_hub import hf_hub_download model_path = 'PATH' # The local directory to save downloaded checkpoint -snapshot_download("Lightricks/LTX-Video", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model') +hf_hub_download(repo_id="Lightricks/LTX-Video", filename="ltx-video-2b-v0.9.1.safetensors", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model') ``` ### Inference @@ -113,7 +117,68 @@ When writing prompts, focus on detailed, chronological descriptions of actions a * Guidance Scale: 3-3.5 are the recommended values * Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed -## More to come... +## Community Contribution + +### ComfyUI-LTXTricks πŸ› οΈ + +A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings. + +- **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks) +- **Features:** + - πŸ”„ **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json). + - βœ‚οΈ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json). + - 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json). + - πŸŽ₯ **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json). + - ✨ **Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json). + - πŸ–ΌοΈ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json). + + +### LTX-VideoQ8 🎱 + +**LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs. + +- **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video) +- **Features:** + - πŸš€ Up to 3X speed-up with no accuracy loss + - πŸŽ₯ Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM) + - πŸ› οΈ Fine-tune 2B transformer models with precalculated latents +- **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/) + +### Your Contribution + +...is welcome! If you have a project or tool that integrates with LTX-Video, +please let us know by opening an issue or pull request. + +# Training + +## Diffusers + +Diffusers implemented [LoRA support](https://github.com/huggingface/diffusers/pull/10228), +with a training script for fine-tuning. +More information and training script in +[finetrainers](https://github.com/a-r-r-o-w/finetrainers?tab=readme-ov-file#training). + +## Diffusion-Pipe + +An experimental training framework with pipeline parallelism, enabling fine-tuning of large models like **LTX-Video** across multiple GPUs. + +- **Repository:** [Diffusion-Pipe](https://github.com/tdrussell/diffusion-pipe) +- **Features:** + - πŸ› οΈ Full fine-tune support for LTX-Video using LoRA + - πŸ“Š Useful metrics logged to Tensorboard + - πŸ”„ Training state checkpointing and resumption + - ⚑ Efficient pre-caching of latents and text embeddings for multi-GPU setups + + +# Join Us πŸš€ + +Want to work on cutting-edge AI research and make a real impact on millions of users worldwide? + +At **Lightricks**, an AI-first company, we’re revolutionizing how visual content is created. + +If you are passionate about AI, computer vision, and video generation, we would love to hear from you! + +Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information. # Acknowledgement diff --git a/inference.py b/inference.py index 5e111cd..704ba6a 100644 --- a/inference.py +++ b/inference.py @@ -1,5 +1,4 @@ import argparse -import json import os import random from datetime import datetime @@ -8,7 +7,6 @@ import imageio import numpy as np -from safetensors import safe_open import torch import torch.nn.functional as F from PIL import Image @@ -22,41 +20,18 @@ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline from ltx_video.schedulers.rf import RectifiedFlowScheduler from ltx_video.utils.conditioning_method import ConditioningMethod - +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy MAX_HEIGHT = 720 MAX_WIDTH = 1280 MAX_NUM_FRAMES = 257 -def load_vae(vae_config, ckpt): - vae = CausalVideoAutoencoder.from_config(vae_config) - vae_state_dict = { - key.replace("vae.", ""): value - for key, value in ckpt.items() - if key.startswith("vae.") - } - vae.load_state_dict(vae_state_dict) - if torch.cuda.is_available(): - vae = vae.cuda() - return vae.to(torch.bfloat16) - - -def load_transformer(transformer_config, ckpt): - transformer = Transformer3DModel.from_config(transformer_config) - transformer_state_dict = { - key.replace("model.diffusion_model.", ""): value - for key, value in ckpt.items() - if key.startswith("model.diffusion_model.") - } - transformer.load_state_dict(transformer_state_dict, strict=True) +def get_total_gpu_memory(): if torch.cuda.is_available(): - transformer = transformer.cuda() - return transformer - - -def load_scheduler(scheduler_config): - return RectifiedFlowScheduler.from_config(scheduler_config) + total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) + return total_memory + return None def load_image_to_tensor_with_resize_and_crop( @@ -204,6 +179,30 @@ def main(): default=3, help="Guidance scale for the pipeline", ) + parser.add_argument( + "--stg_scale", + type=float, + default=1, + help="Spatiotemporal guidance scale for the pipeline. 0 to disable STG.", + ) + parser.add_argument( + "--stg_rescale", + type=float, + default=0.7, + help="Spatiotemporal guidance rescaling scale for the pipeline. 1 to disable rescale.", + ) + parser.add_argument( + "--stg_mode", + type=str, + default="stg_a", + help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.", + ) + parser.add_argument( + "--stg_skip_layers", + type=str, + default="19", + help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.", + ) parser.add_argument( "--image_cond_noise_scale", type=float, @@ -233,9 +232,24 @@ def main(): ) parser.add_argument( - "--bfloat16", - action="store_true", - help="Denoise in bfloat16", + "--precision", + choices=["bfloat16", "mixed_precision"], + default="bfloat16", + help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.", + ) + + # VAE noise augmentation + parser.add_argument( + "--decode_timestep", + type=float, + default=0.05, + help="Timestep for decoding noise", + ) + parser.add_argument( + "--decode_noise_scale", + type=float, + default=0.025, + help="Noise level for decoding noise", ) # Prompts @@ -251,6 +265,12 @@ def main(): help="Negative prompt for undesired features", ) + parser.add_argument( + "--offload_to_cpu", + action="store_true", + help="Offloading unnecessary computations to CPU.", + ) + logger = logging.get_logger(__name__) args = parser.parse_args() @@ -259,6 +279,8 @@ def main(): seed_everething(args.seed) + offload_to_cpu = False if not args.offload_to_cpu else get_total_gpu_memory() < 30 + output_dir = ( Path(args.output_path) if args.output_path @@ -301,35 +323,36 @@ def main(): else: media_items = None - # Paths for the separate mode directories ckpt_path = Path(args.ckpt_path) - ckpt = {} - with safe_open(ckpt_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - for k in f.keys(): - ckpt[k] = f.get_tensor(k) - - configs = json.loads(metadata["config"]) - vae_config = configs["vae"] - transformer_config = configs["transformer"] - scheduler_config = configs["scheduler"] - - # Load models - vae = load_vae(vae_config, ckpt) - transformer = load_transformer(transformer_config, ckpt) - scheduler = load_scheduler(scheduler_config) - patchifier = SymmetricPatchifier(patch_size=1) + vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) + transformer = Transformer3DModel.from_pretrained(ckpt_path) + scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) + text_encoder = T5EncoderModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" ) - if torch.cuda.is_available(): - text_encoder = text_encoder.to("cuda") + patchifier = SymmetricPatchifier(patch_size=1) tokenizer = T5Tokenizer.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer" ) - if args.bfloat16 and transformer.dtype != torch.bfloat16: + if torch.cuda.is_available(): + transformer = transformer.cuda() + vae = vae.cuda() + text_encoder = text_encoder.cuda() + + vae = vae.to(torch.bfloat16) + if args.precision == "bfloat16" and transformer.dtype != torch.bfloat16: transformer = transformer.to(torch.bfloat16) + text_encoder = text_encoder.to(torch.bfloat16) + + # Set spatiotemporal guidance + skip_block_list = [int(x.strip()) for x in args.stg_skip_layers.split(",")] + skip_layer_strategy = ( + SkipLayerStrategy.Attention + if args.stg_mode.lower() == "stg_a" + else SkipLayerStrategy.Residual + ) # Use submodels for the pipeline submodel_dict = { @@ -362,6 +385,11 @@ def main(): num_inference_steps=args.num_inference_steps, num_images_per_prompt=args.num_images_per_prompt, guidance_scale=args.guidance_scale, + skip_layer_strategy=skip_layer_strategy, + skip_block_list=skip_block_list, + stg_scale=args.stg_scale, + do_rescaling=args.stg_rescale != 1, + rescaling_scale=args.stg_rescale, generator=generator, output_type="pt", callback_on_step_end=None, @@ -378,7 +406,10 @@ def main(): else ConditioningMethod.UNCONDITIONAL ), image_cond_noise_scale=args.image_cond_noise_scale, - mixed_precision=not args.bfloat16, + decode_timestep=args.decode_timestep, + decode_noise_scale=args.decode_noise_scale, + mixed_precision=(args.precision == "mixed_precision"), + offload_to_cpu=offload_to_cpu, ).images # Crop the padded images to the desired resolution and number of frames diff --git a/ltx_video/models/autoencoders/causal_video_autoencoder.py b/ltx_video/models/autoencoders/causal_video_autoencoder.py index 80c5f2e..c059d4a 100644 --- a/ltx_video/models/autoencoders/causal_video_autoencoder.py +++ b/ltx_video/models/autoencoders/causal_video_autoencoder.py @@ -3,6 +3,7 @@ from functools import partial from types import SimpleNamespace from typing import Any, Mapping, Optional, Tuple, Union, List +from pathlib import Path import torch import numpy as np @@ -11,13 +12,20 @@ from diffusers.utils import logging import torch.nn.functional as F from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from safetensors import safe_open from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd from ltx_video.models.autoencoders.pixel_norm import PixelNorm from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper from ltx_video.models.transformers.attention import Attention +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + VAE_KEYS_RENAME_DICT, +) +PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -29,34 +37,85 @@ def from_pretrained( *args, **kwargs, ): - config_local_path = pretrained_model_name_or_path / "config.json" - config = cls.load_config(config_local_path, **kwargs) - video_vae = cls.from_config(config) - video_vae.to(kwargs["torch_dtype"]) + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if ( + pretrained_model_name_or_path.is_dir() + and (pretrained_model_name_or_path / "autoencoder.pth").exists() + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + std_of_means = data_dict["std-of-means"] + mean_of_means = data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( + std_of_means + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( + mean_of_means + ) - model_local_path = pretrained_model_name_or_path / "autoencoder.pth" - ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) - video_vae.load_state_dict(ckpt_state_dict) + elif pretrained_model_name_or_path.is_dir(): + config_path = pretrained_model_name_or_path / "vae" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) - statistics_local_path = ( - pretrained_model_name_or_path / "per_channel_statistics.json" - ) - if statistics_local_path.exists(): - with open(statistics_local_path, "r") as file: - data = json.load(file) - transposed_data = list(zip(*data["data"])) - data_dict = { - col: torch.tensor(vals) - for col, vals in zip(data["columns"], transposed_data) - } - video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) - video_vae.register_buffer( - "mean_of_means", - data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ), + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for VAE is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." ) + config = diffusers_and_ours_config_mapping[config] + + state_dict_path = ( + pretrained_model_name_or_path + / "vae" + / "diffusion_pytorch_model.safetensors" + ) + + state_dict = {} + with safe_open(state_dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + state_dict[new_key] = state_dict.pop(key) + + elif pretrained_model_name_or_path.is_file() and str( + pretrained_model_name_or_path + ).endswith(".safetensors"): + state_dict = {} + with safe_open( + pretrained_model_name_or_path, framework="pt", device="cpu" + ) as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["vae"] + + video_vae = cls.from_config(config) + if "torch_dtype" in kwargs: + video_vae.to(kwargs["torch_dtype"]) + video_vae.load_state_dict(state_dict) return video_vae @staticmethod @@ -165,11 +224,16 @@ def to_json_string(self) -> str: return json.dumps(self.config.__dict__) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - per_channel_statistics_prefix = "per_channel_statistics." + if any([key.startswith("vae.") for key in state_dict.keys()]): + state_dict = { + key.replace("vae.", ""): value + for key, value in state_dict.items() + if key.startswith("vae.") + } ckpt_state_dict = { key: value for key, value in state_dict.items() - if not key.startswith(per_channel_statistics_prefix) + if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) } model_keys = set(name for name, _ in self.named_parameters()) @@ -195,9 +259,9 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): super().load_state_dict(converted_state_dict, strict=strict) data_dict = { - key.removeprefix(per_channel_statistics_prefix): value + key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value for key, value in state_dict.items() - if key.startswith(per_channel_statistics_prefix) + if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) } if len(data_dict) > 0: self.register_buffer("std_of_means", data_dict["std-of-means"]) @@ -577,7 +641,7 @@ def forward( self, sample: torch.FloatTensor, target_shape, - timesteps: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" assert target_shape is not None, "target_shape must be provided" @@ -597,14 +661,14 @@ def forward( if self.timestep_conditioning: assert ( - timesteps is not None - ), "should pass timesteps with timestep_conditioning=True" - scaled_timesteps = timesteps * self.timestep_scale_multiplier + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier for up_block in self.up_blocks: if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timesteps=scaled_timesteps + sample, causal=self.causal, timestep=scaled_timestep ) else: sample = checkpoint_fn(up_block)(sample, causal=self.causal) @@ -612,25 +676,25 @@ def forward( sample = self.conv_norm_out(sample) if self.timestep_conditioning: - embedded_timesteps = self.last_time_embedder( - timestep=scaled_timesteps.flatten(), + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), resolution=None, aspect_ratio=None, batch_size=sample.shape[0], hidden_dtype=sample.dtype, ) - embedded_timesteps = embedded_timesteps.view( - batch_size, embedded_timesteps.shape[-1], 1, 1, 1 + embedded_timestep = embedded_timestep.view( + batch_size, embedded_timestep.shape[-1], 1, 1, 1 ) ada_values = self.last_scale_shift_table[ None, ..., None, None, None - ] + embedded_timesteps.reshape( + ] + embedded_timestep.reshape( batch_size, 2, -1, - embedded_timesteps.shape[-3], - embedded_timesteps.shape[-2], - embedded_timesteps.shape[-1], + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], ) shift, scale = ada_values.unbind(dim=1) sample = sample * (1 + scale) + shift @@ -737,16 +801,16 @@ def forward( self, hidden_states: torch.FloatTensor, causal: bool = True, - timesteps: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: timestep_embed = None if self.timestep_conditioning: assert ( - timesteps is not None - ), "should pass timesteps with timestep_conditioning=True" + timestep is not None + ), "should pass timestep with timestep_conditioning=True" batch_size = hidden_states.shape[0] timestep_embed = self.time_embedder( - timestep=timesteps.flatten(), + timestep=timestep.flatten(), resolution=None, aspect_ratio=None, batch_size=batch_size, @@ -759,7 +823,7 @@ def forward( if self.attention_blocks: for resnet, attention in zip(self.res_blocks, self.attention_blocks): hidden_states = resnet( - hidden_states, causal=causal, timesteps=timestep_embed + hidden_states, causal=causal, timestep=timestep_embed ) # Reshape the hidden states to be (batch_size, frames * height * width, channel) @@ -806,7 +870,7 @@ def forward( else: for resnet in self.res_blocks: hidden_states = resnet( - hidden_states, causal=causal, timesteps=timestep_embed + hidden_states, causal=causal, timestep=timestep_embed ) return hidden_states @@ -991,7 +1055,7 @@ def forward( self, input_tensor: torch.FloatTensor, causal: bool = True, - timesteps: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: hidden_states = input_tensor batch_size = hidden_states.shape[0] @@ -999,17 +1063,17 @@ def forward( hidden_states = self.norm1(hidden_states) if self.timestep_conditioning: assert ( - timesteps is not None - ), "should pass timesteps with timestep_conditioning=True" + timestep is not None + ), "should pass timestep with timestep_conditioning=True" ada_values = self.scale_shift_table[ None, ..., None, None, None - ] + timesteps.reshape( + ] + timestep.reshape( batch_size, 4, -1, - timesteps.shape[-3], - timesteps.shape[-2], - timesteps.shape[-1], + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], ) shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) @@ -1164,9 +1228,9 @@ def demo_video_autoencoder_forward_backward(): print(f"input shape={input_videos.shape}") print(f"latent shape={latent.shape}") - timesteps = torch.ones(input_videos.shape[0]) * 0.1 + timestep = torch.ones(input_videos.shape[0]) * 0.1 reconstructed_videos = video_autoencoder.decode( - latent, target_shape=input_videos.shape, timesteps=timesteps + latent, target_shape=input_videos.shape, timestep=timestep ).sample print(f"reconstructed shape={reconstructed_videos.shape}") @@ -1175,7 +1239,7 @@ def demo_video_autoencoder_forward_backward(): input_image = input_videos[:, :, :1, :, :] image_latent = video_autoencoder.encode(input_image).latent_dist.mode() _ = video_autoencoder.decode( - image_latent, target_shape=image_latent.shape, timesteps=timesteps + image_latent, target_shape=image_latent.shape, timestep=timestep ).sample # first_frame_latent = latent[:, :, :1, :, :] diff --git a/ltx_video/models/autoencoders/vae.py b/ltx_video/models/autoencoders/vae.py index a67ac7a..0bcab52 100644 --- a/ltx_video/models/autoencoders/vae.py +++ b/ltx_video/models/autoencoders/vae.py @@ -257,11 +257,11 @@ def _decode( self, z: torch.FloatTensor, target_shape=None, - timesteps: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, ) -> Union[DecoderOutput, torch.FloatTensor]: z = self.post_quant_conv(z) - if "timesteps" in self.decoder_params: - dec = self.decoder(z, target_shape=target_shape, timesteps=timesteps) + if "timestep" in self.decoder_params: + dec = self.decoder(z, target_shape=target_shape, timestep=timestep) else: dec = self.decoder(z, target_shape=target_shape) return dec @@ -271,7 +271,7 @@ def decode( z: torch.FloatTensor, return_dict: bool = True, target_shape=None, - timesteps: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, ) -> Union[DecoderOutput, torch.FloatTensor]: assert target_shape is not None, "target_shape must be provided for decoding" if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: @@ -304,7 +304,7 @@ def decode( decoded = ( self._hw_tiled_decode(z, target_shape) if self.use_hw_tiling - else self._decode(z, target_shape=target_shape, timesteps=timesteps) + else self._decode(z, target_shape=target_shape, timestep=timestep) ) if not return_dict: diff --git a/ltx_video/models/autoencoders/vae_encode.py b/ltx_video/models/autoencoders/vae_encode.py index 3b8b15b..f584ec0 100644 --- a/ltx_video/models/autoencoders/vae_encode.py +++ b/ltx_video/models/autoencoders/vae_encode.py @@ -96,6 +96,7 @@ def vae_decode( is_video: bool = True, split_size: int = 1, vae_per_channel_normalize=False, + timestep=None, ) -> Tensor: is_video_shaped = latents.dim() == 5 batch_size = latents.shape[0] @@ -111,12 +112,16 @@ def vae_decode( ) encode_bs = len(latents) // split_size image_batch = [ - _run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize) + _run_decoder( + latent_batch, vae, is_video, vae_per_channel_normalize, timestep + ) for latent_batch in latents.split(encode_bs) ] images = torch.cat(image_batch, dim=0) else: - images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize) + images = _run_decoder( + latents, vae, is_video, vae_per_channel_normalize, timestep + ) if is_video_shaped and not isinstance( vae, (VideoAutoencoder, CausalVideoAutoencoder) @@ -126,12 +131,19 @@ def vae_decode( def _run_decoder( - latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False + latents: Tensor, + vae: AutoencoderKL, + is_video: bool, + vae_per_channel_normalize=False, + timestep=None, ) -> Tensor: if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): *_, fl, hl, wl = latents.shape temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) latents = latents.to(vae.dtype) + vae_decode_kwargs = {} + if timestep is not None: + vae_decode_kwargs["timestep"] = timestep image = vae.decode( un_normalize_latents(latents, vae, vae_per_channel_normalize), return_dict=False, @@ -142,6 +154,7 @@ def _run_decoder( hl * spatial_scale, wl * spatial_scale, ), + **vae_decode_kwargs, )[0] else: image = vae.decode( diff --git a/ltx_video/models/transformers/attention.py b/ltx_video/models/transformers/attention.py index e31e8b9..96ed251 100644 --- a/ltx_video/models/transformers/attention.py +++ b/ltx_video/models/transformers/attention.py @@ -20,6 +20,8 @@ from einops import rearrange from torch import nn +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + try: from torch_xla.experimental.custom_kernel import flash_attention except ImportError: @@ -204,6 +206,8 @@ def forward( cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, ) -> torch.FloatTensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: @@ -253,6 +257,8 @@ def forward( encoder_hidden_states if self.only_cross_attention else None ), attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, **cross_attention_kwargs, ) if gate_msa is not None: @@ -647,6 +653,8 @@ def forward( freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, **cross_attention_kwargs, ) -> torch.Tensor: r""" @@ -659,6 +667,10 @@ def forward( The hidden states of the encoder. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + skip_layer_mask (`torch.Tensor`, *optional*): + The skip layer mask to use. If `None`, no mask is applied. + skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers to skip for spatiotemporal guidance. **cross_attention_kwargs: Additional keyword arguments to pass along to the cross attention. @@ -690,6 +702,8 @@ def forward( freqs_cis=freqs_cis, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, **cross_attention_kwargs, ) @@ -924,6 +938,8 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.FloatTensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, *args, **kwargs, ) -> torch.FloatTensor: @@ -949,6 +965,9 @@ def __call__( else encoder_hidden_states.shape ) + if skip_layer_mask is not None: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) + if (attention_mask is not None) and (not attn.use_tpu_flash_attention): attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size @@ -1015,7 +1034,7 @@ def __call__( ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" # run the TPU kernel implemented in jax with pallas - hidden_states = flash_attention( + hidden_states_a = flash_attention( q=query, k=key, v=value, @@ -1024,7 +1043,7 @@ def __call__( sm_scale=attn.scale, ) else: - hidden_states = F.scaled_dot_product_attention( + hidden_states_a = F.scaled_dot_product_attention( query, key, value, @@ -1033,10 +1052,20 @@ def __call__( is_causal=False, ) - hidden_states = hidden_states.transpose(1, 2).reshape( + hidden_states_a = hidden_states_a.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) - hidden_states = hidden_states.to(query.dtype) + hidden_states_a = hidden_states_a.to(query.dtype) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Attention + ): + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( + 1.0 - skip_layer_mask + ) + else: + hidden_states = hidden_states_a # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -1047,9 +1076,20 @@ def __call__( hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) if attn.residual_connection: - hidden_states = hidden_states + residual + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor diff --git a/ltx_video/models/transformers/transformer3d.py b/ltx_video/models/transformers/transformer3d.py index dfab43d..afc98e7 100644 --- a/ltx_video/models/transformers/transformer3d.py +++ b/ltx_video/models/transformers/transformer3d.py @@ -1,7 +1,11 @@ # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Literal +from typing import Any, Dict, List, Optional, Literal, Union +import os +import json +import glob +from pathlib import Path import torch from diffusers.configuration_utils import ConfigMixin, register_to_config @@ -11,9 +15,19 @@ from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils import logging from torch import nn +from safetensors import safe_open + from ltx_video.models.transformers.attention import BasicTransformerBlock from ltx_video.models.transformers.embeddings import get_3d_sincos_pos_embed +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + TRANSFORMER_KEYS_RENAME_DICT, +) + logger = logging.get_logger(__name__) @@ -166,6 +180,21 @@ def set_use_tpu_flash_attention(self): for block in self.transformer_blocks: block.set_use_tpu_flash_attention() + def create_skip_layer_mask( + self, + skip_block_list: List[int], + batch_size: int, + num_conds: int, + ptb_index: int, + ): + num_layers = len(self.transformer_blocks) + mask = torch.ones( + (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype + ) + for block_idx in skip_block_list: + mask[block_idx, ptb_index::num_conds] = 0 + return mask + def initialize(self, embedding_std: float, mode: Literal["ltx_video", "legacy"]): def _basic_init(module): if isinstance(module, nn.Linear): @@ -305,6 +334,75 @@ def precompute_freqs_cis(self, indices_grid, spacing="exp"): sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) return cos_freq.to(self.dtype), sin_freq.to(self.dtype) + def load_state_dict( + self, + state_dict: Dict, + *args, + **kwargs, + ): + if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): + state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in state_dict.items() + if key.startswith("model.diffusion_model.") + } + super().load_state_dict(state_dict, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_dir(): + config_path = pretrained_model_path / "transformer" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for transformer is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + state_dict = {} + ckpt_paths = ( + pretrained_model_path + / "transformer" + / "diffusion_pytorch_model*.safetensors" + ) + dict_list = glob.glob(str(ckpt_paths)) + for dict_path in dict_list: + part_dict = {} + with safe_open(dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + part_dict[k] = f.get_tensor(k) + state_dict.update(part_dict) + + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + state_dict[new_key] = state_dict.pop(key) + + transformer = cls.from_config(config) + transformer.load_state_dict(state_dict, strict=True) + elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + transformer_config = configs["transformer"] + transformer = Transformer3DModel.from_config(transformer_config) + transformer.load_state_dict(comfy_single_file_state_dict) + return transformer + def forward( self, hidden_states: torch.Tensor, @@ -315,6 +413,8 @@ def forward( cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, return_dict: bool = True, ): """ @@ -348,6 +448,11 @@ def forward( If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. + skip_layer_mask ( `torch.Tensor`, *optional*): + A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position + `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. + skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -413,6 +518,11 @@ def forward( batch_size, -1, embedded_timestep.shape[-1] ) + if skip_layer_mask is None: + skip_layer_mask = torch.ones( + len(self.transformer_blocks), batch_size, device=hidden_states.device + ) + # 2. Blocks if self.caption_projection is not None: batch_size = hidden_states.shape[0] @@ -421,7 +531,7 @@ def forward( batch_size, -1, hidden_states.shape[-1] ) - for block in self.transformer_blocks: + for block_idx, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -446,6 +556,8 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, class_labels, + skip_layer_mask[block_idx], + skip_layer_strategy, **ckpt_kwargs, ) else: @@ -458,6 +570,8 @@ def custom_forward(*inputs): timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, + skip_layer_mask=skip_layer_mask[block_idx], + skip_layer_strategy=skip_layer_strategy, ) # 3. Output diff --git a/ltx_video/pipelines/pipeline_ltx_video.py b/ltx_video/pipelines/pipeline_ltx_video.py index f0e7832..19f191d 100644 --- a/ltx_video/pipelines/pipeline_ltx_video.py +++ b/ltx_video/pipelines/pipeline_ltx_video.py @@ -37,6 +37,7 @@ ) from ltx_video.schedulers.rf import TimestepShifter from ltx_video.utils.conditioning_method import ConditioningMethod +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -298,7 +299,7 @@ def encode_prompt( # See Section 3.1. of the paper. # FIXME: to be configured in config not hardecoded. Fix in separate PR with rest of config max_length = 128 # TPU supports only lengths multiple of 128 - + text_enc_device = next(self.text_encoder.parameters()).device if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( @@ -326,10 +327,11 @@ def encode_prompt( ) prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(text_enc_device) prompt_attention_mask = prompt_attention_mask.to(device) prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=prompt_attention_mask + text_input_ids.to(text_enc_device), attention_mask=prompt_attention_mask ) prompt_embeds = prompt_embeds[0] @@ -370,10 +372,12 @@ def encode_prompt( return_tensors="pt", ) negative_prompt_attention_mask = uncond_input.attention_mask - negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + text_enc_device + ) negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), + uncond_input.input_ids.to(text_enc_device), attention_mask=negative_prompt_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] @@ -701,10 +705,12 @@ def prepare_latents( if latents is None: latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype + shape, generator=generator, device=generator.device, dtype=dtype ) elif latents_mask is not None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor( + shape, generator=generator, device=generator.device, dtype=dtype + ) latents = latents * latents_mask[..., None] + noise * ( 1 - latents_mask[..., None] ) @@ -768,6 +774,11 @@ def __call__( num_inference_steps: int = 20, timesteps: List[int] = None, guidance_scale: float = 4.5, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: List[int] = None, + stg_scale: float = 1.0, + do_rescaling: bool = True, + rescaling_scale: float = 0.7, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -781,7 +792,10 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, clean_caption: bool = True, media_items: Optional[torch.FloatTensor] = None, + decode_timestep: Union[List[float], float] = 0.0, + decode_noise_scale: Optional[List[float]] = None, mixed_precision: bool = False, + offload_to_cpu: bool = False, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: """ @@ -888,8 +902,23 @@ def __call__( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + do_spatio_temporal_guidance = stg_scale > 0.0 + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + skip_layer_mask = None + if do_spatio_temporal_guidance: + skip_layer_mask = self.transformer.create_skip_layer_mask( + skip_block_list, batch_size, num_conds, 2 + ) # 3. Encode input prompt + self.text_encoder = self.text_encoder.to(self._execution_device) + ( prompt_embeds, prompt_attention_mask, @@ -907,11 +936,30 @@ def __call__( negative_prompt_attention_mask=negative_prompt_attention_mask, clean_caption=clean_caption, ) + + if offload_to_cpu: + self.text_encoder = self.text_encoder.cpu() + + self.transformer = self.transformer.to(self._execution_device) + + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat( + prompt_embeds_batch = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + prompt_attention_mask_batch = torch.cat( [negative_prompt_attention_mask, prompt_attention_mask], dim=0 ) + if do_spatio_temporal_guidance: + prompt_embeds_batch = torch.cat([prompt_embeds_batch, prompt_embeds], dim=0) + prompt_attention_mask_batch = torch.cat( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + dim=0, + ) # 3b. Encode and prepare conditioning data self.video_scale_factor = self.video_scale_factor if is_video else 1 @@ -939,7 +987,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_latent_channels=self.transformer.config.in_channels, num_patches=num_latent_patches, - dtype=prompt_embeds.dtype, + dtype=prompt_embeds_batch.dtype, device=device, generator=generator, latents=init_latents, @@ -949,8 +997,8 @@ def __call__( if conditioning_mask is not None and is_video: assert num_images_per_prompt == 1 conditioning_mask = ( - torch.cat([conditioning_mask] * 2) - if do_classifier_free_guidance + torch.cat([conditioning_mask] * num_conds) + if num_conds > 1 else conditioning_mask ) @@ -985,8 +1033,9 @@ def __call__( orig_conditiong_mask, generator, ) + latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents + torch.cat([latents] * num_conds) if num_conds > 1 else latents ) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t @@ -1057,20 +1106,34 @@ def __call__( noise_pred = self.transformer( latent_model_input.to(self.transformer.dtype), indices_grid, - encoder_hidden_states=prompt_embeds.to(self.transformer.dtype), - encoder_attention_mask=prompt_attention_mask, + encoder_hidden_states=prompt_embeds_batch.to( + self.transformer.dtype + ), + encoder_attention_mask=prompt_attention_mask_batch, timestep=current_timestep, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, return_dict=False, )[0] # perform guidance + if do_spatio_temporal_guidance: + noise_pred_text_perturb = noise_pred[-1:] if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = noise_pred[:2].chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) - current_timestep, _ = current_timestep.chunk(2) + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale * ( + noise_pred_text - noise_pred_text_perturb + ) + if do_rescaling: + factor = noise_pred_text.std() / noise_pred.std() + factor = rescaling_scale * factor + (1 - rescaling_scale) + noise_pred = noise_pred * factor + current_timestep = current_timestep[:1] # learned sigma if ( self.transformer.config.out_channels // 2 @@ -1096,6 +1159,11 @@ def __call__( if callback_on_step_end is not None: callback_on_step_end(self, i, t, {}) + if offload_to_cpu: + self.transformer = self.transformer.cpu() + if self._execution_device == "cuda": + torch.cuda.empty_cache() + latents = self.patchifier.unpatchify( latents=latents, output_height=latent_height, @@ -1105,11 +1173,30 @@ def __call__( // math.prod(self.patchifier.patch_size), ) if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = torch.randn_like(latents) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * latents.shape[0] + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + decode_timestep = torch.tensor(decode_timestep).to(latents.device) + decode_noise_scale = torch.tensor(decode_noise_scale).to( + latents.device + )[:, None, None, None, None] + latents = ( + latents * (1 - decode_noise_scale) + noise * decode_noise_scale + ) + else: + decode_timestep = None image = vae_decode( latents, self.vae, is_video, vae_per_channel_normalize=kwargs["vae_per_channel_normalize"], + timestep=decode_timestep, ) image = self.image_processor.postprocess(image, output_type=output_type) diff --git a/ltx_video/schedulers/rf.py b/ltx_video/schedulers/rf.py index 68929b0..892266d 100644 --- a/ltx_video/schedulers/rf.py +++ b/ltx_video/schedulers/rf.py @@ -2,15 +2,25 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Callable, Optional, Tuple, Union +import json +import os +from pathlib import Path import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import BaseOutput from torch import Tensor +from safetensors import safe_open + from ltx_video.utils.torch_utils import append_dims +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, +) + def simple_diffusion_resolution_dependent_timestep_shift( samples: Tensor, @@ -197,6 +207,31 @@ def set_timesteps( self.num_inference_steps = num_inference_steps self.sigmas = self.timesteps + @staticmethod + def from_pretrained(pretrained_model_path: Union[str, os.PathLike]): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file(): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["scheduler"] + del comfy_single_file_state_dict + + elif pretrained_model_path.is_dir(): + diffusers_noise_scheduler_config_path = ( + pretrained_model_path / "scheduler" / "scheduler_config.json" + ) + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + hashable_config = make_hashable_key(scheduler_config) + if hashable_config in diffusers_and_ours_config_mapping: + config = diffusers_and_ours_config_mapping[hashable_config] + return RectifiedFlowScheduler.from_config(config) + def scale_model_input( self, sample: torch.FloatTensor, timestep: Optional[int] = None ) -> torch.FloatTensor: diff --git a/ltx_video/utils/diffusers_config_mapping.py b/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 0000000..53c0082 --- /dev/null +++ b/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} diff --git a/ltx_video/utils/skip_layer_strategy.py b/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 0000000..fd4c133 --- /dev/null +++ b/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,6 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + Attention = auto() + Residual = auto() diff --git a/pyproject.toml b/pyproject.toml index 30da2d1..589ec9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ltx-video" -version = "0.1.0" +version = "0.1.2" description = "A package for LTX-Video model" authors = [ { name = "Sapir Weissbuch", email = "sapir@lightricks.com" } @@ -18,7 +18,7 @@ classifiers = [ dependencies = [ "torch>=2.1.0", "diffusers>=0.28.2", - "transformers~=4.44.2", + "transformers>=4.44.2", "sentencepiece>=0.1.96", "huggingface-hub~=0.25.2", "einops"