From e34c410cebab33605a41555475ffeb5d73979c1f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 15 Aug 2024 22:36:31 -0500 Subject: [PATCH 01/11] Preparing for Per-Block Effect/Scale, fixed cameractrl_effect using raw_scale_mask instead of instead of raw_cameractrl_effect --- animatediff/motion_module_ad.py | 87 ++++++++++++++++++++++++++++++--- 1 file changed, 80 insertions(+), 7 deletions(-) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 5435ad3..c5938a5 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -60,6 +60,55 @@ def get_string(self): return f"{self.mm_name}:{self.mm_version}:{self.mm_format}:{self.sd_type}" +####################### +# Facilitate Per-Block Effect and Scale Control +class PerAttn: + def __init__(self, attn_idx: Union[int, None], scale: Union[float, Tensor, None]): + self.attn_idx = attn_idx + self.scale = scale + + def matches(self, id: int): + if self.attn_idx is None: + return True + return self.attn_idx == id + + +class PerBlockId: + def __init__(self, block_type: str, block_idx: Union[int, None], module_idx: Union[int, None]): + self.block_type = block_type + self.block_idx = block_idx + self.module_idx = module_idx + + def matches(self, other: 'PerBlockId') -> bool: + # block_type + if other.block_type != self.block_type: + return False + # block_idx + if other.block_idx is None: + return True + elif other.block_idx != self.block_idx: + return False + # module_idx + if other.module_idx is None: + return True + return other.module_idx == self.module_idx + + def __str__(self): + return f"PerBlockId({self.block_type},{self.block_idx},{self.module_idx})" + + +class PerBlock: + def __init__(self, id: PerBlockId, effect: Union[float, Tensor, None]=None, + per_attn_scale: Union[list[PerAttn], None]=None): + self.id = id + self.effect = effect + self.per_attn_scale = per_attn_scale + + def matches(self, id: PerBlockId): + return self.id.matches(id) +#---------------------- +####################### + def is_hotshotxl(mm_state_dict: dict[str, Tensor]) -> bool: # use pos_encoder naming to determine if hotshotxl model for key in mm_state_dict.keys(): @@ -628,6 +677,7 @@ def __init__( self.block_type = block_type self.block_idx = block_idx self.module_idx = module_idx + self.id = PerBlockId(block_type=block_type, block_idx=block_idx, module_idx=module_idx) # effect vars self.effect = None self.temp_effect_mask: Tensor = None @@ -668,7 +718,15 @@ def set_scale_multiplier(self, multiplier: Union[float, None]): def set_scale_mask(self, mask: Tensor): self.temporal_transformer.set_scale_mask(mask) - def set_effect(self, multival: Union[float, Tensor]): + def set_scale(self, multiplier: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): + self.temporal_transformer.set_scale(multiplier) + + def set_effect(self, multival: Union[float, Tensor], per_block_list: Union[list[PerBlock], None]=None): + if per_block_list is not None: + for per_block in per_block_list: + if self.id.matches(per_block.id) and per_block.effect is not None: + multival = per_block.effect + break if type(multival) == Tensor: self.effect = multival elif multival is not None and math.isclose(multival, 1.0): @@ -677,7 +735,7 @@ def set_effect(self, multival: Union[float, Tensor]): self.effect = multival self.temp_effect_mask = None - def set_cameractrl_effect(self, multival: Union[float, Tensor]): + def set_cameractrl_effect(self, multival: Union[float, Tensor, None]): if type(multival) == Tensor: pass elif multival is None: @@ -741,6 +799,7 @@ def should_handle_camera_features(self): return self.camera_features is not None and self.block_type != BlockType.MID# and self.module_idx == 0 def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_mask=None): + #logger.info(f"block_type: {self.block_type}, block_idx: {self.block_idx}, module_idx: {self.module_idx}") mm_kwargs = None if self.should_handle_camera_features(): mm_kwargs = {"camera_feature": self.camera_features[self.block_idx]} @@ -843,6 +902,14 @@ def set_scale_mask(self, mask: Tensor): self.raw_scale_mask = mask self.temp_scale_mask = None + def set_scale(self, scale: Union[float, Tensor, None]): + if type(scale) == Tensor: + self.set_scale_mask(scale) + self.set_scale_multiplier(None) + else: + self.set_scale_mask(None) + self.set_scale_multiplier(scale) + def set_cameractrl_effect(self, multival: Union[float, Tensor]): self.raw_cameractrl_effect = multival self.temp_cameractrl_effect = None @@ -926,7 +993,7 @@ def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, N self.temp_cameractrl_effect = None # otherwise, calculate temp_cameractrl self.prev_cameractrl_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) + mask = prepare_mask_batch(self.raw_cameractrl_effect, shape=(self.full_length, 1, height, width)) mask = repeat_to_batch_size(mask, self.full_length) # if mask not the same amount length as full length, make it match if self.full_length != mask.shape[0]: @@ -1044,9 +1111,15 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops) self.ff_norm = ops.LayerNorm(dim) - def set_scale_multiplier(self, multiplier: Union[float, None]): - for block in self.attention_blocks: - block.set_scale_multiplier(multiplier) + def set_scale_multiplier(self, multiplier: Union[float, None], per_attn_list: Union[list[PerAttn], None]=None): + for idx, block in enumerate(self.attention_blocks): + mult = multiplier + if per_attn_list is not None: + for per_attn in per_attn_list: + if per_attn.attn_idx == idx: + mult = per_attn.scale + break + block.set_scale_multiplier(mult) def set_sub_idxs(self, sub_idxs: list[int]): for block in self.attention_blocks: @@ -1064,7 +1137,7 @@ def forward( video_length: int=None, scale_mask: Tensor=None, cameractrl_effect: Union[float, Tensor] = None, - view_options: ContextOptions=None, + view_options: Union[ContextOptions, None]=None, mm_kwargs: dict[str]=None, ): # make view_options None if context_length > video_length, or if equal and equal not allowed From 2a520db00c9a7a933fed412b6eeaf9635dafb544 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 15 Aug 2024 23:01:56 -0500 Subject: [PATCH 02/11] Moved scale_mask and cameractrl_effec to TemporalTransformerBlock --- animatediff/motion_module_ad.py | 259 ++++++++++++++++---------------- 1 file changed, 132 insertions(+), 127 deletions(-) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index c5938a5..89f9a3d 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -848,17 +848,6 @@ def __init__( ops=comfy.ops.disable_weight_init, ): super().__init__() - self.video_length = 16 - self.full_length = 16 - self.raw_scale_mask: Union[Tensor, None] = None - self.temp_scale_mask: Union[Tensor, None] = None - self.sub_idxs: Union[list[int], None] = None - self.prev_hidden_states_batch = 0 - - # cameractrl stuff - self.raw_cameractrl_effect: Union[float, Tensor] = None - self.temp_cameractrl_effect: Union[float, Tensor] = None - self.prev_cameractrl_hidden_states_batch = 0 inner_dim = num_attention_heads * attention_head_dim @@ -891,16 +880,16 @@ def __init__( self.proj_out = ops.Linear(inner_dim, in_channels) def set_video_length(self, video_length: int, full_length: int): - self.video_length = video_length - self.full_length = full_length + for block in self.transformer_blocks: + block.set_video_length(video_length, full_length) def set_scale_multiplier(self, multiplier: Union[float, None]): for block in self.transformer_blocks: block.set_scale_multiplier(multiplier) def set_scale_mask(self, mask: Tensor): - self.raw_scale_mask = mask - self.temp_scale_mask = None + for block in self.transformer_blocks: + block.set_scale_mask(mask) def set_scale(self, scale: Union[float, Tensor, None]): if type(scale) == Tensor: @@ -911,118 +900,20 @@ def set_scale(self, scale: Union[float, Tensor, None]): self.set_scale_multiplier(scale) def set_cameractrl_effect(self, multival: Union[float, Tensor]): - self.raw_cameractrl_effect = multival - self.temp_cameractrl_effect = None + for block in self.transformer_blocks: + block.set_cameractrl_effect(multival) def set_sub_idxs(self, sub_idxs: list[int]): - self.sub_idxs = sub_idxs for block in self.transformer_blocks: block.set_sub_idxs(sub_idxs) def reset_temp_vars(self): - del self.temp_scale_mask - self.temp_scale_mask = None - self.prev_hidden_states_batch = 0 - del self.temp_cameractrl_effect - self.temp_cameractrl_effect = None - self.prev_cameractrl_hidden_states_batch = 0 for block in self.transformer_blocks: block.reset_temp_vars() - def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: - # if no raw mask, return None - if self.raw_scale_mask is None: - return None - shape = hidden_states.shape - batch, channel, height, width = shape - # if temp mask already calculated, return it - if self.temp_scale_mask != None: - # check if hidden_states batch matches - if batch == self.prev_hidden_states_batch: - if self.sub_idxs is not None: - return self.temp_scale_mask[:, self.sub_idxs, :] - return self.temp_scale_mask - # if does not match, reset cached temp_scale_mask and recalculate it - del self.temp_scale_mask - self.temp_scale_mask = None - # otherwise, calculate temp mask - self.prev_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) - mask = repeat_to_batch_size(mask, self.full_length) - # if mask not the same amount length as full length, make it match - if self.full_length != mask.shape[0]: - mask = broadcast_image_to(mask, self.full_length, 1) - # reshape mask to attention K shape (h*w, latent_count, 1) - batch, channel, height, width = mask.shape - # first, perform same operations as on hidden_states, - # turning (b, c, h, w) -> (b, h*w, c) - mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) - # then, make it the same shape as attention's k, (h*w, b, c) - mask = mask.permute(1, 0, 2) - # make masks match the expected length of h*w - batched_number = shape[0] // self.video_length - if batched_number > 1: - mask = torch.cat([mask] * batched_number, dim=0) - # cache mask and set to proper device - self.temp_scale_mask = mask - # move temp_scale_mask to proper dtype + device - self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device) - # return subset of masks, if needed - if self.sub_idxs is not None: - return self.temp_scale_mask[:, self.sub_idxs, :] - return self.temp_scale_mask - - def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, None]: - # if no raw camera_Ctrl, return None - if self.raw_cameractrl_effect is None: - return 1.0 - # if raw_cameractrl is not a Tensor, return it (should be a float) - if type(self.raw_cameractrl_effect) != Tensor: - return self.raw_cameractrl_effect - shape = hidden_states.shape - batch, channel, height, width = shape - # if temp_cameractrl already calculated, return it - if self.temp_cameractrl_effect != None: - # check if hidden_states batch matches - if batch == self.prev_cameractrl_hidden_states_batch: - if self.sub_idxs is not None: - return self.temp_cameractrl_effect[:, self.sub_idxs, :] - return self.temp_cameractrl_effect - # if does not match, reset cached temp_cameractrl and recalculate it - del self.temp_cameractrl_effect - self.temp_cameractrl_effect = None - # otherwise, calculate temp_cameractrl - self.prev_cameractrl_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_cameractrl_effect, shape=(self.full_length, 1, height, width)) - mask = repeat_to_batch_size(mask, self.full_length) - # if mask not the same amount length as full length, make it match - if self.full_length != mask.shape[0]: - mask = broadcast_image_to(mask, self.full_length, 1) - # reshape mask to attention K shape (h*w, latent_count, 1) - batch, channel, height, width = mask.shape - # first, perform same operations as on hidden_states, - # turning (b, c, h, w) -> (b, h*w, c) - mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) - # then, make it the same shape as attention's k, (h*w, b, c) - mask = mask.permute(1, 0, 2) - # make masks match the expected length of h*w - batched_number = shape[0] // self.video_length - if batched_number > 1: - mask = torch.cat([mask] * batched_number, dim=0) - # cache mask and set to proper device - self.temp_cameractrl_effect = mask - # move temp_cameractrl to proper dtype + device - self.temp_cameractrl_effect = self.temp_cameractrl_effect.to(dtype=hidden_states.dtype, device=hidden_states.device) - # return subset of masks, if needed - if self.sub_idxs is not None: - return self.temp_cameractrl_effect[:, self.sub_idxs, :] - return self.temp_cameractrl_effect - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options: ContextOptions=None, mm_kwargs: dict[str]=None): batch, channel, height, width = hidden_states.shape residual = hidden_states - scale_mask = self.get_scale_mask(hidden_states) - cameractrl_effect = self.get_cameractrl_effect(hidden_states) # add some casts for fp8 purposes - does not affect speed otherwise hidden_states = self.norm(hidden_states).to(hidden_states.dtype) inner_dim = hidden_states.shape[1] @@ -1037,9 +928,6 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, - video_length=self.video_length, - scale_mask=scale_mask, - cameractrl_effect=cameractrl_effect, view_options=view_options, mm_kwargs=mm_kwargs ) @@ -1079,6 +967,17 @@ def __init__( ops=comfy.ops.disable_weight_init, ): super().__init__() + self.video_length = 16 + self.full_length = 16 + self.raw_scale_mask: Union[Tensor, None] = None + self.temp_scale_mask: Union[Tensor, None] = None + self.sub_idxs: Union[list[int], None] = None + self.prev_hidden_states_batch = 0 + + # cameractrl stuff + self.raw_cameractrl_effect: Union[float, Tensor] = None + self.temp_cameractrl_effect: Union[float, Tensor] = None + self.prev_cameractrl_hidden_states_batch = 0 attention_blocks: Iterable[VersatileAttention] = [] norms = [] @@ -1111,6 +1010,10 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops) self.ff_norm = ops.LayerNorm(dim) + def set_video_length(self, video_length: int, full_length: int): + self.video_length = video_length + self.full_length = full_length + def set_scale_multiplier(self, multiplier: Union[float, None], per_attn_list: Union[list[PerAttn], None]=None): for idx, block in enumerate(self.attention_blocks): mult = multiplier @@ -1121,30 +1024,132 @@ def set_scale_multiplier(self, multiplier: Union[float, None], per_attn_list: Un break block.set_scale_multiplier(mult) + def set_scale_mask(self, mask: Tensor): + self.raw_scale_mask = mask + self.temp_scale_mask = None + + def set_cameractrl_effect(self, multival: Union[float, Tensor]): + self.raw_cameractrl_effect = multival + self.temp_cameractrl_effect = None + def set_sub_idxs(self, sub_idxs: list[int]): for block in self.attention_blocks: block.set_sub_idxs(sub_idxs) def reset_temp_vars(self): + del self.temp_scale_mask + self.temp_scale_mask = None + self.prev_hidden_states_batch = 0 + del self.temp_cameractrl_effect + self.temp_cameractrl_effect = None + self.prev_cameractrl_hidden_states_batch = 0 for block in self.attention_blocks: block.reset_temp_vars() + + def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: + # if no raw mask, return None + if self.raw_scale_mask is None: + return None + shape = hidden_states.shape + batch, channel, height, width = shape + # if temp mask already calculated, return it + if self.temp_scale_mask != None: + # check if hidden_states batch matches + if batch == self.prev_hidden_states_batch: + if self.sub_idxs is not None: + return self.temp_scale_mask[:, self.sub_idxs, :] + return self.temp_scale_mask + # if does not match, reset cached temp_scale_mask and recalculate it + del self.temp_scale_mask + self.temp_scale_mask = None + # otherwise, calculate temp mask + self.prev_hidden_states_batch = batch + mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) + mask = repeat_to_batch_size(mask, self.full_length) + # if mask not the same amount length as full length, make it match + if self.full_length != mask.shape[0]: + mask = broadcast_image_to(mask, self.full_length, 1) + # reshape mask to attention K shape (h*w, latent_count, 1) + batch, channel, height, width = mask.shape + # first, perform same operations as on hidden_states, + # turning (b, c, h, w) -> (b, h*w, c) + mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) + # then, make it the same shape as attention's k, (h*w, b, c) + mask = mask.permute(1, 0, 2) + # make masks match the expected length of h*w + batched_number = shape[0] // self.video_length + if batched_number > 1: + mask = torch.cat([mask] * batched_number, dim=0) + # cache mask and set to proper device + self.temp_scale_mask = mask + # move temp_scale_mask to proper dtype + device + self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device) + # return subset of masks, if needed + if self.sub_idxs is not None: + return self.temp_scale_mask[:, self.sub_idxs, :] + return self.temp_scale_mask + + def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, None]: + # if no raw camera_Ctrl, return None + if self.raw_cameractrl_effect is None: + return 1.0 + # if raw_cameractrl is not a Tensor, return it (should be a float) + if type(self.raw_cameractrl_effect) != Tensor: + return self.raw_cameractrl_effect + shape = hidden_states.shape + batch, channel, height, width = shape + # if temp_cameractrl already calculated, return it + if self.temp_cameractrl_effect != None: + # check if hidden_states batch matches + if batch == self.prev_cameractrl_hidden_states_batch: + if self.sub_idxs is not None: + return self.temp_cameractrl_effect[:, self.sub_idxs, :] + return self.temp_cameractrl_effect + # if does not match, reset cached temp_cameractrl and recalculate it + del self.temp_cameractrl_effect + self.temp_cameractrl_effect = None + # otherwise, calculate temp_cameractrl + self.prev_cameractrl_hidden_states_batch = batch + mask = prepare_mask_batch(self.raw_cameractrl_effect, shape=(self.full_length, 1, height, width)) + mask = repeat_to_batch_size(mask, self.full_length) + # if mask not the same amount length as full length, make it match + if self.full_length != mask.shape[0]: + mask = broadcast_image_to(mask, self.full_length, 1) + # reshape mask to attention K shape (h*w, latent_count, 1) + batch, channel, height, width = mask.shape + # first, perform same operations as on hidden_states, + # turning (b, c, h, w) -> (b, h*w, c) + mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) + # then, make it the same shape as attention's k, (h*w, b, c) + mask = mask.permute(1, 0, 2) + # make masks match the expected length of h*w + batched_number = shape[0] // self.video_length + if batched_number > 1: + mask = torch.cat([mask] * batched_number, dim=0) + # cache mask and set to proper device + self.temp_cameractrl_effect = mask + # move temp_cameractrl to proper dtype + device + self.temp_cameractrl_effect = self.temp_cameractrl_effect.to(dtype=hidden_states.dtype, device=hidden_states.device) + # return subset of masks, if needed + if self.sub_idxs is not None: + return self.temp_cameractrl_effect[:, self.sub_idxs, :] + return self.temp_cameractrl_effect def forward( self, hidden_states: Tensor, encoder_hidden_states: Tensor=None, attention_mask: Tensor=None, - video_length: int=None, - scale_mask: Tensor=None, - cameractrl_effect: Union[float, Tensor] = None, view_options: Union[ContextOptions, None]=None, mm_kwargs: dict[str]=None, ): + scale_mask = self.get_scale_mask(hidden_states) + cameractrl_effect = self.get_cameractrl_effect(hidden_states) # make view_options None if context_length > video_length, or if equal and equal not allowed if view_options: - if view_options.context_length > video_length: + if view_options.context_length > self.video_length: view_options = None - elif view_options.context_length == video_length and not view_options.use_on_equal_length: + elif view_options.context_length == self.video_length and not view_options.use_on_equal_length: view_options = None if not view_options: for attention_block, norm in zip(self.attention_blocks, self.norms): @@ -1156,7 +1161,7 @@ def forward( if attention_block.is_cross_attention else None, attention_mask=attention_mask, - video_length=video_length, + video_length=self.video_length, scale_mask=scale_mask, cameractrl_effect=cameractrl_effect, mm_kwargs=mm_kwargs @@ -1166,12 +1171,12 @@ def forward( # views idea gotten from diffusers AnimateDiff FreeNoise implementation: # https://github.com/arthur-qiu/FreeNoise-AnimateDiff/blob/main/animatediff/models/motion_module.py # apply sliding context windows (views) - views = get_context_windows(num_frames=video_length, opts=view_options) - hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length) + views = get_context_windows(num_frames=self.video_length, opts=view_options) + hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=self.video_length) value_final = torch.zeros_like(hidden_states) count_final = torch.zeros_like(hidden_states) # bias_final = [0.0] * video_length - batched_conds = hidden_states.size(1) // video_length + batched_conds = hidden_states.size(1) // self.video_length # store original camera_feature, if present has_camera_feature = False if mm_kwargs is not None: From f2376ebe9b6a570f3f4020ae20f5aaa39be35657 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 15 Aug 2024 23:44:12 -0500 Subject: [PATCH 03/11] Rollback moving of scale_mask and cameractrl_effect, refactor so that set_scale replaces set_scale_multiplier and set_scale_mask where applicable --- animatediff/motion_module_ad.py | 341 ++++++++++++++------------------ 1 file changed, 153 insertions(+), 188 deletions(-) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 89f9a3d..2379479 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -345,7 +345,7 @@ def get_best_beta_schedule(self, log=False) -> str: def cleanup(self): self._reset_sub_idxs() - self._reset_scale_multiplier() + self._reset_scale() self._reset_temp_vars() if self.img_encoder is not None: self.img_encoder.cleanup() @@ -465,15 +465,15 @@ def set_video_length(self, video_length: int, full_length: int): if self.mid_block is not None: self.mid_block.set_video_length(video_length, full_length) - def set_scale(self, multival: Union[float, Tensor]): - if multival is None: - multival = 1.0 - if type(multival) == Tensor: - self._set_scale_multiplier(1.0) - self._set_scale_mask(multival) - else: - self._set_scale_multiplier(multival) - self._set_scale_mask(None) + def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): + if self.down_blocks is not None: + for block in self.down_blocks: + block.set_scale(scale, per_block_list) + if self.up_blocks is not None: + for block in self.up_blocks: + block.set_scale(scale, per_block_list) + if self.mid_block is not None: + self.mid_block.set_scale(scale, per_block_list) def set_effect(self, multival: Union[float, Tensor]): # keep track of if model is in effect @@ -539,26 +539,6 @@ def set_camera_features(self, camera_features: list[Tensor]): if self.up_blocks is not None: for block in self.up_blocks: block.set_camera_features(camera_features=list(reversed(camera_features))) - - def _set_scale_multiplier(self, multiplier: Union[float, None]): - if self.down_blocks is not None: - for block in self.down_blocks: - block.set_scale_multiplier(multiplier) - if self.up_blocks is not None: - for block in self.up_blocks: - block.set_scale_multiplier(multiplier) - if self.mid_block is not None: - self.mid_block.set_scale_multiplier(multiplier) - - def _set_scale_mask(self, mask: Tensor): - if self.down_blocks is not None: - for block in self.down_blocks: - block.set_scale_mask(mask) - if self.up_blocks is not None: - for block in self.up_blocks: - block.set_scale_mask(mask) - if self.mid_block is not None: - self.mid_block.set_scale_mask(mask) def _reset_temp_vars(self): if self.down_blocks is not None: @@ -570,8 +550,8 @@ def _reset_temp_vars(self): if self.mid_block is not None: self.mid_block.reset_temp_vars() - def _reset_scale_multiplier(self): - self._set_scale_multiplier(None) + def _reset_scale(self): + self.set_scale(None) def _reset_sub_idxs(self): self.set_sub_idxs(None) @@ -606,14 +586,10 @@ def set_video_length(self, video_length: int, full_length: int): for motion_module in self.motion_modules: motion_module.set_video_length(video_length, full_length) - def set_scale_multiplier(self, multiplier: Union[float, None]): + def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): for motion_module in self.motion_modules: - motion_module.set_scale_multiplier(multiplier) - - def set_scale_mask(self, mask: Tensor): - for motion_module in self.motion_modules: - motion_module.set_scale_mask(mask) - + motion_module.set_scale(scale, per_block_list) + def set_effect(self, multival: Union[float, Tensor]): for motion_module in self.motion_modules: motion_module.set_effect(multival) @@ -711,15 +687,9 @@ def set_video_length(self, video_length: int, full_length: int): self.video_length = video_length self.full_length = full_length self.temporal_transformer.set_video_length(video_length, full_length) - - def set_scale_multiplier(self, multiplier: Union[float, None]): - self.temporal_transformer.set_scale_multiplier(multiplier) - - def set_scale_mask(self, mask: Tensor): - self.temporal_transformer.set_scale_mask(mask) - def set_scale(self, multiplier: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): - self.temporal_transformer.set_scale(multiplier) + def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): + self.temporal_transformer.set_scale(scale) def set_effect(self, multival: Union[float, Tensor], per_block_list: Union[list[PerBlock], None]=None): if per_block_list is not None: @@ -848,6 +818,17 @@ def __init__( ops=comfy.ops.disable_weight_init, ): super().__init__() + self.video_length = 16 + self.full_length = 16 + self.raw_scale_mask: Union[Tensor, None] = None + self.temp_scale_mask: Union[Tensor, None] = None + self.sub_idxs: Union[list[int], None] = None + self.prev_hidden_states_batch = 0 + + # cameractrl stuff + self.raw_cameractrl_effect: Union[float, Tensor] = None + self.temp_cameractrl_effect: Union[float, Tensor] = None + self.prev_cameractrl_hidden_states_batch = 0 inner_dim = num_attention_heads * attention_head_dim @@ -880,18 +861,24 @@ def __init__( self.proj_out = ops.Linear(inner_dim, in_channels) def set_video_length(self, video_length: int, full_length: int): - for block in self.transformer_blocks: - block.set_video_length(video_length, full_length) - + self.video_length = video_length + self.full_length = full_length + def set_scale_multiplier(self, multiplier: Union[float, None]): for block in self.transformer_blocks: - block.set_scale_multiplier(multiplier) + mult = multiplier + block.set_scale_multiplier(mult) def set_scale_mask(self, mask: Tensor): - for block in self.transformer_blocks: - block.set_scale_mask(mask) + self.raw_scale_mask = mask + self.temp_scale_mask = None - def set_scale(self, scale: Union[float, Tensor, None]): + def set_scale(self, scale: Union[float, Tensor, None], per_attn_list: Union[list[PerAttn], None]=None): + # if per_attn_list is not None: + # for per_attn in per_attn_list: + # if per_attn.attn_idx == idx: + # mult = per_attn.scale + # break if type(scale) == Tensor: self.set_scale_mask(scale) self.set_scale_multiplier(None) @@ -900,20 +887,117 @@ def set_scale(self, scale: Union[float, Tensor, None]): self.set_scale_multiplier(scale) def set_cameractrl_effect(self, multival: Union[float, Tensor]): - for block in self.transformer_blocks: - block.set_cameractrl_effect(multival) + self.raw_cameractrl_effect = multival + self.temp_cameractrl_effect = None def set_sub_idxs(self, sub_idxs: list[int]): for block in self.transformer_blocks: block.set_sub_idxs(sub_idxs) def reset_temp_vars(self): + del self.temp_scale_mask + self.temp_scale_mask = None + self.prev_hidden_states_batch = 0 + del self.temp_cameractrl_effect + self.temp_cameractrl_effect = None + self.prev_cameractrl_hidden_states_batch = 0 for block in self.transformer_blocks: block.reset_temp_vars() + def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: + # if no raw mask, return None + if self.raw_scale_mask is None: + return None + shape = hidden_states.shape + batch, channel, height, width = shape + # if temp mask already calculated, return it + if self.temp_scale_mask != None: + # check if hidden_states batch matches + if batch == self.prev_hidden_states_batch: + if self.sub_idxs is not None: + return self.temp_scale_mask[:, self.sub_idxs, :] + return self.temp_scale_mask + # if does not match, reset cached temp_scale_mask and recalculate it + del self.temp_scale_mask + self.temp_scale_mask = None + # otherwise, calculate temp mask + self.prev_hidden_states_batch = batch + mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) + mask = repeat_to_batch_size(mask, self.full_length) + # if mask not the same amount length as full length, make it match + if self.full_length != mask.shape[0]: + mask = broadcast_image_to(mask, self.full_length, 1) + # reshape mask to attention K shape (h*w, latent_count, 1) + batch, channel, height, width = mask.shape + # first, perform same operations as on hidden_states, + # turning (b, c, h, w) -> (b, h*w, c) + mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) + # then, make it the same shape as attention's k, (h*w, b, c) + mask = mask.permute(1, 0, 2) + # make masks match the expected length of h*w + batched_number = shape[0] // self.video_length + if batched_number > 1: + mask = torch.cat([mask] * batched_number, dim=0) + # cache mask and set to proper device + self.temp_scale_mask = mask + # move temp_scale_mask to proper dtype + device + self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device) + # return subset of masks, if needed + if self.sub_idxs is not None: + return self.temp_scale_mask[:, self.sub_idxs, :] + return self.temp_scale_mask + + def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, None]: + # if no raw camera_Ctrl, return None + if self.raw_cameractrl_effect is None: + return 1.0 + # if raw_cameractrl is not a Tensor, return it (should be a float) + if type(self.raw_cameractrl_effect) != Tensor: + return self.raw_cameractrl_effect + shape = hidden_states.shape + batch, channel, height, width = shape + # if temp_cameractrl already calculated, return it + if self.temp_cameractrl_effect != None: + # check if hidden_states batch matches + if batch == self.prev_cameractrl_hidden_states_batch: + if self.sub_idxs is not None: + return self.temp_cameractrl_effect[:, self.sub_idxs, :] + return self.temp_cameractrl_effect + # if does not match, reset cached temp_cameractrl and recalculate it + del self.temp_cameractrl_effect + self.temp_cameractrl_effect = None + # otherwise, calculate temp_cameractrl + self.prev_cameractrl_hidden_states_batch = batch + mask = prepare_mask_batch(self.raw_cameractrl_effect, shape=(self.full_length, 1, height, width)) + mask = repeat_to_batch_size(mask, self.full_length) + # if mask not the same amount length as full length, make it match + if self.full_length != mask.shape[0]: + mask = broadcast_image_to(mask, self.full_length, 1) + # reshape mask to attention K shape (h*w, latent_count, 1) + batch, channel, height, width = mask.shape + # first, perform same operations as on hidden_states, + # turning (b, c, h, w) -> (b, h*w, c) + mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) + # then, make it the same shape as attention's k, (h*w, b, c) + mask = mask.permute(1, 0, 2) + # make masks match the expected length of h*w + batched_number = shape[0] // self.video_length + if batched_number > 1: + mask = torch.cat([mask] * batched_number, dim=0) + # cache mask and set to proper device + self.temp_cameractrl_effect = mask + # move temp_cameractrl to proper dtype + device + self.temp_cameractrl_effect = self.temp_cameractrl_effect.to(dtype=hidden_states.dtype, device=hidden_states.device) + # return subset of masks, if needed + if self.sub_idxs is not None: + return self.temp_cameractrl_effect[:, self.sub_idxs, :] + return self.temp_cameractrl_effect + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options: ContextOptions=None, mm_kwargs: dict[str]=None): batch, channel, height, width = hidden_states.shape residual = hidden_states + scale_mask = self.get_scale_mask(hidden_states) + cameractrl_effect = self.get_cameractrl_effect(hidden_states) # add some casts for fp8 purposes - does not affect speed otherwise hidden_states = self.norm(hidden_states).to(hidden_states.dtype) inner_dim = hidden_states.shape[1] @@ -928,6 +1012,9 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + video_length=self.video_length, + scale_mask=scale_mask, + cameractrl_effect=cameractrl_effect, view_options=view_options, mm_kwargs=mm_kwargs ) @@ -967,17 +1054,6 @@ def __init__( ops=comfy.ops.disable_weight_init, ): super().__init__() - self.video_length = 16 - self.full_length = 16 - self.raw_scale_mask: Union[Tensor, None] = None - self.temp_scale_mask: Union[Tensor, None] = None - self.sub_idxs: Union[list[int], None] = None - self.prev_hidden_states_batch = 0 - - # cameractrl stuff - self.raw_cameractrl_effect: Union[float, Tensor] = None - self.temp_cameractrl_effect: Union[float, Tensor] = None - self.prev_cameractrl_hidden_states_batch = 0 attention_blocks: Iterable[VersatileAttention] = [] norms = [] @@ -1010,146 +1086,35 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops) self.ff_norm = ops.LayerNorm(dim) - def set_video_length(self, video_length: int, full_length: int): - self.video_length = video_length - self.full_length = full_length - - def set_scale_multiplier(self, multiplier: Union[float, None], per_attn_list: Union[list[PerAttn], None]=None): + def set_scale_multiplier(self, multiplier: Union[float, None]): for idx, block in enumerate(self.attention_blocks): mult = multiplier - if per_attn_list is not None: - for per_attn in per_attn_list: - if per_attn.attn_idx == idx: - mult = per_attn.scale - break block.set_scale_multiplier(mult) - def set_scale_mask(self, mask: Tensor): - self.raw_scale_mask = mask - self.temp_scale_mask = None - - def set_cameractrl_effect(self, multival: Union[float, Tensor]): - self.raw_cameractrl_effect = multival - self.temp_cameractrl_effect = None - def set_sub_idxs(self, sub_idxs: list[int]): for block in self.attention_blocks: block.set_sub_idxs(sub_idxs) def reset_temp_vars(self): - del self.temp_scale_mask - self.temp_scale_mask = None - self.prev_hidden_states_batch = 0 - del self.temp_cameractrl_effect - self.temp_cameractrl_effect = None - self.prev_cameractrl_hidden_states_batch = 0 for block in self.attention_blocks: block.reset_temp_vars() - - def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: - # if no raw mask, return None - if self.raw_scale_mask is None: - return None - shape = hidden_states.shape - batch, channel, height, width = shape - # if temp mask already calculated, return it - if self.temp_scale_mask != None: - # check if hidden_states batch matches - if batch == self.prev_hidden_states_batch: - if self.sub_idxs is not None: - return self.temp_scale_mask[:, self.sub_idxs, :] - return self.temp_scale_mask - # if does not match, reset cached temp_scale_mask and recalculate it - del self.temp_scale_mask - self.temp_scale_mask = None - # otherwise, calculate temp mask - self.prev_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) - mask = repeat_to_batch_size(mask, self.full_length) - # if mask not the same amount length as full length, make it match - if self.full_length != mask.shape[0]: - mask = broadcast_image_to(mask, self.full_length, 1) - # reshape mask to attention K shape (h*w, latent_count, 1) - batch, channel, height, width = mask.shape - # first, perform same operations as on hidden_states, - # turning (b, c, h, w) -> (b, h*w, c) - mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) - # then, make it the same shape as attention's k, (h*w, b, c) - mask = mask.permute(1, 0, 2) - # make masks match the expected length of h*w - batched_number = shape[0] // self.video_length - if batched_number > 1: - mask = torch.cat([mask] * batched_number, dim=0) - # cache mask and set to proper device - self.temp_scale_mask = mask - # move temp_scale_mask to proper dtype + device - self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device) - # return subset of masks, if needed - if self.sub_idxs is not None: - return self.temp_scale_mask[:, self.sub_idxs, :] - return self.temp_scale_mask - - def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, None]: - # if no raw camera_Ctrl, return None - if self.raw_cameractrl_effect is None: - return 1.0 - # if raw_cameractrl is not a Tensor, return it (should be a float) - if type(self.raw_cameractrl_effect) != Tensor: - return self.raw_cameractrl_effect - shape = hidden_states.shape - batch, channel, height, width = shape - # if temp_cameractrl already calculated, return it - if self.temp_cameractrl_effect != None: - # check if hidden_states batch matches - if batch == self.prev_cameractrl_hidden_states_batch: - if self.sub_idxs is not None: - return self.temp_cameractrl_effect[:, self.sub_idxs, :] - return self.temp_cameractrl_effect - # if does not match, reset cached temp_cameractrl and recalculate it - del self.temp_cameractrl_effect - self.temp_cameractrl_effect = None - # otherwise, calculate temp_cameractrl - self.prev_cameractrl_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_cameractrl_effect, shape=(self.full_length, 1, height, width)) - mask = repeat_to_batch_size(mask, self.full_length) - # if mask not the same amount length as full length, make it match - if self.full_length != mask.shape[0]: - mask = broadcast_image_to(mask, self.full_length, 1) - # reshape mask to attention K shape (h*w, latent_count, 1) - batch, channel, height, width = mask.shape - # first, perform same operations as on hidden_states, - # turning (b, c, h, w) -> (b, h*w, c) - mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) - # then, make it the same shape as attention's k, (h*w, b, c) - mask = mask.permute(1, 0, 2) - # make masks match the expected length of h*w - batched_number = shape[0] // self.video_length - if batched_number > 1: - mask = torch.cat([mask] * batched_number, dim=0) - # cache mask and set to proper device - self.temp_cameractrl_effect = mask - # move temp_cameractrl to proper dtype + device - self.temp_cameractrl_effect = self.temp_cameractrl_effect.to(dtype=hidden_states.dtype, device=hidden_states.device) - # return subset of masks, if needed - if self.sub_idxs is not None: - return self.temp_cameractrl_effect[:, self.sub_idxs, :] - return self.temp_cameractrl_effect def forward( self, hidden_states: Tensor, encoder_hidden_states: Tensor=None, attention_mask: Tensor=None, + video_length: int=None, + scale_mask: Tensor=None, + cameractrl_effect: Union[float, Tensor] = None, view_options: Union[ContextOptions, None]=None, mm_kwargs: dict[str]=None, ): - scale_mask = self.get_scale_mask(hidden_states) - cameractrl_effect = self.get_cameractrl_effect(hidden_states) # make view_options None if context_length > video_length, or if equal and equal not allowed if view_options: - if view_options.context_length > self.video_length: + if view_options.context_length > video_length: view_options = None - elif view_options.context_length == self.video_length and not view_options.use_on_equal_length: + elif view_options.context_length == video_length and not view_options.use_on_equal_length: view_options = None if not view_options: for attention_block, norm in zip(self.attention_blocks, self.norms): @@ -1161,7 +1126,7 @@ def forward( if attention_block.is_cross_attention else None, attention_mask=attention_mask, - video_length=self.video_length, + video_length=video_length, scale_mask=scale_mask, cameractrl_effect=cameractrl_effect, mm_kwargs=mm_kwargs @@ -1171,12 +1136,12 @@ def forward( # views idea gotten from diffusers AnimateDiff FreeNoise implementation: # https://github.com/arthur-qiu/FreeNoise-AnimateDiff/blob/main/animatediff/models/motion_module.py # apply sliding context windows (views) - views = get_context_windows(num_frames=self.video_length, opts=view_options) - hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=self.video_length) + views = get_context_windows(num_frames=video_length, opts=view_options) + hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length) value_final = torch.zeros_like(hidden_states) count_final = torch.zeros_like(hidden_states) # bias_final = [0.0] * video_length - batched_conds = hidden_states.size(1) // self.video_length + batched_conds = hidden_states.size(1) // video_length # store original camera_feature, if present has_camera_feature = False if mm_kwargs is not None: From 1a87cefb18dda6cbaeed4a37903acbd9c16206d8 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 16 Aug 2024 03:40:33 -0500 Subject: [PATCH 04/11] Initial Per-Block implementation - only effect works currently, still need to implement scale, fixed some Apply node sizes due to .js changes with documentation --- animatediff/model_injection.py | 23 +- animatediff/motion_module_ad.py | 33 ++- animatediff/nodes.py | 18 ++ animatediff/nodes_animatelcmi2v.py | 8 +- animatediff/nodes_gen1.py | 10 +- animatediff/nodes_gen2.py | 19 +- animatediff/nodes_per_block.py | 394 +++++++++++++++++++++++++++++ animatediff/nodes_pia.py | 8 +- 8 files changed, 483 insertions(+), 30 deletions(-) create mode 100644 animatediff/nodes_per_block.py diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index aa048b2..df5047c 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -19,7 +19,7 @@ from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding from .context import ContextOptions, ContextOptions, ContextOptionsGroup -from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention, +from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks, has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) from .logger import logger from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA, @@ -752,8 +752,9 @@ def __init__(self, *args, **kwargs): self.timestep_range: tuple[float, float] = None self.keyframes: ADKeyframeGroup = ADKeyframeGroup() - self.scale_multival = None - self.effect_multival = None + self.scale_multival: Union[float, Tensor, None] = None + self.effect_multival: Union[float, Tensor, None] = None + self.per_block_list: Union[list[PerBlock], None] = None # AnimateLCM-I2V self.orig_ref_drift: float = None @@ -789,6 +790,7 @@ def __init__(self, *args, **kwargs): self.current_pia_input: InputPIA = None self.combined_scale: Union[float, Tensor] = None self.combined_effect: Union[float, Tensor] = None + self.combined_per_block_list: Union[float, Tensor] = None self.combined_cameractrl_effect: Union[float, Tensor] = None self.combined_pia_mask: Union[float, Tensor] = None self.combined_pia_effect: Union[float, Tensor] = None @@ -820,7 +822,7 @@ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patc def pre_run(self, model: ModelPatcherAndInjector): self.cleanup() self.model.set_scale(self.scale_multival) - self.model.set_effect(self.effect_multival) + self.model.set_effect(self.effect_multival, self.per_block_list) self.model.set_cameractrl_effect(self.cameractrl_multival) if self.model.img_encoder is not None: self.model.img_encoder.set_ref_drift(self.orig_ref_drift) @@ -884,7 +886,7 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input) # apply scale and effect self.model.set_scale(self.combined_scale) - self.model.set_effect(self.combined_effect) + self.model.set_effect(self.combined_effect, self.per_block_list) # TODO: set combined_per_block_list self.model.set_cameractrl_effect(self.combined_cameractrl_effect) # apply effect - if not within range, set effect to 0, effectively turning model off if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]: @@ -893,7 +895,7 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): else: # if was not in range last step, apply effect to toggle AD status if not self.was_within_range: - self.model.set_effect(self.combined_effect) + self.model.set_effect(self.combined_effect, self.per_block_list) self.was_within_range = True # update steps current keyframe is used self.current_used_steps += 1 @@ -1048,6 +1050,7 @@ def cleanup(self): self.current_effect = None self.combined_scale = None self.combined_effect = None + self.combined_per_block_list = None self.was_within_range = False self.prev_sub_idxs = None self.prev_batched_number = None @@ -1375,6 +1378,14 @@ def validate_model_compatibility_gen2(model: ModelPatcher, motion_model: MotionM + f"but the provided model is type {model_sd_type}.") +def validate_per_block_compatibility(motion_model: MotionModelPatcher, all_per_blocks: AllPerBlocks): + if all_per_blocks is None: + return + mm_info = motion_model.model.mm_info + if all_per_blocks.sd_type != mm_info.sd_type: + raise Exception(f"Per-Block provided is meant for {all_per_blocks.sd_type}, but provided motion module is for {mm_info.sd_type}.") + + def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int): pe_shape = model_dict[key].shape temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 2379479..5754408 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -1,6 +1,7 @@ import math from typing import Iterable, Tuple, Union, TYPE_CHECKING import re +from dataclasses import dataclass import torch from einops import rearrange, repeat @@ -20,7 +21,8 @@ from .adapter_animatelcm_i2v import AdapterEmbed if TYPE_CHECKING: # avoids circular import from .adapter_cameractrl import CameraPoseEncoder -from .utils_motion import CrossAttentionMM, MotionCompatibilityError, DummyNNModule, extend_to_batch_size, prepare_mask_batch +from .utils_motion import (CrossAttentionMM, MotionCompatibilityError, DummyNNModule, extend_to_batch_size, prepare_mask_batch, + get_combined_multival) from .utils_model import BetaSchedules, ModelTypeSD from .logger import logger @@ -74,7 +76,7 @@ def matches(self, id: int): class PerBlockId: - def __init__(self, block_type: str, block_idx: Union[int, None], module_idx: Union[int, None]): + def __init__(self, block_type: str, block_idx: Union[int, None]=None, module_idx: Union[int, None]=None): self.block_type = block_type self.block_idx = block_idx self.module_idx = module_idx @@ -99,13 +101,19 @@ def __str__(self): class PerBlock: def __init__(self, id: PerBlockId, effect: Union[float, Tensor, None]=None, - per_attn_scale: Union[list[PerAttn], None]=None): + scales: Union[list[Union[float, Tensor, None]], None]=None): self.id = id self.effect = effect - self.per_attn_scale = per_attn_scale + self.scales = scales def matches(self, id: PerBlockId): return self.id.matches(id) + + +@dataclass +class AllPerBlocks: + per_block_list: list[PerBlock] + sd_type: Union[str, None] = None #---------------------- ####################### @@ -282,6 +290,7 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo): temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.MID, ops=ops) self.AD_video_length: int = 24 self.effect_model = 1.0 + self.effect_per_block_list = None # AnimateLCM-I2V stuff - create AdapterEmbed if keys present for it self.img_encoder: AdapterEmbed = None if has_img_encoder(mm_state_dict): @@ -475,21 +484,22 @@ def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[lis if self.mid_block is not None: self.mid_block.set_scale(scale, per_block_list) - def set_effect(self, multival: Union[float, Tensor]): + def set_effect(self, multival: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): # keep track of if model is in effect if multival is None: self.effect_model = 1.0 else: self.effect_model = multival + self.effect_per_block_list = per_block_list # pass down effect multival to all blocks if self.down_blocks is not None: for block in self.down_blocks: - block.set_effect(multival) + block.set_effect(multival, per_block_list) if self.up_blocks is not None: for block in self.up_blocks: - block.set_effect(multival) + block.set_effect(multival, per_block_list) if self.mid_block is not None: - self.mid_block.set_effect(multival) + self.mid_block.set_effect(multival, per_block_list) def is_in_effect(self): if type(self.effect_model) == Tensor: @@ -590,9 +600,9 @@ def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[lis for motion_module in self.motion_modules: motion_module.set_scale(scale, per_block_list) - def set_effect(self, multival: Union[float, Tensor]): + def set_effect(self, multival: Union[float, Tensor], per_block_list: Union[list[PerBlock], None]=None): for motion_module in self.motion_modules: - motion_module.set_effect(multival) + motion_module.set_effect(multival, per_block_list) def set_cameractrl_effect(self, multival: Union[float, Tensor]): for motion_module in self.motion_modules: @@ -695,7 +705,8 @@ def set_effect(self, multival: Union[float, Tensor], per_block_list: Union[list[ if per_block_list is not None: for per_block in per_block_list: if self.id.matches(per_block.id) and per_block.effect is not None: - multival = per_block.effect + multival = get_combined_multival(multival, per_block.effect) + logger.info(f"block_type: {self.block_type}, block_idx: {self.block_idx}, module_idx: {self.module_idx}") break if type(multival) == Tensor: self.effect = multival diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 43e718b..9463e23 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -39,6 +39,8 @@ WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) from .nodes_scheduling import (PromptSchedulingNode, PromptSchedulingLatentsNode, ValueSchedulingNode, ValueSchedulingLatentsNode, AddValuesReplaceNode, FloatToFloatsNode) +from .nodes_per_block import (ADBlockComboNode, ADBlockIndivNode, PerBlockHighLevelNode, + PerBlock_SD15_LowLevelNode, PerBlock_SD15_MidLevelNode, PerBlock_SDXL_LowLevelNode, PerBlock_SDXL_MidLevelNode) from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect, PerturbedAttentionGuidanceMultival, RescaleCFGMultival from .nodes_deprecated import (AnimateDiffLoader_Deprecated, AnimateDiffLoaderAdvanced_Deprecated, AnimateDiffCombine_Deprecated, AnimateDiffModelSettings, AnimateDiffModelSettingsSimple, AnimateDiffModelSettingsAdvanced, AnimateDiffModelSettingsAdvancedAttnStrengths) @@ -162,6 +164,14 @@ ValueSchedulingLatentsNode.NodeID: ValueSchedulingLatentsNode, AddValuesReplaceNode.NodeID: AddValuesReplaceNode, FloatToFloatsNode.NodeID: FloatToFloatsNode, + # Per-Block + ADBlockComboNode.NodeID: ADBlockComboNode, + ADBlockIndivNode.NodeID: ADBlockIndivNode, + PerBlockHighLevelNode.NodeID: PerBlockHighLevelNode, + PerBlock_SD15_MidLevelNode.NodeID: PerBlock_SD15_MidLevelNode, + PerBlock_SD15_LowLevelNode.NodeID: PerBlock_SD15_LowLevelNode, + PerBlock_SDXL_MidLevelNode.NodeID: PerBlock_SDXL_MidLevelNode, + PerBlock_SDXL_LowLevelNode.NodeID: PerBlock_SDXL_LowLevelNode, # Extras Nodes "ADE_AnimateDiffUnload": AnimateDiffUnload, "ADE_EmptyLatentImageLarge": EmptyLatentImageLarge, @@ -319,6 +329,14 @@ ValueSchedulingLatentsNode.NodeID: ValueSchedulingLatentsNode.NodeName, AddValuesReplaceNode.NodeID: AddValuesReplaceNode.NodeName, FloatToFloatsNode.NodeID:FloatToFloatsNode.NodeName, + # Per-Block + ADBlockComboNode.NodeID: ADBlockComboNode.NodeName, + ADBlockIndivNode.NodeID: ADBlockIndivNode.NodeName, + PerBlockHighLevelNode.NodeID: PerBlockHighLevelNode.NodeName, + PerBlock_SD15_MidLevelNode.NodeID: PerBlock_SD15_MidLevelNode.NodeName, + PerBlock_SD15_LowLevelNode.NodeID: PerBlock_SD15_LowLevelNode.NodeName, + PerBlock_SDXL_MidLevelNode.NodeID: PerBlock_SDXL_MidLevelNode.NodeName, + PerBlock_SDXL_LowLevelNode.NodeID: PerBlock_SDXL_LowLevelNode.NodeName, # Extras Nodes "ADE_AnimateDiffUnload": "AnimateDiff Unload 🎭🅐🅓", "ADE_EmptyLatentImageLarge": "Empty Latent Image (Big Batch) 🎭🅐🅓", diff --git a/animatediff/nodes_animatelcmi2v.py b/animatediff/nodes_animatelcmi2v.py index 370f5f0..9d63e12 100644 --- a/animatediff/nodes_animatelcmi2v.py +++ b/animatediff/nodes_animatelcmi2v.py @@ -34,7 +34,8 @@ def INPUT_TYPES(s): "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), - "autosize": ("ADEAUTOSIZE", {"padding": 70}), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -44,11 +45,12 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, ref_latent: dict, ref_drift: float=0.0, apply_ref_when_disabled=False, start_percent: float=0.0, end_percent: float=1.0, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, - scale_multival=None, effect_multival=None, + scale_multival=None, effect_multival=None, per_block=None, prev_m_models: MotionModelGroup=None,): new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, motion_lora=motion_lora, ad_keyframes=ad_keyframes, - scale_multival=scale_multival, effect_multival=effect_multival, prev_m_models=prev_m_models) + scale_multival=scale_multival, effect_multival=effect_multival, per_block=per_block, + prev_m_models=prev_m_models) # most recent added model will always be first in list; curr_model = new_m_models[0].models[0] # confirm that model contains img_encoder diff --git a/animatediff/nodes_gen1.py b/animatediff/nodes_gen1.py index ad93440..574c735 100644 --- a/animatediff/nodes_gen1.py +++ b/animatediff/nodes_gen1.py @@ -10,7 +10,10 @@ from .utils_model import BetaSchedules, get_available_motion_loras, get_available_motion_models, get_motion_lora_path from .utils_motion import ADKeyframeGroup, get_combined_multival from .motion_lora import MotionLoraInfo, MotionLoraList -from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelGroup, load_motion_lora_as_patches, load_motion_module_gen1, load_motion_module_gen2, validate_model_compatibility_gen2 +from .motion_module_ad import AllPerBlocks +from .model_injection import (InjectionParams, ModelPatcherAndInjector, MotionModelGroup, + load_motion_lora_as_patches, load_motion_module_gen1, load_motion_module_gen2, validate_model_compatibility_gen2, + validate_per_block_compatibility) from .sample_settings import SampleSettings, SeedNoiseGeneration from .sampling import motion_sample_factory @@ -33,6 +36,7 @@ def INPUT_TYPES(s): "sample_settings": ("SAMPLE_SETTINGS",), "scale_multival": ("MULTIVAL",), "effect_multival": ("MULTIVAL",), + "per_block": ("PER_BLOCK",), } } @@ -45,6 +49,7 @@ def load_mm_and_inject_params(self, model_name: str, beta_schedule: str,# apply_mm_groupnorm_hack: bool, context_options: ContextOptionsGroup=None, motion_lora: MotionLoraList=None, ad_settings: AnimateDiffSettings=None, sample_settings: SampleSettings=None, scale_multival=None, effect_multival=None, ad_keyframes: ADKeyframeGroup=None, + per_block: AllPerBlocks=None, ): # load motion module and motion settings, if included motion_model = load_motion_module_gen2(model_name=model_name, motion_model_settings=ad_settings) @@ -56,6 +61,9 @@ def load_mm_and_inject_params(self, load_motion_lora_as_patches(motion_model, lora) motion_model.scale_multival = scale_multival motion_model.effect_multival = effect_multival + if per_block is not None: + validate_per_block_compatibility(motion_model=motion_model, all_per_blocks=per_block) + motion_model.per_block_list = per_block.per_block_list motion_model.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() # create injection params diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index 8302df2..979c10f 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -9,8 +9,9 @@ from .utils_model import BIGMAX, BetaSchedules, get_available_motion_models from .utils_motion import ADKeyframeGroup, ADKeyframe, InputPIA from .motion_lora import MotionLoraList +from .motion_module_ad import AllPerBlocks from .model_injection import (InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher, create_fresh_motion_module, - load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2) + load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2, validate_per_block_compatibility) from .sample_settings import SampleSettings @@ -94,7 +95,8 @@ def INPUT_TYPES(s): "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), - "autosize": ("ADEAUTOSIZE", {"padding": 80}), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -104,7 +106,7 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: float=0.0, end_percent: float=1.0, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, - scale_multival=None, effect_multival=None, + scale_multival=None, effect_multival=None, per_block: AllPerBlocks=None, prev_m_models: MotionModelGroup=None,): # set up motion models list if prev_m_models is None: @@ -122,6 +124,9 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: fl load_motion_lora_as_patches(motion_model, lora) motion_model.scale_multival = scale_multival motion_model.effect_multival = effect_multival + if per_block is not None: + validate_per_block_compatibility(motion_model=motion_model, all_per_blocks=per_block) + motion_model.per_block_list = per_block.per_block_list motion_model.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() motion_model.timestep_percent_range = (start_percent, end_percent) # add to beginning, so that after injection, it will be the earliest of prev_m_models to be run @@ -141,7 +146,8 @@ def INPUT_TYPES(s): "scale_multival": ("MULTIVAL",), "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), - "autosize": ("ADEAUTOSIZE", {"padding": 40}), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -151,11 +157,12 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, motion_lora: MotionLoraList=None, - scale_multival=None, effect_multival=None, ad_keyframes=None): + scale_multival=None, effect_multival=None, ad_keyframes=None, + per_block: AllPerBlocks=None): # just a subset of normal ApplyAnimateDiffModelNode inputs return ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, motion_lora=motion_lora, scale_multival=scale_multival, effect_multival=effect_multival, - ad_keyframes=ad_keyframes) + ad_keyframes=ad_keyframes, per_block=per_block) class LoadAnimateDiffModelNode: diff --git a/animatediff/nodes_per_block.py b/animatediff/nodes_per_block.py new file mode 100644 index 0000000..5090583 --- /dev/null +++ b/animatediff/nodes_per_block.py @@ -0,0 +1,394 @@ +from typing import Union +from torch import Tensor + +from .motion_module_ad import PerBlock, PerBlockId, BlockType, AllPerBlocks +from .utils_model import ModelTypeSD + + +class ADBlockHolder: + def __init__(self, effect: Union[float, Tensor, None]=None, + scales: Union[list[float, Tensor], None]=list()): + self.effect = effect + self.scales = scales + + def has_effect(self): + return self.effect is not None + + def has_scale(self): + for scale in self.scales: + if scale is not None: + return True + return False + + def is_empty(self): + has_anything = self.has_effect() or self.has_scale() + return not has_anything + + +class ADBlockComboNode: + NodeID = 'ADE_ADBlockCombo' + NodeName = 'AD Block 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "effect": ("MULTIVAL",), + "scale": ("MULTIVAL",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("AD_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "block_control" + + def block_control(self, effect: Union[float, Tensor, None]=None, scale: Union[float, Tensor, None]=None): + scales = [scale, scale] + block = ADBlockHolder(effect=effect, scales=scales) + if block.is_empty(): + block = None + return (block,) + + +class ADBlockIndivNode: + NodeID = 'ADE_ADBlockIndiv' + NodeName = 'AD Block+ 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "effect": ("MULTIVAL",), + "scale_0": ("MULTIVAL",), + "scale_1": ("MULTIVAL",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("AD_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "block_control" + + def block_control(self, effect: Union[float, Tensor, None]=None, + scale_0: Union[float, Tensor, None]=None, scale_1: Union[float, Tensor, None]=None): + scales = [scale_0, scale_1] + block = ADBlockHolder(effect=effect, scales=scales) + if block.is_empty(): + block = None + return (block,) + + +class PerBlockHighLevelNode: + NodeID = 'ADE_PerBlockHighLevel' + NodeName = 'AD Per Block 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up: Union[ADBlockHolder, None]=None): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN): down, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP): up, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks),) + + +class PerBlock_SD15_MidLevelNode: + NodeID = 'ADE_PerBlock_SD15_MidLevel' + NodeName = 'AD Per Block+ (SD1.5) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down_0": ("AD_BLOCK",), + "down_1": ("AD_BLOCK",), + "down_2": ("AD_BLOCK",), + "down_3": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up_0": ("AD_BLOCK",), + "up_1": ("AD_BLOCK",), + "up_2": ("AD_BLOCK",), + "up_3": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down_0: Union[ADBlockHolder, None]=None, + down_1: Union[ADBlockHolder, None]=None, + down_2: Union[ADBlockHolder, None]=None, + down_3: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up_0: Union[ADBlockHolder, None]=None, + up_1: Union[ADBlockHolder, None]=None, + up_2: Union[ADBlockHolder, None]=None, + up_3: Union[ADBlockHolder, None]=None): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN, block_idx=0): down_0, + PerBlockId(block_type=BlockType.DOWN, block_idx=1): down_1, + PerBlockId(block_type=BlockType.DOWN, block_idx=2): down_2, + PerBlockId(block_type=BlockType.DOWN, block_idx=3): down_3, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP, block_idx=0): up_0, + PerBlockId(block_type=BlockType.UP, block_idx=1): up_1, + PerBlockId(block_type=BlockType.UP, block_idx=2): up_2, + PerBlockId(block_type=BlockType.UP, block_idx=3): up_3, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks, ModelTypeSD.SD1_5),) + + +class PerBlock_SD15_LowLevelNode: + NodeID = 'ADE_PerBlock_SD15_LowLevel' + NodeName = 'AD Per Block++ (SD1.5) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down_0__0": ("AD_BLOCK",), + "down_0__1": ("AD_BLOCK",), + "down_1__0": ("AD_BLOCK",), + "down_1__1": ("AD_BLOCK",), + "down_2__0": ("AD_BLOCK",), + "down_2__1": ("AD_BLOCK",), + "down_3__0": ("AD_BLOCK",), + "down_3__1": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up_0__0": ("AD_BLOCK",), + "up_0__1": ("AD_BLOCK",), + "up_0__2": ("AD_BLOCK",), + "up_1__0": ("AD_BLOCK",), + "up_1__1": ("AD_BLOCK",), + "up_1__2": ("AD_BLOCK",), + "up_2__0": ("AD_BLOCK",), + "up_2__1": ("AD_BLOCK",), + "up_2__2": ("AD_BLOCK",), + "up_3__0": ("AD_BLOCK",), + "up_3__1": ("AD_BLOCK",), + "up_3__2": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down_0__0: Union[ADBlockHolder, None]=None, + down_0__1: Union[ADBlockHolder, None]=None, + down_1__0: Union[ADBlockHolder, None]=None, + down_1__1: Union[ADBlockHolder, None]=None, + down_2__0: Union[ADBlockHolder, None]=None, + down_2__1: Union[ADBlockHolder, None]=None, + down_3__0: Union[ADBlockHolder, None]=None, + down_3__1: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up_0__0: Union[ADBlockHolder, None]=None, + up_0__1: Union[ADBlockHolder, None]=None, + up_0__2: Union[ADBlockHolder, None]=None, + up_1__0: Union[ADBlockHolder, None]=None, + up_1__1: Union[ADBlockHolder, None]=None, + up_1__2: Union[ADBlockHolder, None]=None, + up_2__0: Union[ADBlockHolder, None]=None, + up_2__1: Union[ADBlockHolder, None]=None, + up_2__2: Union[ADBlockHolder, None]=None, + up_3__0: Union[ADBlockHolder, None]=None, + up_3__1: Union[ADBlockHolder, None]=None, + up_3__2: Union[ADBlockHolder, None]=None): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN, block_idx=0, module_idx=0): down_0__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=0, module_idx=1): down_0__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=1, module_idx=0): down_1__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=1, module_idx=1): down_1__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=2, module_idx=0): down_2__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=2, module_idx=1): down_2__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=3, module_idx=0): down_3__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=3, module_idx=1): down_3__1, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=0): up_0__0, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=1): up_0__1, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=2): up_0__2, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=0): up_1__0, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=1): up_1__1, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=2): up_1__2, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=0): up_2__0, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=1): up_2__1, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=2): up_2__2, + PerBlockId(block_type=BlockType.UP, block_idx=3, module_idx=0): up_3__0, + PerBlockId(block_type=BlockType.UP, block_idx=3, module_idx=1): up_3__1, + PerBlockId(block_type=BlockType.UP, block_idx=3, module_idx=2): up_3__2, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks, ModelTypeSD.SD1_5),) + + +class PerBlock_SDXL_MidLevelNode: + NodeID = 'ADE_PerBlock_SDXL_MidLevel' + NodeName = 'AD Per Block+ (SDXL) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down_0": ("AD_BLOCK",), + "down_1": ("AD_BLOCK",), + "down_2": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up_0": ("AD_BLOCK",), + "up_1": ("AD_BLOCK",), + "up_2": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down_0: Union[ADBlockHolder, None]=None, + down_1: Union[ADBlockHolder, None]=None, + down_2: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up_0: Union[ADBlockHolder, None]=None, + up_1: Union[ADBlockHolder, None]=None, + up_2: Union[ADBlockHolder, None]=None): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN, block_idx=0): down_0, + PerBlockId(block_type=BlockType.DOWN, block_idx=1): down_1, + PerBlockId(block_type=BlockType.DOWN, block_idx=2): down_2, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP, block_idx=0): up_0, + PerBlockId(block_type=BlockType.UP, block_idx=1): up_1, + PerBlockId(block_type=BlockType.UP, block_idx=2): up_2, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks, ModelTypeSD.SDXL),) + + +class PerBlock_SDXL_LowLevelNode: + NodeID = 'ADE_PerBlock_SDXL_LowLevel' + NodeName = 'AD Per Block++ (SDXL) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "down_0__0": ("AD_BLOCK",), + "down_0__1": ("AD_BLOCK",), + "down_1__0": ("AD_BLOCK",), + "down_1__1": ("AD_BLOCK",), + "down_2__0": ("AD_BLOCK",), + "down_2__1": ("AD_BLOCK",), + "mid": ("AD_BLOCK",), + "up_0__0": ("AD_BLOCK",), + "up_0__1": ("AD_BLOCK",), + "up_0__2": ("AD_BLOCK",), + "up_1__0": ("AD_BLOCK",), + "up_1__1": ("AD_BLOCK",), + "up_1__2": ("AD_BLOCK",), + "up_2__0": ("AD_BLOCK",), + "up_2__1": ("AD_BLOCK",), + "up_2__2": ("AD_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + def create_per_block(self, + down_0__0: Union[ADBlockHolder, None]=None, + down_0__1: Union[ADBlockHolder, None]=None, + down_1__0: Union[ADBlockHolder, None]=None, + down_1__1: Union[ADBlockHolder, None]=None, + down_2__0: Union[ADBlockHolder, None]=None, + down_2__1: Union[ADBlockHolder, None]=None, + mid: Union[ADBlockHolder, None]=None, + up_0__0: Union[ADBlockHolder, None]=None, + up_0__1: Union[ADBlockHolder, None]=None, + up_0__2: Union[ADBlockHolder, None]=None, + up_1__0: Union[ADBlockHolder, None]=None, + up_1__1: Union[ADBlockHolder, None]=None, + up_1__2: Union[ADBlockHolder, None]=None, + up_2__0: Union[ADBlockHolder, None]=None, + up_2__1: Union[ADBlockHolder, None]=None, + up_2__2: Union[ADBlockHolder, None]=None,): + blocks = [] + d = { + PerBlockId(block_type=BlockType.DOWN, block_idx=0, module_idx=0): down_0__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=0, module_idx=1): down_0__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=1, module_idx=0): down_1__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=1, module_idx=1): down_1__1, + PerBlockId(block_type=BlockType.DOWN, block_idx=2, module_idx=0): down_2__0, + PerBlockId(block_type=BlockType.DOWN, block_idx=2, module_idx=1): down_2__1, + PerBlockId(block_type=BlockType.MID): mid, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=0): up_0__0, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=1): up_0__1, + PerBlockId(block_type=BlockType.UP, block_idx=0, module_idx=2): up_0__2, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=0): up_1__0, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=1): up_1__1, + PerBlockId(block_type=BlockType.UP, block_idx=1, module_idx=2): up_1__2, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=0): up_2__0, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=1): up_2__1, + PerBlockId(block_type=BlockType.UP, block_idx=2, module_idx=2): up_2__2, + } + for id, block in d.items(): + if block is not None: + blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) + if len(blocks) == 0: + return (None,) + return (AllPerBlocks(blocks, ModelTypeSD.SDXL),) diff --git a/animatediff/nodes_pia.py b/animatediff/nodes_pia.py index e65386a..5944485 100644 --- a/animatediff/nodes_pia.py +++ b/animatediff/nodes_pia.py @@ -125,7 +125,8 @@ def INPUT_TYPES(s): "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), - "autosize": ("ADEAUTOSIZE", {"padding": 70}), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -136,11 +137,12 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, image: Tensor, vae: VAE, start_percent: float=0.0, end_percent: float=1.0, pia_input: InputPIA=None, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, - scale_multival=None, effect_multival=None, ref_multival=None, + scale_multival=None, effect_multival=None, ref_multival=None, per_block=None, prev_m_models: MotionModelGroup=None,): new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, motion_lora=motion_lora, ad_keyframes=ad_keyframes, - scale_multival=scale_multival, effect_multival=effect_multival, prev_m_models=prev_m_models) + scale_multival=scale_multival, effect_multival=effect_multival, per_block=per_block, + prev_m_models=prev_m_models) # most recent added model will always be first in list; curr_model = new_m_models[0].models[0] # confirm that model is PIA From 6813b0c19b377d6c2129670f0af43e35a59f3793 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 16 Aug 2024 04:55:34 -0500 Subject: [PATCH 05/11] Fix sd_type check for AD Per Block node --- animatediff/model_injection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index df5047c..0c4249f 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -1379,7 +1379,7 @@ def validate_model_compatibility_gen2(model: ModelPatcher, motion_model: MotionM def validate_per_block_compatibility(motion_model: MotionModelPatcher, all_per_blocks: AllPerBlocks): - if all_per_blocks is None: + if all_per_blocks is None or all_per_blocks.sd_type is None: return mm_info = motion_model.model.mm_info if all_per_blocks.sd_type != mm_info.sd_type: From 96d280793140cad6cd68d4ec5a0623570fa267ff Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 16 Aug 2024 05:53:27 -0500 Subject: [PATCH 06/11] Added Per Block Floats nodes --- animatediff/nodes.py | 7 ++- animatediff/nodes_per_block.py | 100 +++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 9463e23..a429789 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -40,7 +40,8 @@ from .nodes_scheduling import (PromptSchedulingNode, PromptSchedulingLatentsNode, ValueSchedulingNode, ValueSchedulingLatentsNode, AddValuesReplaceNode, FloatToFloatsNode) from .nodes_per_block import (ADBlockComboNode, ADBlockIndivNode, PerBlockHighLevelNode, - PerBlock_SD15_LowLevelNode, PerBlock_SD15_MidLevelNode, PerBlock_SDXL_LowLevelNode, PerBlock_SDXL_MidLevelNode) + PerBlock_SD15_LowLevelNode, PerBlock_SD15_MidLevelNode, PerBlock_SD15_FromFloatsNode, + PerBlock_SDXL_LowLevelNode, PerBlock_SDXL_MidLevelNode, PerBlock_SDXL_FromFloatsNode) from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect, PerturbedAttentionGuidanceMultival, RescaleCFGMultival from .nodes_deprecated import (AnimateDiffLoader_Deprecated, AnimateDiffLoaderAdvanced_Deprecated, AnimateDiffCombine_Deprecated, AnimateDiffModelSettings, AnimateDiffModelSettingsSimple, AnimateDiffModelSettingsAdvanced, AnimateDiffModelSettingsAdvancedAttnStrengths) @@ -170,8 +171,10 @@ PerBlockHighLevelNode.NodeID: PerBlockHighLevelNode, PerBlock_SD15_MidLevelNode.NodeID: PerBlock_SD15_MidLevelNode, PerBlock_SD15_LowLevelNode.NodeID: PerBlock_SD15_LowLevelNode, + PerBlock_SD15_FromFloatsNode.NodeID: PerBlock_SD15_FromFloatsNode, PerBlock_SDXL_MidLevelNode.NodeID: PerBlock_SDXL_MidLevelNode, PerBlock_SDXL_LowLevelNode.NodeID: PerBlock_SDXL_LowLevelNode, + PerBlock_SDXL_FromFloatsNode.NodeID: PerBlock_SDXL_FromFloatsNode, # Extras Nodes "ADE_AnimateDiffUnload": AnimateDiffUnload, "ADE_EmptyLatentImageLarge": EmptyLatentImageLarge, @@ -335,8 +338,10 @@ PerBlockHighLevelNode.NodeID: PerBlockHighLevelNode.NodeName, PerBlock_SD15_MidLevelNode.NodeID: PerBlock_SD15_MidLevelNode.NodeName, PerBlock_SD15_LowLevelNode.NodeID: PerBlock_SD15_LowLevelNode.NodeName, + PerBlock_SD15_FromFloatsNode.NodeID: PerBlock_SD15_FromFloatsNode.NodeName, PerBlock_SDXL_MidLevelNode.NodeID: PerBlock_SDXL_MidLevelNode.NodeName, PerBlock_SDXL_LowLevelNode.NodeID: PerBlock_SDXL_LowLevelNode.NodeName, + PerBlock_SDXL_FromFloatsNode.NodeID: PerBlock_SDXL_FromFloatsNode.NodeName, # Extras Nodes "ADE_AnimateDiffUnload": "AnimateDiff Unload 🎭🅐🅓", "ADE_EmptyLatentImageLarge": "Empty Latent Image (Big Batch) 🎭🅐🅓", diff --git a/animatediff/nodes_per_block.py b/animatediff/nodes_per_block.py index 5090583..b06621f 100644 --- a/animatediff/nodes_per_block.py +++ b/animatediff/nodes_per_block.py @@ -1,8 +1,10 @@ from typing import Union from torch import Tensor +from .documentation import short_desc, register_description, coll, DocHelper from .motion_module_ad import PerBlock, PerBlockId, BlockType, AllPerBlocks from .utils_model import ModelTypeSD +from .utils_motion import extend_list_to_batch_size class ADBlockHolder: @@ -267,6 +269,55 @@ def create_per_block(self, return (AllPerBlocks(blocks, ModelTypeSD.SD1_5),) +class PerBlock_SD15_FromFloatsNode: + NodeID = 'ADE_PerBlock_SD15_FromFloats' + NodeName = 'AD Per Block Floats (SD1.5) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "effect_21_floats": ("FLOATS",), + "scale_21_floats": ("FLOATS",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + Desc = [ + short_desc('Use Floats from Value Schedules to select SD1.5 effect/scale values for blocks.'), + 'SD1.5 Motion Modules contain 21 blocks:', + 'idx 0 - start of down blocks (down_0__0)', + 'idx 7 - end of down blocks (down_3__1)', + 'idx 8 - mid block (mid)', + 'idx 9 - start of up blocks (up_0__0)', + 'idx 20 - end of up blocks (up_3__2)', + ] + register_description(NodeID, Desc) + + def create_per_block(self, + effect_21_floats: Union[list[float], None]=None, + scale_21_floats: Union[list[float], None]=None): + if effect_21_floats is None and scale_21_floats is None: + return (None,) + # SD1.5 has 21 blocks + block_total = 21 + holders = [ADBlockHolder() for _ in range(block_total)] + if effect_21_floats is not None: + effect_21_floats = extend_list_to_batch_size(effect_21_floats, block_total) + for effect, holder in zip(effect_21_floats, holders): + holder.effect = effect + if scale_21_floats is not None: + scale_21_floats = extend_list_to_batch_size(scale_21_floats, block_total) + for scale, holder in zip(scale_21_floats, holders): + holder.scales = [scale, scale] + return PerBlock_SD15_LowLevelNode.create_per_block(self, *holders) + + class PerBlock_SDXL_MidLevelNode: NodeID = 'ADE_PerBlock_SDXL_MidLevel' NodeName = 'AD Per Block+ (SDXL) 🎭🅐🅓' @@ -392,3 +443,52 @@ def create_per_block(self, if len(blocks) == 0: return (None,) return (AllPerBlocks(blocks, ModelTypeSD.SDXL),) + + +class PerBlock_SDXL_FromFloatsNode: + NodeID = 'ADE_PerBlock_SDXL_FromFloats' + NodeName = 'AD Per Block Floats (SDXL) 🎭🅐🅓' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "effect_16_floats": ("FLOATS",), + "scale_16_floats": ("FLOATS",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("PER_BLOCK",) + CATEGORY = "Animate Diff 🎭🅐🅓/per block" + FUNCTION = "create_per_block" + + Desc = [ + short_desc('Use Floats from Value Schedules to select SDXL effect/scale values for blocks.'), + 'SDXL Motion Modules contain 16 blocks:', + 'idx 0 - start of down blocks (down_0__0)', + 'idx 5 - end of down blocks (down_2__1)', + 'idx 6 - mid block (mid)', + 'idx 7 - start of up blocks (up_0__0)', + 'idx 15 - end of up blocks (up_2__2)', + ] + register_description(NodeID, Desc) + + def create_per_block(self, + effect_16_floats: Union[list[float], None]=None, + scale_16_floats: Union[list[float], None]=None): + if effect_16_floats is None and scale_16_floats is None: + return (None,) + # SDXL has 16 blocks + block_total = 16 + holders = [ADBlockHolder() for _ in range(block_total)] + if effect_16_floats is not None: + effect_16_floats = extend_list_to_batch_size(effect_16_floats, block_total) + for effect, holder in zip(effect_16_floats, holders): + holder.effect = effect + if scale_16_floats is not None: + scale_16_floats = extend_list_to_batch_size(scale_16_floats, block_total) + for scale, holder in zip(scale_16_floats, holders): + holder.scales = [scale, scale] + return PerBlock_SDXL_LowLevelNode.create_per_block(self, *holders) From 3a380d0a18bb93e662f3c3bd9f4a7b7d24a0f1a0 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 16 Aug 2024 08:35:40 -0500 Subject: [PATCH 07/11] Added scale support to Per Block nodes --- animatediff/model_injection.py | 4 +- animatediff/motion_module_ad.py | 128 ++++++++++++++++++++------------ 2 files changed, 81 insertions(+), 51 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 0c4249f..0b2ed14 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -821,7 +821,7 @@ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patc def pre_run(self, model: ModelPatcherAndInjector): self.cleanup() - self.model.set_scale(self.scale_multival) + self.model.set_scale(self.scale_multival, self.per_block_list) self.model.set_effect(self.effect_multival, self.per_block_list) self.model.set_cameractrl_effect(self.cameractrl_multival) if self.model.img_encoder is not None: @@ -885,7 +885,7 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x) self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input) # apply scale and effect - self.model.set_scale(self.combined_scale) + self.model.set_scale(self.combined_scale, self.per_block_list) self.model.set_effect(self.combined_effect, self.per_block_list) # TODO: set combined_per_block_list self.model.set_cameractrl_effect(self.combined_cameractrl_effect) # apply effect - if not within range, set effect to 0, effectively turning model off diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 5754408..7966b34 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -2,6 +2,7 @@ from typing import Iterable, Tuple, Union, TYPE_CHECKING import re from dataclasses import dataclass +from collections.abc import Iterable import torch from einops import rearrange, repeat @@ -21,8 +22,8 @@ from .adapter_animatelcm_i2v import AdapterEmbed if TYPE_CHECKING: # avoids circular import from .adapter_cameractrl import CameraPoseEncoder -from .utils_motion import (CrossAttentionMM, MotionCompatibilityError, DummyNNModule, extend_to_batch_size, prepare_mask_batch, - get_combined_multival) +from .utils_motion import (CrossAttentionMM, MotionCompatibilityError, DummyNNModule, extend_to_batch_size, extend_list_to_batch_size, + prepare_mask_batch, get_combined_multival) from .utils_model import BetaSchedules, ModelTypeSD from .logger import logger @@ -685,6 +686,7 @@ def __init__( cross_frame_attention_mode=cross_frame_attention_mode, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, + block_id=self.id, ops=ops ) @@ -699,14 +701,14 @@ def set_video_length(self, video_length: int, full_length: int): self.temporal_transformer.set_video_length(video_length, full_length) def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): - self.temporal_transformer.set_scale(scale) + self.temporal_transformer.set_scale(scale, per_block_list) def set_effect(self, multival: Union[float, Tensor], per_block_list: Union[list[PerBlock], None]=None): if per_block_list is not None: for per_block in per_block_list: if self.id.matches(per_block.id) and per_block.effect is not None: multival = get_combined_multival(multival, per_block.effect) - logger.info(f"block_type: {self.block_type}, block_idx: {self.block_idx}, module_idx: {self.module_idx}") + #logger.info(f"block_type: {self.block_type}, block_idx: {self.block_idx}, module_idx: {self.module_idx}") break if type(multival) == Tensor: self.effect = multival @@ -826,13 +828,13 @@ def __init__( cross_frame_attention_mode=None, temporal_pe=False, temporal_pe_max_len=24, + block_id: PerBlockId=None, ops=comfy.ops.disable_weight_init, ): super().__init__() + self.id = block_id self.video_length = 16 self.full_length = 16 - self.raw_scale_mask: Union[Tensor, None] = None - self.temp_scale_mask: Union[Tensor, None] = None self.sub_idxs: Union[list[int], None] = None self.prev_hidden_states_batch = 0 @@ -871,31 +873,50 @@ def __init__( ) self.proj_out = ops.Linear(inner_dim, in_channels) + self.raw_scale_masks: Union[list[Tensor], None] = [None] * self.get_attention_count() + self.temp_scale_masks: Union[list[Tensor], None] = [None] * self.get_attention_count() + + def get_attention_count(self): + if len(self.transformer_blocks) > 0: + return len(self.transformer_blocks[0].attention_blocks) + return 0 + def set_video_length(self, video_length: int, full_length: int): self.video_length = video_length self.full_length = full_length - def set_scale_multiplier(self, multiplier: Union[float, None]): + def set_scale_multiplier(self, idx: int, multiplier: Union[float, list[float], None]): + # if not isinstance(multiplier, Iterable): + # multiplier = [multiplier] + # multiplier = extend_list_to_batch_size(multiplier, self.get_attention_count()) for block in self.transformer_blocks: - mult = multiplier - block.set_scale_multiplier(mult) - - def set_scale_mask(self, mask: Tensor): - self.raw_scale_mask = mask - self.temp_scale_mask = None - - def set_scale(self, scale: Union[float, Tensor, None], per_attn_list: Union[list[PerAttn], None]=None): - # if per_attn_list is not None: - # for per_attn in per_attn_list: - # if per_attn.attn_idx == idx: - # mult = per_attn.scale - # break - if type(scale) == Tensor: - self.set_scale_mask(scale) - self.set_scale_multiplier(None) - else: - self.set_scale_mask(None) - self.set_scale_multiplier(scale) + block.set_scale_multiplier(idx, multiplier) + + def set_scale_mask(self, idx: int, mask: Tensor): + self.raw_scale_masks[idx] = mask + self.temp_scale_masks[idx] = None + + def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list[PerBlock], None]=None): + if per_block_list is not None: + for per_block in per_block_list: + if self.id.matches(per_block.id) and len(per_block.scales) > 0: + scales = [] + for sub_scale in per_block.scales: + scales.append(get_combined_multival(scale, sub_scale)) + #logger.info(f"scale - block_type: {self.id.block_type}, block_idx: {self.id.block_idx}, module_idx: {self.id.module_idx}") + scale = scales + break + + if type(scale) == Tensor or not isinstance(scale, Iterable): + scale = [scale] + scale = extend_list_to_batch_size(scale, self.get_attention_count()) + for idx, sub_scale in enumerate(scale): + if type(sub_scale) == Tensor: + self.set_scale_mask(idx, sub_scale) + self.set_scale_multiplier(idx, None) + else: + self.set_scale_mask(idx, None) + self.set_scale_multiplier(idx, sub_scale) def set_cameractrl_effect(self, multival: Union[float, Tensor]): self.raw_cameractrl_effect = multival @@ -906,8 +927,8 @@ def set_sub_idxs(self, sub_idxs: list[int]): block.set_sub_idxs(sub_idxs) def reset_temp_vars(self): - del self.temp_scale_mask - self.temp_scale_mask = None + del self.temp_scale_masks + self.temp_scale_masks = [None] * self.get_attention_count() self.prev_hidden_states_batch = 0 del self.temp_cameractrl_effect self.temp_cameractrl_effect = None @@ -915,25 +936,36 @@ def reset_temp_vars(self): for block in self.transformer_blocks: block.reset_temp_vars() - def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: + def get_scale_masks(self, hidden_states: Tensor) -> Union[Tensor, None]: + masks = [] + prev_mask = None + prev_idx = 0 + for idx in range(len(self.raw_scale_masks)): + if prev_mask is self.raw_scale_masks[idx]: + masks.append(self.temp_scale_masks[prev_idx]) + else: + masks.append(self.get_scale_mask(idx=idx, hidden_states=hidden_states)) + prev_idx = idx + return masks + + def get_scale_mask(self, idx: int, hidden_states: Tensor) -> Union[Tensor, None]: # if no raw mask, return None - if self.raw_scale_mask is None: + if self.raw_scale_masks[idx] is None: return None shape = hidden_states.shape batch, channel, height, width = shape # if temp mask already calculated, return it - if self.temp_scale_mask != None: + if self.temp_scale_masks[idx] != None: # check if hidden_states batch matches if batch == self.prev_hidden_states_batch: if self.sub_idxs is not None: - return self.temp_scale_mask[:, self.sub_idxs, :] - return self.temp_scale_mask + return self.temp_scale_masks[idx][:, self.sub_idxs, :] + return self.temp_scale_masks[idx] # if does not match, reset cached temp_scale_mask and recalculate it - del self.temp_scale_mask - self.temp_scale_mask = None + self.temp_scale_masks[idx] = None # otherwise, calculate temp mask self.prev_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) + mask = prepare_mask_batch(self.raw_scale_masks[idx], shape=(self.full_length, 1, height, width)) mask = repeat_to_batch_size(mask, self.full_length) # if mask not the same amount length as full length, make it match if self.full_length != mask.shape[0]: @@ -950,13 +982,13 @@ def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: if batched_number > 1: mask = torch.cat([mask] * batched_number, dim=0) # cache mask and set to proper device - self.temp_scale_mask = mask + self.temp_scale_masks[idx] = mask # move temp_scale_mask to proper dtype + device - self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device) + self.temp_scale_masks[idx] = self.temp_scale_masks[idx].to(dtype=hidden_states.dtype, device=hidden_states.device) # return subset of masks, if needed if self.sub_idxs is not None: - return self.temp_scale_mask[:, self.sub_idxs, :] - return self.temp_scale_mask + return self.temp_scale_masks[idx][:, self.sub_idxs, :] + return self.temp_scale_masks[idx] def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, None]: # if no raw camera_Ctrl, return None @@ -1007,7 +1039,7 @@ def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, N def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options: ContextOptions=None, mm_kwargs: dict[str]=None): batch, channel, height, width = hidden_states.shape residual = hidden_states - scale_mask = self.get_scale_mask(hidden_states) + scale_masks = self.get_scale_masks(hidden_states) cameractrl_effect = self.get_cameractrl_effect(hidden_states) # add some casts for fp8 purposes - does not affect speed otherwise hidden_states = self.norm(hidden_states).to(hidden_states.dtype) @@ -1024,7 +1056,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, video_length=self.video_length, - scale_mask=scale_mask, + scale_masks=scale_masks, cameractrl_effect=cameractrl_effect, view_options=view_options, mm_kwargs=mm_kwargs @@ -1097,10 +1129,8 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops) self.ff_norm = ops.LayerNorm(dim) - def set_scale_multiplier(self, multiplier: Union[float, None]): - for idx, block in enumerate(self.attention_blocks): - mult = multiplier - block.set_scale_multiplier(mult) + def set_scale_multiplier(self, idx: int, multiplier: Union[float, None]): + self.attention_blocks[idx].set_scale_multiplier(multiplier) def set_sub_idxs(self, sub_idxs: list[int]): for block in self.attention_blocks: @@ -1116,7 +1146,7 @@ def forward( encoder_hidden_states: Tensor=None, attention_mask: Tensor=None, video_length: int=None, - scale_mask: Tensor=None, + scale_masks: list[Tensor]=None, cameractrl_effect: Union[float, Tensor] = None, view_options: Union[ContextOptions, None]=None, mm_kwargs: dict[str]=None, @@ -1128,7 +1158,7 @@ def forward( elif view_options.context_length == video_length and not view_options.use_on_equal_length: view_options = None if not view_options: - for attention_block, norm in zip(self.attention_blocks, self.norms): + for attention_block, norm, scale_mask in zip(self.attention_blocks, self.norms, scale_masks): norm_hidden_states = norm(hidden_states).to(hidden_states.dtype) hidden_states = ( attention_block( @@ -1163,7 +1193,7 @@ def forward( sub_hidden_states = rearrange(hidden_states[:, sub_idxs], "b f d c -> (b f) d c") if has_camera_feature: mm_kwargs["camera_feature"] = orig_camera_feature[:, sub_idxs, :] - for attention_block, norm in zip(self.attention_blocks, self.norms): + for attention_block, norm, scale_mask in zip(self.attention_blocks, self.norms, scale_masks): norm_hidden_states = norm(sub_hidden_states).to(sub_hidden_states.dtype) sub_hidden_states = ( attention_block( From 3c4d5d45e87fa7b6917a95e56b8b887eb13ace1e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 16 Aug 2024 08:38:25 -0500 Subject: [PATCH 08/11] version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9029be3..35964e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.1.2" +version = "1.1.3" license = { file = "LICENSE" } dependencies = [] From 541ece59155f844692c1705a5044eefea28f98af Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 16 Aug 2024 08:42:43 -0500 Subject: [PATCH 09/11] cleanup conflicting type import --- animatediff/motion_module_ad.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 7966b34..19fda16 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -2,7 +2,7 @@ from typing import Iterable, Tuple, Union, TYPE_CHECKING import re from dataclasses import dataclass -from collections.abc import Iterable +from collections.abc import Iterable as IterColl import torch from einops import rearrange, repeat @@ -886,9 +886,6 @@ def set_video_length(self, video_length: int, full_length: int): self.full_length = full_length def set_scale_multiplier(self, idx: int, multiplier: Union[float, list[float], None]): - # if not isinstance(multiplier, Iterable): - # multiplier = [multiplier] - # multiplier = extend_list_to_batch_size(multiplier, self.get_attention_count()) for block in self.transformer_blocks: block.set_scale_multiplier(idx, multiplier) @@ -907,7 +904,7 @@ def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[lis scale = scales break - if type(scale) == Tensor or not isinstance(scale, Iterable): + if type(scale) == Tensor or not isinstance(scale, IterColl): scale = [scale] scale = extend_list_to_batch_size(scale, self.get_attention_count()) for idx, sub_scale in enumerate(scale): From 7eeed536d13dff58c7a8093b7d2f10852d77b3b8 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 16 Aug 2024 09:12:12 -0500 Subject: [PATCH 10/11] Fixed scale_masks param being None with CameraCtrl --- animatediff/motion_module_ad.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 19fda16..1830266 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -1148,6 +1148,8 @@ def forward( view_options: Union[ContextOptions, None]=None, mm_kwargs: dict[str]=None, ): + if scale_masks is None: + scale_masks = [None] * len(self.attention_blocks) # make view_options None if context_length > video_length, or if equal and equal not allowed if view_options: if view_options.context_length > video_length: From 8c76a9ccf46f17c4a1c2dad46d9bf619e6fae4db Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 16 Aug 2024 09:18:02 -0500 Subject: [PATCH 11/11] Added per_block input to Apply AnimateDiff+CameraCtrl Model --- animatediff/nodes_cameractrl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/animatediff/nodes_cameractrl.py b/animatediff/nodes_cameractrl.py index 9352bae..6ba94e7 100644 --- a/animatediff/nodes_cameractrl.py +++ b/animatediff/nodes_cameractrl.py @@ -209,6 +209,7 @@ def INPUT_TYPES(s): "cameractrl_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), + "per_block": ("PER_BLOCK",), } } @@ -218,10 +219,10 @@ def INPUT_TYPES(s): def apply_motion_model(self, motion_model: MotionModelPatcher, cameractrl_poses: list[list[float]], start_percent: float=0.0, end_percent: float=1.0, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, - scale_multival=None, effect_multival=None, cameractrl_multival=None, + scale_multival=None, effect_multival=None, cameractrl_multival=None, per_block=None, prev_m_models: MotionModelGroup=None,): new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, - motion_lora=motion_lora, ad_keyframes=ad_keyframes, + motion_lora=motion_lora, ad_keyframes=ad_keyframes, per_block=per_block, scale_multival=scale_multival, effect_multival=effect_multival, prev_m_models=prev_m_models) # most recent added model will always be first in list; curr_model = new_m_models[0].models[0]