diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index aa048b2..0b2ed14 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -19,7 +19,7 @@ from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding from .context import ContextOptions, ContextOptions, ContextOptionsGroup -from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention, +from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks, has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) from .logger import logger from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA, @@ -752,8 +752,9 @@ def __init__(self, *args, **kwargs): self.timestep_range: tuple[float, float] = None self.keyframes: ADKeyframeGroup = ADKeyframeGroup() - self.scale_multival = None - self.effect_multival = None + self.scale_multival: Union[float, Tensor, None] = None + self.effect_multival: Union[float, Tensor, None] = None + self.per_block_list: Union[list[PerBlock], None] = None # AnimateLCM-I2V self.orig_ref_drift: float = None @@ -789,6 +790,7 @@ def __init__(self, *args, **kwargs): self.current_pia_input: InputPIA = None self.combined_scale: Union[float, Tensor] = None self.combined_effect: Union[float, Tensor] = None + self.combined_per_block_list: Union[float, Tensor] = None self.combined_cameractrl_effect: Union[float, Tensor] = None self.combined_pia_mask: Union[float, Tensor] = None self.combined_pia_effect: Union[float, Tensor] = None @@ -819,8 +821,8 @@ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patc def pre_run(self, model: ModelPatcherAndInjector): self.cleanup() - self.model.set_scale(self.scale_multival) - self.model.set_effect(self.effect_multival) + self.model.set_scale(self.scale_multival, self.per_block_list) + self.model.set_effect(self.effect_multival, self.per_block_list) self.model.set_cameractrl_effect(self.cameractrl_multival) if self.model.img_encoder is not None: self.model.img_encoder.set_ref_drift(self.orig_ref_drift) @@ -883,8 +885,8 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x) self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input) # apply scale and effect - self.model.set_scale(self.combined_scale) - self.model.set_effect(self.combined_effect) + self.model.set_scale(self.combined_scale, self.per_block_list) + self.model.set_effect(self.combined_effect, self.per_block_list) # TODO: set combined_per_block_list self.model.set_cameractrl_effect(self.combined_cameractrl_effect) # apply effect - if not within range, set effect to 0, effectively turning model off if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]: @@ -893,7 +895,7 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): else: # if was not in range last step, apply effect to toggle AD status if not self.was_within_range: - self.model.set_effect(self.combined_effect) + self.model.set_effect(self.combined_effect, self.per_block_list) self.was_within_range = True # update steps current keyframe is used self.current_used_steps += 1 @@ -1048,6 +1050,7 @@ def cleanup(self): self.current_effect = None self.combined_scale = None self.combined_effect = None + self.combined_per_block_list = None self.was_within_range = False self.prev_sub_idxs = None self.prev_batched_number = None @@ -1375,6 +1378,14 @@ def validate_model_compatibility_gen2(model: ModelPatcher, motion_model: MotionM + f"but the provided model is type {model_sd_type}.") +def validate_per_block_compatibility(motion_model: MotionModelPatcher, all_per_blocks: AllPerBlocks): + if all_per_blocks is None or all_per_blocks.sd_type is None: + return + mm_info = motion_model.model.mm_info + if all_per_blocks.sd_type != mm_info.sd_type: + raise Exception(f"Per-Block provided is meant for {all_per_blocks.sd_type}, but provided motion module is for {mm_info.sd_type}.") + + def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int): pe_shape = model_dict[key].shape temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 5435ad3..1830266 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -1,6 +1,8 @@ import math from typing import Iterable, Tuple, Union, TYPE_CHECKING import re +from dataclasses import dataclass +from collections.abc import Iterable as IterColl import torch from einops import rearrange, repeat @@ -20,7 +22,8 @@ from .adapter_animatelcm_i2v import AdapterEmbed if TYPE_CHECKING: # avoids circular import from .adapter_cameractrl import CameraPoseEncoder -from .utils_motion import CrossAttentionMM, MotionCompatibilityError, DummyNNModule, extend_to_batch_size, prepare_mask_batch +from .utils_motion import (CrossAttentionMM, MotionCompatibilityError, DummyNNModule, extend_to_batch_size, extend_list_to_batch_size, + prepare_mask_batch, get_combined_multival) from .utils_model import BetaSchedules, ModelTypeSD from .logger import logger @@ -60,6 +63,61 @@ def get_string(self): return f"{self.mm_name}:{self.mm_version}:{self.mm_format}:{self.sd_type}" +####################### +# Facilitate Per-Block Effect and Scale Control +class PerAttn: + def __init__(self, attn_idx: Union[int, None], scale: Union[float, Tensor, None]): + self.attn_idx = attn_idx + self.scale = scale + + def matches(self, id: int): + if self.attn_idx is None: + return True + return self.attn_idx == id + + +class PerBlockId: + def __init__(self, block_type: str, block_idx: Union[int, None]=None, module_idx: Union[int, None]=None): + self.block_type = block_type + self.block_idx = block_idx + self.module_idx = module_idx + + def matches(self, other: 'PerBlockId') -> bool: + # block_type + if other.block_type != self.block_type: + return False + # block_idx + if other.block_idx is None: + return True + elif other.block_idx != self.block_idx: + return False + # module_idx + if other.module_idx is None: + return True + return other.module_idx == self.module_idx + + def __str__(self): + return f"PerBlockId({self.block_type},{self.block_idx},{self.module_idx})" + + +class PerBlock: + def __init__(self, id: PerBlockId, effect: Union[float, Tensor, None]=None, + scales: Union[list[Union[float, Tensor, None]], None]=None): + self.id = id + self.effect = effect + self.scales = scales + + def matches(self, id: PerBlockId): + return self.id.matches(id) + + +@dataclass +class AllPerBlocks: + per_block_list: list[PerBlock] + sd_type: Union[str, None] = None +#---------------------- +####################### + def is_hotshotxl(mm_state_dict: dict[str, Tensor]) -> bool: # use pos_encoder naming to determine if hotshotxl model for key in mm_state_dict.keys(): @@ -233,6 +291,7 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo): temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.MID, ops=ops) self.AD_video_length: int = 24 self.effect_model = 1.0 + self.effect_per_block_list = None # AnimateLCM-I2V stuff - create AdapterEmbed if keys present for it self.img_encoder: AdapterEmbed = None if has_img_encoder(mm_state_dict): @@ -296,7 +355,7 @@ def get_best_beta_schedule(self, log=False) -> str: def cleanup(self): self._reset_sub_idxs() - self._reset_scale_multiplier() + self._reset_scale() self._reset_temp_vars() if self.img_encoder is not None: self.img_encoder.cleanup() @@ -416,31 +475,32 @@ def set_video_length(self, video_length: int, full_length: int): if self.mid_block is not None: self.mid_block.set_video_length(video_length, full_length) - def set_scale(self, multival: Union[float, Tensor]): - if multival is None: - multival = 1.0 - if type(multival) == Tensor: - self._set_scale_multiplier(1.0) - self._set_scale_mask(multival) - else: - self._set_scale_multiplier(multival) - self._set_scale_mask(None) + def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): + if self.down_blocks is not None: + for block in self.down_blocks: + block.set_scale(scale, per_block_list) + if self.up_blocks is not None: + for block in self.up_blocks: + block.set_scale(scale, per_block_list) + if self.mid_block is not None: + self.mid_block.set_scale(scale, per_block_list) - def set_effect(self, multival: Union[float, Tensor]): + def set_effect(self, multival: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): # keep track of if model is in effect if multival is None: self.effect_model = 1.0 else: self.effect_model = multival + self.effect_per_block_list = per_block_list # pass down effect multival to all blocks if self.down_blocks is not None: for block in self.down_blocks: - block.set_effect(multival) + block.set_effect(multival, per_block_list) if self.up_blocks is not None: for block in self.up_blocks: - block.set_effect(multival) + block.set_effect(multival, per_block_list) if self.mid_block is not None: - self.mid_block.set_effect(multival) + self.mid_block.set_effect(multival, per_block_list) def is_in_effect(self): if type(self.effect_model) == Tensor: @@ -490,26 +550,6 @@ def set_camera_features(self, camera_features: list[Tensor]): if self.up_blocks is not None: for block in self.up_blocks: block.set_camera_features(camera_features=list(reversed(camera_features))) - - def _set_scale_multiplier(self, multiplier: Union[float, None]): - if self.down_blocks is not None: - for block in self.down_blocks: - block.set_scale_multiplier(multiplier) - if self.up_blocks is not None: - for block in self.up_blocks: - block.set_scale_multiplier(multiplier) - if self.mid_block is not None: - self.mid_block.set_scale_multiplier(multiplier) - - def _set_scale_mask(self, mask: Tensor): - if self.down_blocks is not None: - for block in self.down_blocks: - block.set_scale_mask(mask) - if self.up_blocks is not None: - for block in self.up_blocks: - block.set_scale_mask(mask) - if self.mid_block is not None: - self.mid_block.set_scale_mask(mask) def _reset_temp_vars(self): if self.down_blocks is not None: @@ -521,8 +561,8 @@ def _reset_temp_vars(self): if self.mid_block is not None: self.mid_block.reset_temp_vars() - def _reset_scale_multiplier(self): - self._set_scale_multiplier(None) + def _reset_scale(self): + self.set_scale(None) def _reset_sub_idxs(self): self.set_sub_idxs(None) @@ -557,17 +597,13 @@ def set_video_length(self, video_length: int, full_length: int): for motion_module in self.motion_modules: motion_module.set_video_length(video_length, full_length) - def set_scale_multiplier(self, multiplier: Union[float, None]): - for motion_module in self.motion_modules: - motion_module.set_scale_multiplier(multiplier) - - def set_scale_mask(self, mask: Tensor): + def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): for motion_module in self.motion_modules: - motion_module.set_scale_mask(mask) - - def set_effect(self, multival: Union[float, Tensor]): + motion_module.set_scale(scale, per_block_list) + + def set_effect(self, multival: Union[float, Tensor], per_block_list: Union[list[PerBlock], None]=None): for motion_module in self.motion_modules: - motion_module.set_effect(multival) + motion_module.set_effect(multival, per_block_list) def set_cameractrl_effect(self, multival: Union[float, Tensor]): for motion_module in self.motion_modules: @@ -628,6 +664,7 @@ def __init__( self.block_type = block_type self.block_idx = block_idx self.module_idx = module_idx + self.id = PerBlockId(block_type=block_type, block_idx=block_idx, module_idx=module_idx) # effect vars self.effect = None self.temp_effect_mask: Tensor = None @@ -649,6 +686,7 @@ def __init__( cross_frame_attention_mode=cross_frame_attention_mode, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, + block_id=self.id, ops=ops ) @@ -661,14 +699,17 @@ def set_video_length(self, video_length: int, full_length: int): self.video_length = video_length self.full_length = full_length self.temporal_transformer.set_video_length(video_length, full_length) - - def set_scale_multiplier(self, multiplier: Union[float, None]): - self.temporal_transformer.set_scale_multiplier(multiplier) - def set_scale_mask(self, mask: Tensor): - self.temporal_transformer.set_scale_mask(mask) + def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): + self.temporal_transformer.set_scale(scale, per_block_list) - def set_effect(self, multival: Union[float, Tensor]): + def set_effect(self, multival: Union[float, Tensor], per_block_list: Union[list[PerBlock], None]=None): + if per_block_list is not None: + for per_block in per_block_list: + if self.id.matches(per_block.id) and per_block.effect is not None: + multival = get_combined_multival(multival, per_block.effect) + #logger.info(f"block_type: {self.block_type}, block_idx: {self.block_idx}, module_idx: {self.module_idx}") + break if type(multival) == Tensor: self.effect = multival elif multival is not None and math.isclose(multival, 1.0): @@ -677,7 +718,7 @@ def set_effect(self, multival: Union[float, Tensor]): self.effect = multival self.temp_effect_mask = None - def set_cameractrl_effect(self, multival: Union[float, Tensor]): + def set_cameractrl_effect(self, multival: Union[float, Tensor, None]): if type(multival) == Tensor: pass elif multival is None: @@ -741,6 +782,7 @@ def should_handle_camera_features(self): return self.camera_features is not None and self.block_type != BlockType.MID# and self.module_idx == 0 def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_mask=None): + #logger.info(f"block_type: {self.block_type}, block_idx: {self.block_idx}, module_idx: {self.module_idx}") mm_kwargs = None if self.should_handle_camera_features(): mm_kwargs = {"camera_feature": self.camera_features[self.block_idx]} @@ -786,13 +828,13 @@ def __init__( cross_frame_attention_mode=None, temporal_pe=False, temporal_pe_max_len=24, + block_id: PerBlockId=None, ops=comfy.ops.disable_weight_init, ): super().__init__() + self.id = block_id self.video_length = 16 self.full_length = 16 - self.raw_scale_mask: Union[Tensor, None] = None - self.temp_scale_mask: Union[Tensor, None] = None self.sub_idxs: Union[list[int], None] = None self.prev_hidden_states_batch = 0 @@ -831,30 +873,59 @@ def __init__( ) self.proj_out = ops.Linear(inner_dim, in_channels) + self.raw_scale_masks: Union[list[Tensor], None] = [None] * self.get_attention_count() + self.temp_scale_masks: Union[list[Tensor], None] = [None] * self.get_attention_count() + + def get_attention_count(self): + if len(self.transformer_blocks) > 0: + return len(self.transformer_blocks[0].attention_blocks) + return 0 + def set_video_length(self, video_length: int, full_length: int): self.video_length = video_length self.full_length = full_length - - def set_scale_multiplier(self, multiplier: Union[float, None]): - for block in self.transformer_blocks: - block.set_scale_multiplier(multiplier) - def set_scale_mask(self, mask: Tensor): - self.raw_scale_mask = mask - self.temp_scale_mask = None + def set_scale_multiplier(self, idx: int, multiplier: Union[float, list[float], None]): + for block in self.transformer_blocks: + block.set_scale_multiplier(idx, multiplier) + + def set_scale_mask(self, idx: int, mask: Tensor): + self.raw_scale_masks[idx] = mask + self.temp_scale_masks[idx] = None + + def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): + if per_block_list is not None: + for per_block in per_block_list: + if self.id.matches(per_block.id) and len(per_block.scales) > 0: + scales = [] + for sub_scale in per_block.scales: + scales.append(get_combined_multival(scale, sub_scale)) + #logger.info(f"scale - block_type: {self.id.block_type}, block_idx: {self.id.block_idx}, module_idx: {self.id.module_idx}") + scale = scales + break + + if type(scale) == Tensor or not isinstance(scale, IterColl): + scale = [scale] + scale = extend_list_to_batch_size(scale, self.get_attention_count()) + for idx, sub_scale in enumerate(scale): + if type(sub_scale) == Tensor: + self.set_scale_mask(idx, sub_scale) + self.set_scale_multiplier(idx, None) + else: + self.set_scale_mask(idx, None) + self.set_scale_multiplier(idx, sub_scale) def set_cameractrl_effect(self, multival: Union[float, Tensor]): self.raw_cameractrl_effect = multival self.temp_cameractrl_effect = None def set_sub_idxs(self, sub_idxs: list[int]): - self.sub_idxs = sub_idxs for block in self.transformer_blocks: block.set_sub_idxs(sub_idxs) def reset_temp_vars(self): - del self.temp_scale_mask - self.temp_scale_mask = None + del self.temp_scale_masks + self.temp_scale_masks = [None] * self.get_attention_count() self.prev_hidden_states_batch = 0 del self.temp_cameractrl_effect self.temp_cameractrl_effect = None @@ -862,25 +933,36 @@ def reset_temp_vars(self): for block in self.transformer_blocks: block.reset_temp_vars() - def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: + def get_scale_masks(self, hidden_states: Tensor) -> Union[Tensor, None]: + masks = [] + prev_mask = None + prev_idx = 0 + for idx in range(len(self.raw_scale_masks)): + if prev_mask is self.raw_scale_masks[idx]: + masks.append(self.temp_scale_masks[prev_idx]) + else: + masks.append(self.get_scale_mask(idx=idx, hidden_states=hidden_states)) + prev_idx = idx + return masks + + def get_scale_mask(self, idx: int, hidden_states: Tensor) -> Union[Tensor, None]: # if no raw mask, return None - if self.raw_scale_mask is None: + if self.raw_scale_masks[idx] is None: return None shape = hidden_states.shape batch, channel, height, width = shape # if temp mask already calculated, return it - if self.temp_scale_mask != None: + if self.temp_scale_masks[idx] != None: # check if hidden_states batch matches if batch == self.prev_hidden_states_batch: if self.sub_idxs is not None: - return self.temp_scale_mask[:, self.sub_idxs, :] - return self.temp_scale_mask + return self.temp_scale_masks[idx][:, self.sub_idxs, :] + return self.temp_scale_masks[idx] # if does not match, reset cached temp_scale_mask and recalculate it - del self.temp_scale_mask - self.temp_scale_mask = None + self.temp_scale_masks[idx] = None # otherwise, calculate temp mask self.prev_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) + mask = prepare_mask_batch(self.raw_scale_masks[idx], shape=(self.full_length, 1, height, width)) mask = repeat_to_batch_size(mask, self.full_length) # if mask not the same amount length as full length, make it match if self.full_length != mask.shape[0]: @@ -897,13 +979,13 @@ def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: if batched_number > 1: mask = torch.cat([mask] * batched_number, dim=0) # cache mask and set to proper device - self.temp_scale_mask = mask + self.temp_scale_masks[idx] = mask # move temp_scale_mask to proper dtype + device - self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device) + self.temp_scale_masks[idx] = self.temp_scale_masks[idx].to(dtype=hidden_states.dtype, device=hidden_states.device) # return subset of masks, if needed if self.sub_idxs is not None: - return self.temp_scale_mask[:, self.sub_idxs, :] - return self.temp_scale_mask + return self.temp_scale_masks[idx][:, self.sub_idxs, :] + return self.temp_scale_masks[idx] def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, None]: # if no raw camera_Ctrl, return None @@ -926,7 +1008,7 @@ def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, N self.temp_cameractrl_effect = None # otherwise, calculate temp_cameractrl self.prev_cameractrl_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) + mask = prepare_mask_batch(self.raw_cameractrl_effect, shape=(self.full_length, 1, height, width)) mask = repeat_to_batch_size(mask, self.full_length) # if mask not the same amount length as full length, make it match if self.full_length != mask.shape[0]: @@ -954,7 +1036,7 @@ def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, N def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options: ContextOptions=None, mm_kwargs: dict[str]=None): batch, channel, height, width = hidden_states.shape residual = hidden_states - scale_mask = self.get_scale_mask(hidden_states) + scale_masks = self.get_scale_masks(hidden_states) cameractrl_effect = self.get_cameractrl_effect(hidden_states) # add some casts for fp8 purposes - does not affect speed otherwise hidden_states = self.norm(hidden_states).to(hidden_states.dtype) @@ -971,7 +1053,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, video_length=self.video_length, - scale_mask=scale_mask, + scale_masks=scale_masks, cameractrl_effect=cameractrl_effect, view_options=view_options, mm_kwargs=mm_kwargs @@ -1044,9 +1126,8 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops) self.ff_norm = ops.LayerNorm(dim) - def set_scale_multiplier(self, multiplier: Union[float, None]): - for block in self.attention_blocks: - block.set_scale_multiplier(multiplier) + def set_scale_multiplier(self, idx: int, multiplier: Union[float, None]): + self.attention_blocks[idx].set_scale_multiplier(multiplier) def set_sub_idxs(self, sub_idxs: list[int]): for block in self.attention_blocks: @@ -1062,11 +1143,13 @@ def forward( encoder_hidden_states: Tensor=None, attention_mask: Tensor=None, video_length: int=None, - scale_mask: Tensor=None, + scale_masks: list[Tensor]=None, cameractrl_effect: Union[float, Tensor] = None, - view_options: ContextOptions=None, + view_options: Union[ContextOptions, None]=None, mm_kwargs: dict[str]=None, ): + if scale_masks is None: + scale_masks = [None] * len(self.attention_blocks) # make view_options None if context_length > video_length, or if equal and equal not allowed if view_options: if view_options.context_length > video_length: @@ -1074,7 +1157,7 @@ def forward( elif view_options.context_length == video_length and not view_options.use_on_equal_length: view_options = None if not view_options: - for attention_block, norm in zip(self.attention_blocks, self.norms): + for attention_block, norm, scale_mask in zip(self.attention_blocks, self.norms, scale_masks): norm_hidden_states = norm(hidden_states).to(hidden_states.dtype) hidden_states = ( attention_block( @@ -1109,7 +1192,7 @@ def forward( sub_hidden_states = rearrange(hidden_states[:, sub_idxs], "b f d c -> (b f) d c") if has_camera_feature: mm_kwargs["camera_feature"] = orig_camera_feature[:, sub_idxs, :] - for attention_block, norm in zip(self.attention_blocks, self.norms): + for attention_block, norm, scale_mask in zip(self.attention_blocks, self.norms, scale_masks): norm_hidden_states = norm(sub_hidden_states).to(sub_hidden_states.dtype) sub_hidden_states = ( attention_block( diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 43e718b..a429789 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -39,6 +39,9 @@ WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) from .nodes_scheduling import (PromptSchedulingNode, PromptSchedulingLatentsNode, ValueSchedulingNode, ValueSchedulingLatentsNode, AddValuesReplaceNode, FloatToFloatsNode) +from .nodes_per_block import (ADBlockComboNode, ADBlockIndivNode, PerBlockHighLevelNode, + PerBlock_SD15_LowLevelNode, PerBlock_SD15_MidLevelNode, PerBlock_SD15_FromFloatsNode, + PerBlock_SDXL_LowLevelNode, PerBlock_SDXL_MidLevelNode, PerBlock_SDXL_FromFloatsNode) from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect, PerturbedAttentionGuidanceMultival, RescaleCFGMultival from .nodes_deprecated import (AnimateDiffLoader_Deprecated, AnimateDiffLoaderAdvanced_Deprecated, AnimateDiffCombine_Deprecated, AnimateDiffModelSettings, AnimateDiffModelSettingsSimple, AnimateDiffModelSettingsAdvanced, AnimateDiffModelSettingsAdvancedAttnStrengths) @@ -162,6 +165,16 @@ ValueSchedulingLatentsNode.NodeID: ValueSchedulingLatentsNode, AddValuesReplaceNode.NodeID: AddValuesReplaceNode, FloatToFloatsNode.NodeID: FloatToFloatsNode, + # Per-Block + ADBlockComboNode.NodeID: ADBlockComboNode, + ADBlockIndivNode.NodeID: ADBlockIndivNode, + PerBlockHighLevelNode.NodeID: PerBlockHighLevelNode, + PerBlock_SD15_MidLevelNode.NodeID: PerBlock_SD15_MidLevelNode, + PerBlock_SD15_LowLevelNode.NodeID: PerBlock_SD15_LowLevelNode, + PerBlock_SD15_FromFloatsNode.NodeID: PerBlock_SD15_FromFloatsNode, + PerBlock_SDXL_MidLevelNode.NodeID: PerBlock_SDXL_MidLevelNode, + PerBlock_SDXL_LowLevelNode.NodeID: PerBlock_SDXL_LowLevelNode, + PerBlock_SDXL_FromFloatsNode.NodeID: PerBlock_SDXL_FromFloatsNode, # Extras Nodes "ADE_AnimateDiffUnload": AnimateDiffUnload, "ADE_EmptyLatentImageLarge": EmptyLatentImageLarge, @@ -319,6 +332,16 @@ ValueSchedulingLatentsNode.NodeID: ValueSchedulingLatentsNode.NodeName, AddValuesReplaceNode.NodeID: AddValuesReplaceNode.NodeName, FloatToFloatsNode.NodeID:FloatToFloatsNode.NodeName, + # Per-Block + ADBlockComboNode.NodeID: ADBlockComboNode.NodeName, + ADBlockIndivNode.NodeID: ADBlockIndivNode.NodeName, + PerBlockHighLevelNode.NodeID: PerBlockHighLevelNode.NodeName, + PerBlock_SD15_MidLevelNode.NodeID: PerBlock_SD15_MidLevelNode.NodeName, + PerBlock_SD15_LowLevelNode.NodeID: PerBlock_SD15_LowLevelNode.NodeName, + PerBlock_SD15_FromFloatsNode.NodeID: PerBlock_SD15_FromFloatsNode.NodeName, + PerBlock_SDXL_MidLevelNode.NodeID: PerBlock_SDXL_MidLevelNode.NodeName, + PerBlock_SDXL_LowLevelNode.NodeID: PerBlock_SDXL_LowLevelNode.NodeName, + PerBlock_SDXL_FromFloatsNode.NodeID: PerBlock_SDXL_FromFloatsNode.NodeName, # Extras Nodes "ADE_AnimateDiffUnload": "AnimateDiff Unload 🎭🅐🅓", "ADE_EmptyLatentImageLarge": "Empty Latent Image (Big Batch) 🎭🅐🅓", diff --git a/animatediff/nodes_animatelcmi2v.py b/animatediff/nodes_animatelcmi2v.py index 370f5f0..9d63e12 100644 --- a/animatediff/nodes_animatelcmi2v.py +++ b/animatediff/nodes_animatelcmi2v.py @@ -34,7 +34,8 @@ def INPUT_TYPES(s): "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), - "autosize": ("ADEAUTOSIZE", {"padding": 70}), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -44,11 +45,12 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, ref_latent: dict, ref_drift: float=0.0, apply_ref_when_disabled=False, start_percent: float=0.0, end_percent: float=1.0, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, - scale_multival=None, effect_multival=None, + scale_multival=None, effect_multival=None, per_block=None, prev_m_models: MotionModelGroup=None,): new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, motion_lora=motion_lora, ad_keyframes=ad_keyframes, - scale_multival=scale_multival, effect_multival=effect_multival, prev_m_models=prev_m_models) + scale_multival=scale_multival, effect_multival=effect_multival, per_block=per_block, + prev_m_models=prev_m_models) # most recent added model will always be first in list; curr_model = new_m_models[0].models[0] # confirm that model contains img_encoder diff --git a/animatediff/nodes_cameractrl.py b/animatediff/nodes_cameractrl.py index 9352bae..6ba94e7 100644 --- a/animatediff/nodes_cameractrl.py +++ b/animatediff/nodes_cameractrl.py @@ -209,6 +209,7 @@ def INPUT_TYPES(s): "cameractrl_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), + "per_block": ("PER_BLOCK",), } } @@ -218,10 +219,10 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, cameractrl_poses: list[list[float]], start_percent: float=0.0, end_percent: float=1.0, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, - scale_multival=None, effect_multival=None, cameractrl_multival=None, + scale_multival=None, effect_multival=None, cameractrl_multival=None, per_block=None, prev_m_models: MotionModelGroup=None,): new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, - motion_lora=motion_lora, ad_keyframes=ad_keyframes, + motion_lora=motion_lora, ad_keyframes=ad_keyframes, per_block=per_block, scale_multival=scale_multival, effect_multival=effect_multival, prev_m_models=prev_m_models) # most recent added model will always be first in list; curr_model = new_m_models[0].models[0] diff --git a/animatediff/nodes_gen1.py b/animatediff/nodes_gen1.py index ad93440..574c735 100644 --- a/animatediff/nodes_gen1.py +++ b/animatediff/nodes_gen1.py @@ -10,7 +10,10 @@ from .utils_model import BetaSchedules, get_available_motion_loras, get_available_motion_models, get_motion_lora_path from .utils_motion import ADKeyframeGroup, get_combined_multival from .motion_lora import MotionLoraInfo, MotionLoraList -from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelGroup, load_motion_lora_as_patches, load_motion_module_gen1, load_motion_module_gen2, validate_model_compatibility_gen2 +from .motion_module_ad import AllPerBlocks +from .model_injection import (InjectionParams, ModelPatcherAndInjector, MotionModelGroup, + load_motion_lora_as_patches, load_motion_module_gen1, load_motion_module_gen2, validate_model_compatibility_gen2, + validate_per_block_compatibility) from .sample_settings import SampleSettings, SeedNoiseGeneration from .sampling import motion_sample_factory @@ -33,6 +36,7 @@ def INPUT_TYPES(s): "sample_settings": ("SAMPLE_SETTINGS",), "scale_multival": ("MULTIVAL",), "effect_multival": ("MULTIVAL",), + "per_block": ("PER_BLOCK",), } } @@ -45,6 +49,7 @@ def load_mm_and_inject_params(self, model_name: str, beta_schedule: str,# apply_mm_groupnorm_hack: bool, context_options: ContextOptionsGroup=None, motion_lora: MotionLoraList=None, ad_settings: AnimateDiffSettings=None, sample_settings: SampleSettings=None, scale_multival=None, effect_multival=None, ad_keyframes: ADKeyframeGroup=None, + per_block: AllPerBlocks=None, ): # load motion module and motion settings, if included motion_model = load_motion_module_gen2(model_name=model_name, motion_model_settings=ad_settings) @@ -56,6 +61,9 @@ def load_mm_and_inject_params(self, load_motion_lora_as_patches(motion_model, lora) motion_model.scale_multival = scale_multival motion_model.effect_multival = effect_multival + if per_block is not None: + validate_per_block_compatibility(motion_model=motion_model, all_per_blocks=per_block) + motion_model.per_block_list = per_block.per_block_list motion_model.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() # create injection params diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index 8302df2..979c10f 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -9,8 +9,9 @@ from .utils_model import BIGMAX, BetaSchedules, get_available_motion_models from .utils_motion import ADKeyframeGroup, ADKeyframe, InputPIA from .motion_lora import MotionLoraList +from .motion_module_ad import AllPerBlocks from .model_injection import (InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher, create_fresh_motion_module, - load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2) + load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2, validate_per_block_compatibility) from .sample_settings import SampleSettings @@ -94,7 +95,8 @@ def INPUT_TYPES(s): "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), - "autosize": ("ADEAUTOSIZE", {"padding": 80}), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -104,7 +106,7 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: float=0.0, end_percent: float=1.0, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, - scale_multival=None, effect_multival=None, + scale_multival=None, effect_multival=None, per_block: AllPerBlocks=None, prev_m_models: MotionModelGroup=None,): # set up motion models list if prev_m_models is None: @@ -122,6 +124,9 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: fl load_motion_lora_as_patches(motion_model, lora) motion_model.scale_multival = scale_multival motion_model.effect_multival = effect_multival + if per_block is not None: + validate_per_block_compatibility(motion_model=motion_model, all_per_blocks=per_block) + motion_model.per_block_list = per_block.per_block_list motion_model.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() motion_model.timestep_percent_range = (start_percent, end_percent) # add to beginning, so that after injection, it will be the earliest of prev_m_models to be run @@ -141,7 +146,8 @@ def INPUT_TYPES(s): "scale_multival": ("MULTIVAL",), "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), - "autosize": ("ADEAUTOSIZE", {"padding": 40}), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -151,11 +157,12 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, motion_lora: MotionLoraList=None, - scale_multival=None, effect_multival=None, ad_keyframes=None): + scale_multival=None, effect_multival=None, ad_keyframes=None, + per_block: AllPerBlocks=None): # just a subset of normal ApplyAnimateDiffModelNode inputs return ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, motion_lora=motion_lora, scale_multival=scale_multival, effect_multival=effect_multival, - ad_keyframes=ad_keyframes) + ad_keyframes=ad_keyframes, per_block=per_block) class LoadAnimateDiffModelNode: diff --git a/animatediff/nodes_per_block.py b/animatediff/nodes_per_block.py new file mode 100644 index 0000000..b06621f --- /dev/null +++ b/animatediff/nodes_per_block.py @@ -0,0 +1,494 @@ +from typing import Union +from torch import Tensor + +from .documentation import short_desc, register_description, coll, DocHelper +from .motion_module_ad import PerBlock, PerBlockId, BlockType, AllPerBlocks +from .utils_model import ModelTypeSD +from .utils_motion import extend_list_to_batch_size + + +class ADBlockHolder: + def __init__(self, effect: Union[float, Tensor, None]=None, + scales: Union[list[float, Tensor], None]=list()): + self.effect = effect + self.scales = scales + + def has_effect(self): + return self.effect is not None + + def has_scale(self): + for scale in self.scales: + if scale is not None: + return True + return False + + def is_empty(self): + has_anything = self.has_effect() or self.has_scale() + return not has_anything + + +class ADBlockComboNode: + NodeID = 'ADE_ADBlockCombo' + NodeName = 'AD Block 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "effect": ("MULTIVAL",), + "scale": ("MULTIVAL",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("AD_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "block_control" + + def block_control(self, effect: Union[float, Tensor, None]=None, scale: Union[float, Tensor, None]=None): + scales = [scale, scale] + block = ADBlockHolder(effect=effect, scales=scales) + if block.is_empty(): + block = None + return (block,) + + +class ADBlockIndivNode: + NodeID = 'ADE_ADBlockIndiv' + NodeName = 'AD Block+ 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "effect": ("MULTIVAL",), + "scale_0": ("MULTIVAL",), + "scale_1": ("MULTIVAL",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("AD_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "block_control" + + def block_control(self, effect: Union[float, Tensor, None]=None, + scale_0: Union[float, Tensor, None]=None, scale_1: Union[float, Tensor, None]=None): + scales = [scale_0, scale_1] + block = ADBlockHolder(effect=effect, scales=scales) + if block.is_empty(): + block = None + return (block,) + + +class PerBlockHighLevelNode: + NodeID = 'ADE_PerBlockHighLevel' + NodeName = 'AD Per Block 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up: Union[ADBlockHolder, None]=None): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN): down, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP): up, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks),) + + +class PerBlock_SD15_MidLevelNode: + NodeID = 'ADE_PerBlock_SD15_MidLevel' + NodeName = 'AD Per Block+ (SD1.5) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down_0": ("AD_BLOCK",), + "down_1": ("AD_BLOCK",), + "down_2": ("AD_BLOCK",), + "down_3": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up_0": ("AD_BLOCK",), + "up_1": ("AD_BLOCK",), + "up_2": ("AD_BLOCK",), + "up_3": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down_0: Union[ADBlockHolder, None]=None, + down_1: Union[ADBlockHolder, None]=None, + down_2: Union[ADBlockHolder, None]=None, + down_3: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up_0: Union[ADBlockHolder, None]=None, + up_1: Union[ADBlockHolder, None]=None, + up_2: Union[ADBlockHolder, None]=None, + up_3: Union[ADBlockHolder, None]=None): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN, block_idx=0): down_0, + PerBlockId(block_type=BlockType.DOWN, block_idx=1): down_1, + PerBlockId(block_type=BlockType.DOWN, block_idx=2): down_2, + PerBlockId(block_type=BlockType.DOWN, block_idx=3): down_3, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP, block_idx=0): up_0, + PerBlockId(block_type=BlockType.UP, block_idx=1): up_1, + PerBlockId(block_type=BlockType.UP, block_idx=2): up_2, + PerBlockId(block_type=BlockType.UP, block_idx=3): up_3, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks, ModelTypeSD.SD1_5),) + + +class PerBlock_SD15_LowLevelNode: + NodeID = 'ADE_PerBlock_SD15_LowLevel' + NodeName = 'AD Per Block++ (SD1.5) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down_0__0": ("AD_BLOCK",), + "down_0__1": ("AD_BLOCK",), + "down_1__0": ("AD_BLOCK",), + "down_1__1": ("AD_BLOCK",), + "down_2__0": ("AD_BLOCK",), + "down_2__1": ("AD_BLOCK",), + "down_3__0": ("AD_BLOCK",), + "down_3__1": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up_0__0": ("AD_BLOCK",), + "up_0__1": ("AD_BLOCK",), + "up_0__2": ("AD_BLOCK",), + "up_1__0": ("AD_BLOCK",), + "up_1__1": ("AD_BLOCK",), + "up_1__2": ("AD_BLOCK",), + "up_2__0": ("AD_BLOCK",), + "up_2__1": ("AD_BLOCK",), + "up_2__2": ("AD_BLOCK",), + "up_3__0": ("AD_BLOCK",), + "up_3__1": ("AD_BLOCK",), + "up_3__2": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down_0__0: Union[ADBlockHolder, None]=None, + down_0__1: Union[ADBlockHolder, None]=None, + down_1__0: Union[ADBlockHolder, None]=None, + down_1__1: Union[ADBlockHolder, None]=None, + down_2__0: Union[ADBlockHolder, None]=None, + down_2__1: Union[ADBlockHolder, None]=None, + down_3__0: Union[ADBlockHolder, None]=None, + down_3__1: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up_0__0: Union[ADBlockHolder, None]=None, + up_0__1: Union[ADBlockHolder, None]=None, + up_0__2: Union[ADBlockHolder, None]=None, + up_1__0: Union[ADBlockHolder, None]=None, + up_1__1: Union[ADBlockHolder, None]=None, + up_1__2: Union[ADBlockHolder, None]=None, + up_2__0: Union[ADBlockHolder, None]=None, + up_2__1: Union[ADBlockHolder, None]=None, + up_2__2: Union[ADBlockHolder, None]=None, + up_3__0: Union[ADBlockHolder, None]=None, + up_3__1: Union[ADBlockHolder, None]=None, + up_3__2: Union[ADBlockHolder, None]=None): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN, block_idx=0, module_idx=0): down_0__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=0, module_idx=1): down_0__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=1, module_idx=0): down_1__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=1, module_idx=1): down_1__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=2, module_idx=0): down_2__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=2, module_idx=1): down_2__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=3, module_idx=0): down_3__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=3, module_idx=1): down_3__1, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=0): up_0__0, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=1): up_0__1, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=2): up_0__2, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=0): up_1__0, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=1): up_1__1, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=2): up_1__2, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=0): up_2__0, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=1): up_2__1, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=2): up_2__2, + PerBlockId(block_type=BlockType.UP, block_idx=3, module_idx=0): up_3__0, + PerBlockId(block_type=BlockType.UP, block_idx=3, module_idx=1): up_3__1, + PerBlockId(block_type=BlockType.UP, block_idx=3, module_idx=2): up_3__2, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks, ModelTypeSD.SD1_5),) + + +class PerBlock_SD15_FromFloatsNode: + NodeID = 'ADE_PerBlock_SD15_FromFloats' + NodeName = 'AD Per Block Floats (SD1.5) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "effect_21_floats": ("FLOATS",), + "scale_21_floats": ("FLOATS",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + Desc = [ + short_desc('Use Floats from Value Schedules to select SD1.5 effect/scale values for blocks.'), + 'SD1.5 Motion Modules contain 21 blocks:', + 'idx 0 - start of down blocks (down_0__0)', + 'idx 7 - end of down blocks (down_3__1)', + 'idx 8 - mid block (mid)', + 'idx 9 - start of up blocks (up_0__0)', + 'idx 20 - end of up blocks (up_3__2)', + ] + register_description(NodeID, Desc) + + def create_per_block(self, + effect_21_floats: Union[list[float], None]=None, + scale_21_floats: Union[list[float], None]=None): + if effect_21_floats is None and scale_21_floats is None: + return (None,) + # SD1.5 has 21 blocks + block_total = 21 + holders = [ADBlockHolder() for _ in range(block_total)] + if effect_21_floats is not None: + effect_21_floats = extend_list_to_batch_size(effect_21_floats, block_total) + for effect, holder in zip(effect_21_floats, holders): + holder.effect = effect + if scale_21_floats is not None: + scale_21_floats = extend_list_to_batch_size(scale_21_floats, block_total) + for scale, holder in zip(scale_21_floats, holders): + holder.scales = [scale, scale] + return PerBlock_SD15_LowLevelNode.create_per_block(self, *holders) + + +class PerBlock_SDXL_MidLevelNode: + NodeID = 'ADE_PerBlock_SDXL_MidLevel' + NodeName = 'AD Per Block+ (SDXL) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down_0": ("AD_BLOCK",), + "down_1": ("AD_BLOCK",), + "down_2": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up_0": ("AD_BLOCK",), + "up_1": ("AD_BLOCK",), + "up_2": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down_0: Union[ADBlockHolder, None]=None, + down_1: Union[ADBlockHolder, None]=None, + down_2: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up_0: Union[ADBlockHolder, None]=None, + up_1: Union[ADBlockHolder, None]=None, + up_2: Union[ADBlockHolder, None]=None): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN, block_idx=0): down_0, + PerBlockId(block_type=BlockType.DOWN, block_idx=1): down_1, + PerBlockId(block_type=BlockType.DOWN, block_idx=2): down_2, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP, block_idx=0): up_0, + PerBlockId(block_type=BlockType.UP, block_idx=1): up_1, + PerBlockId(block_type=BlockType.UP, block_idx=2): up_2, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks, ModelTypeSD.SDXL),) + + +class PerBlock_SDXL_LowLevelNode: + NodeID = 'ADE_PerBlock_SDXL_LowLevel' + NodeName = 'AD Per Block++ (SDXL) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down_0__0": ("AD_BLOCK",), + "down_0__1": ("AD_BLOCK",), + "down_1__0": ("AD_BLOCK",), + "down_1__1": ("AD_BLOCK",), + "down_2__0": ("AD_BLOCK",), + "down_2__1": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up_0__0": ("AD_BLOCK",), + "up_0__1": ("AD_BLOCK",), + "up_0__2": ("AD_BLOCK",), + "up_1__0": ("AD_BLOCK",), + "up_1__1": ("AD_BLOCK",), + "up_1__2": ("AD_BLOCK",), + "up_2__0": ("AD_BLOCK",), + "up_2__1": ("AD_BLOCK",), + "up_2__2": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down_0__0: Union[ADBlockHolder, None]=None, + down_0__1: Union[ADBlockHolder, None]=None, + down_1__0: Union[ADBlockHolder, None]=None, + down_1__1: Union[ADBlockHolder, None]=None, + down_2__0: Union[ADBlockHolder, None]=None, + down_2__1: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up_0__0: Union[ADBlockHolder, None]=None, + up_0__1: Union[ADBlockHolder, None]=None, + up_0__2: Union[ADBlockHolder, None]=None, + up_1__0: Union[ADBlockHolder, None]=None, + up_1__1: Union[ADBlockHolder, None]=None, + up_1__2: Union[ADBlockHolder, None]=None, + up_2__0: Union[ADBlockHolder, None]=None, + up_2__1: Union[ADBlockHolder, None]=None, + up_2__2: Union[ADBlockHolder, None]=None,): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN, block_idx=0, module_idx=0): down_0__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=0, module_idx=1): down_0__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=1, module_idx=0): down_1__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=1, module_idx=1): down_1__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=2, module_idx=0): down_2__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=2, module_idx=1): down_2__1, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=0): up_0__0, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=1): up_0__1, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=2): up_0__2, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=0): up_1__0, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=1): up_1__1, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=2): up_1__2, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=0): up_2__0, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=1): up_2__1, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=2): up_2__2, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks, ModelTypeSD.SDXL),) + + +class PerBlock_SDXL_FromFloatsNode: + NodeID = 'ADE_PerBlock_SDXL_FromFloats' + NodeName = 'AD Per Block Floats (SDXL) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "effect_16_floats": ("FLOATS",), + "scale_16_floats": ("FLOATS",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + Desc = [ + short_desc('Use Floats from Value Schedules to select SDXL effect/scale values for blocks.'), + 'SDXL Motion Modules contain 16 blocks:', + 'idx 0 - start of down blocks (down_0__0)', + 'idx 5 - end of down blocks (down_2__1)', + 'idx 6 - mid block (mid)', + 'idx 7 - start of up blocks (up_0__0)', + 'idx 15 - end of up blocks (up_2__2)', + ] + register_description(NodeID, Desc) + + def create_per_block(self, + effect_16_floats: Union[list[float], None]=None, + scale_16_floats: Union[list[float], None]=None): + if effect_16_floats is None and scale_16_floats is None: + return (None,) + # SDXL has 16 blocks + block_total = 16 + holders = [ADBlockHolder() for _ in range(block_total)] + if effect_16_floats is not None: + effect_16_floats = extend_list_to_batch_size(effect_16_floats, block_total) + for effect, holder in zip(effect_16_floats, holders): + holder.effect = effect + if scale_16_floats is not None: + scale_16_floats = extend_list_to_batch_size(scale_16_floats, block_total) + for scale, holder in zip(scale_16_floats, holders): + holder.scales = [scale, scale] + return PerBlock_SDXL_LowLevelNode.create_per_block(self, *holders) diff --git a/animatediff/nodes_pia.py b/animatediff/nodes_pia.py index e65386a..5944485 100644 --- a/animatediff/nodes_pia.py +++ b/animatediff/nodes_pia.py @@ -125,7 +125,8 @@ def INPUT_TYPES(s): "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), - "autosize": ("ADEAUTOSIZE", {"padding": 70}), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -136,11 +137,12 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, image: Tensor, vae: VAE, start_percent: float=0.0, end_percent: float=1.0, pia_input: InputPIA=None, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, - scale_multival=None, effect_multival=None, ref_multival=None, + scale_multival=None, effect_multival=None, ref_multival=None, per_block=None, prev_m_models: MotionModelGroup=None,): new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, motion_lora=motion_lora, ad_keyframes=ad_keyframes, - scale_multival=scale_multival, effect_multival=effect_multival, prev_m_models=prev_m_models) + scale_multival=scale_multival, effect_multival=effect_multival, per_block=per_block, + prev_m_models=prev_m_models) # most recent added model will always be first in list; curr_model = new_m_models[0].models[0] # confirm that model is PIA diff --git a/pyproject.toml b/pyproject.toml index 9029be3..35964e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.1.2" +version = "1.1.3" license = { file = "LICENSE" } dependencies = []