From 9d0df65a83a575c4b8e93a3562839753f4cc6f1f Mon Sep 17 00:00:00 2001 From: Yoav HaCohen Date: Sun, 16 Feb 2025 16:19:32 +0200 Subject: [PATCH] STG: Add a new STG strategy: AttentionValues (default) This new strategy returns V(x) from the skipped attention blocks. --- inference.py | 10 ++++++---- ltx_video/models/transformers/attention.py | 10 +++++++++- ltx_video/utils/skip_layer_strategy.py | 3 ++- tests/test_inference.py | 4 ++-- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/inference.py b/inference.py index d0dc50b..a595bfe 100644 --- a/inference.py +++ b/inference.py @@ -221,9 +221,9 @@ def main(): parser.add_argument( "--stg_mode", type=str, - default="attention", + default="attention_values", help="Spatiotemporal guidance mode. " - "It can be one of 'attention' (default), 'residual', or 'transformer_block'.", + "It can be one of 'attention_values' (default), 'attension_skip', 'residual', or 'transformer_block'.", ) parser.add_argument( "--stg_skip_layers", @@ -464,8 +464,10 @@ def infer( # Set spatiotemporal guidance skip_block_list = [int(x.strip()) for x in stg_skip_layers.split(",")] - if stg_mode.lower() == "stg_a" or stg_mode.lower() == "attention": - skip_layer_strategy = SkipLayerStrategy.Attention + if stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": + skip_layer_strategy = SkipLayerStrategy.AttentionSkip + elif stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": + skip_layer_strategy = SkipLayerStrategy.AttentionValues elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": skip_layer_strategy = SkipLayerStrategy.Residual elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": diff --git a/ltx_video/models/transformers/attention.py b/ltx_video/models/transformers/attention.py index 6eeac94..6ea0ed1 100644 --- a/ltx_video/models/transformers/attention.py +++ b/ltx_video/models/transformers/attention.py @@ -1013,6 +1013,7 @@ def __call__( query = attn.apply_rotary_emb(query, freqs_cis) value = attn.to_v(encoder_hidden_states) + value_for_stg = value inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1070,11 +1071,18 @@ def __call__( if ( skip_layer_mask is not None - and skip_layer_strategy == SkipLayerStrategy.Attention + and skip_layer_strategy == SkipLayerStrategy.AttentionSkip ): hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( 1.0 - skip_layer_mask ) + elif ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionValues + ): + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( + 1.0 - skip_layer_mask + ) else: hidden_states = hidden_states_a diff --git a/ltx_video/utils/skip_layer_strategy.py b/ltx_video/utils/skip_layer_strategy.py index 39638b5..30f9016 100644 --- a/ltx_video/utils/skip_layer_strategy.py +++ b/ltx_video/utils/skip_layer_strategy.py @@ -2,6 +2,7 @@ class SkipLayerStrategy(Enum): - Attention = auto() + AttentionSkip = auto() + AttentionValues = auto() Residual = auto() TransformerBlock = auto() diff --git a/tests/test_inference.py b/tests/test_inference.py index 6ac9fbb..09b41b0 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -26,7 +26,7 @@ def test_infer_runs_on_real_path(test_paths, do_txt_to_image): "guidance_scale": 2.5, "stg_scale": 1, "stg_rescale": 0.7, - "stg_mode": "stg_a", + "stg_mode": "attention_values", "stg_skip_layers": "1,2,3", "image_cond_noise_scale": 0.15, "height": 480, @@ -79,7 +79,7 @@ def test_pipeline_on_batch(test_paths): "do_rescaling": True, "stg_scale": 1, "rescaling_scale": 0.7, - "skip_layer_strategy": SkipLayerStrategy.Attention, + "skip_layer_strategy": SkipLayerStrategy.AttentionValues, "skip_block_list": [1, 2], "image_cond_noise_scale": 0.15, "height": 480,