Skip to content

Commit

Permalink
Merge pull request #115 from Lightricks/public/feature/stg-attention-…
Browse files Browse the repository at this point in the history
…values

STG: Add a new STG strategy: AttentionValues (default)
  • Loading branch information
yoavhacohen authored Feb 16, 2025
2 parents f5a27c9 + 9d0df65 commit 84e36db
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
10 changes: 6 additions & 4 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def main():
parser.add_argument(
"--stg_mode",
type=str,
default="attention",
default="attention_values",
help="Spatiotemporal guidance mode. "
"It can be one of 'attention' (default), 'residual', or 'transformer_block'.",
"It can be one of 'attention_values' (default), 'attension_skip', 'residual', or 'transformer_block'.",
)
parser.add_argument(
"--stg_skip_layers",
Expand Down Expand Up @@ -464,8 +464,10 @@ def infer(

# Set spatiotemporal guidance
skip_block_list = [int(x.strip()) for x in stg_skip_layers.split(",")]
if stg_mode.lower() == "stg_a" or stg_mode.lower() == "attention":
skip_layer_strategy = SkipLayerStrategy.Attention
if stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
elif stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
skip_layer_strategy = SkipLayerStrategy.AttentionValues
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
skip_layer_strategy = SkipLayerStrategy.Residual
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
Expand Down
10 changes: 9 additions & 1 deletion ltx_video/models/transformers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,7 @@ def __call__(
query = attn.apply_rotary_emb(query, freqs_cis)

value = attn.to_v(encoder_hidden_states)
value_for_stg = value

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
Expand Down Expand Up @@ -1070,11 +1071,18 @@ def __call__(

if (
skip_layer_mask is not None
and skip_layer_strategy == SkipLayerStrategy.Attention
and skip_layer_strategy == SkipLayerStrategy.AttentionSkip
):
hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (
1.0 - skip_layer_mask
)
elif (
skip_layer_mask is not None
and skip_layer_strategy == SkipLayerStrategy.AttentionValues
):
hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (
1.0 - skip_layer_mask
)
else:
hidden_states = hidden_states_a

Expand Down
3 changes: 2 additions & 1 deletion ltx_video/utils/skip_layer_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


class SkipLayerStrategy(Enum):
Attention = auto()
AttentionSkip = auto()
AttentionValues = auto()
Residual = auto()
TransformerBlock = auto()
4 changes: 2 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_infer_runs_on_real_path(test_paths, do_txt_to_image):
"guidance_scale": 2.5,
"stg_scale": 1,
"stg_rescale": 0.7,
"stg_mode": "stg_a",
"stg_mode": "attention_values",
"stg_skip_layers": "1,2,3",
"image_cond_noise_scale": 0.15,
"height": 480,
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_pipeline_on_batch(test_paths):
"do_rescaling": True,
"stg_scale": 1,
"rescaling_scale": 0.7,
"skip_layer_strategy": SkipLayerStrategy.Attention,
"skip_layer_strategy": SkipLayerStrategy.AttentionValues,
"skip_block_list": [1, 2],
"image_cond_noise_scale": 0.15,
"height": 480,
Expand Down

0 comments on commit 84e36db

Please sign in to comment.