Skip to content

Commit

Permalink
Merge PR #455 from Kosinkadink/develop - Per-Block Effect/Scale control
Browse files Browse the repository at this point in the history
Per-Block Effect and Scale control
  • Loading branch information
Kosinkadink authored Aug 16, 2024
2 parents 2b1de6c + 8c76a9c commit fb904a3
Show file tree
Hide file tree
Showing 10 changed files with 742 additions and 111 deletions.
27 changes: 19 additions & 8 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight
from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding
from .context import ContextOptions, ContextOptions, ContextOptionsGroup
from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention,
from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks,
has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len)
from .logger import logger
from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA,
Expand Down Expand Up @@ -752,8 +752,9 @@ def __init__(self, *args, **kwargs):
self.timestep_range: tuple[float, float] = None
self.keyframes: ADKeyframeGroup = ADKeyframeGroup()

self.scale_multival = None
self.effect_multival = None
self.scale_multival: Union[float, Tensor, None] = None
self.effect_multival: Union[float, Tensor, None] = None
self.per_block_list: Union[list[PerBlock], None] = None

# AnimateLCM-I2V
self.orig_ref_drift: float = None
Expand Down Expand Up @@ -789,6 +790,7 @@ def __init__(self, *args, **kwargs):
self.current_pia_input: InputPIA = None
self.combined_scale: Union[float, Tensor] = None
self.combined_effect: Union[float, Tensor] = None
self.combined_per_block_list: Union[float, Tensor] = None
self.combined_cameractrl_effect: Union[float, Tensor] = None
self.combined_pia_mask: Union[float, Tensor] = None
self.combined_pia_effect: Union[float, Tensor] = None
Expand Down Expand Up @@ -819,8 +821,8 @@ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patc

def pre_run(self, model: ModelPatcherAndInjector):
self.cleanup()
self.model.set_scale(self.scale_multival)
self.model.set_effect(self.effect_multival)
self.model.set_scale(self.scale_multival, self.per_block_list)
self.model.set_effect(self.effect_multival, self.per_block_list)
self.model.set_cameractrl_effect(self.cameractrl_multival)
if self.model.img_encoder is not None:
self.model.img_encoder.set_ref_drift(self.orig_ref_drift)
Expand Down Expand Up @@ -883,8 +885,8 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor):
self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x)
self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input)
# apply scale and effect
self.model.set_scale(self.combined_scale)
self.model.set_effect(self.combined_effect)
self.model.set_scale(self.combined_scale, self.per_block_list)
self.model.set_effect(self.combined_effect, self.per_block_list) # TODO: set combined_per_block_list
self.model.set_cameractrl_effect(self.combined_cameractrl_effect)
# apply effect - if not within range, set effect to 0, effectively turning model off
if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]:
Expand All @@ -893,7 +895,7 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor):
else:
# if was not in range last step, apply effect to toggle AD status
if not self.was_within_range:
self.model.set_effect(self.combined_effect)
self.model.set_effect(self.combined_effect, self.per_block_list)
self.was_within_range = True
# update steps current keyframe is used
self.current_used_steps += 1
Expand Down Expand Up @@ -1048,6 +1050,7 @@ def cleanup(self):
self.current_effect = None
self.combined_scale = None
self.combined_effect = None
self.combined_per_block_list = None
self.was_within_range = False
self.prev_sub_idxs = None
self.prev_batched_number = None
Expand Down Expand Up @@ -1375,6 +1378,14 @@ def validate_model_compatibility_gen2(model: ModelPatcher, motion_model: MotionM
+ f"but the provided model is type {model_sd_type}.")


def validate_per_block_compatibility(motion_model: MotionModelPatcher, all_per_blocks: AllPerBlocks):
if all_per_blocks is None or all_per_blocks.sd_type is None:
return
mm_info = motion_model.model.mm_info
if all_per_blocks.sd_type != mm_info.sd_type:
raise Exception(f"Per-Block provided is meant for {all_per_blocks.sd_type}, but provided motion module is for {mm_info.sd_type}.")


def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int):
pe_shape = model_dict[key].shape
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1)
Expand Down
Loading

0 comments on commit fb904a3

Please sign in to comment.