diff --git a/README.md b/README.md index 1b440fc..9b59cf4 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,8 @@ NOTE: you can also use custom locations for models/motion loras by making use of - NOTE: Requires same settings as described for AnimateLCM above. Requires ```Apply AnimateLCM-I2V Model``` Gen2 node usage so that ```ref_latent``` can be provided; use ```Scale Ref Image and VAE Encode``` node to preprocess input images. While this was intended as an img2video model, I found it works best for vid2vid purposes with ```ref_drift=0.0```, and to use it for only at least 1 step before switching over to other models via chaining with toher Apply AnimateDiff Model (Adv.) nodes. The ```apply_ref_when_disabled``` can be set to True to allow the img_encoder to do its thing even when the ```end_percent``` is reached. AnimateLCM-I2V is also extremely useful for maintaining coherence at higher resolutions (with ControlNet and SD LoRAs active, I could easily upscale from 512x512 source to 1024x1024 in a single pass). TODO: add examples - [CameraCtrl](https://github.com/hehao13/CameraCtrl) support, with the pruned model you must use here: [CameraCtrl_pruned.safetensors](https://huggingface.co/Kosinkadink/CameraCtrl/tree/main) - NOTE: Requires AnimateDiff SD1.5 models, and was specifically trained for v3 model. Gen2 only, with helper nodes provided under Gen2/CameraCtrl submenu. +- [PIA](https://github.com/open-mmlab/PIA) support, with the model [pia.ckpt](https://huggingface.co/Leoxing/PIA/tree/main) + - NOTE: You will need to use ```autoselect``` or ```sqrt_linear (AnimateDiff)``` beta_schedule. Requires ```Apply AnimateDiff-PIA Model``` Gen2 node usage if you want to actually provide input images. The ```pia_input``` can be provided via the paper's presets (```PIA Input [Paper Presets]```) or by manually entering values (```PIA Input [Multival]```). - AnimateDiff Keyframes to change Scale and Effect at different points in the sampling process. - fp8 support; requires newest ComfyUI and torch >= 2.1 (decreases VRAM usage, but changes outputs) - Mac M1/M2/M3 support @@ -72,14 +74,17 @@ NOTE: you can also use custom locations for models/motion loras by making use of - Maskable and Schedulable SD LoRA (and Models as LoRA) for both AnimateDiff and StableDiffusion usage via LoRA Hooks - Per-frame GLIGEN coordinates control - Currently requires GLIGENTextBoxApplyBatch from KJNodes to do so, but I will add native nodes to do this soon. +- Image Injection mid-sampling ## Upcoming Features -- Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: mid-May) -- Maskable Motion LoRA (Goal: end of May/beginning of June) +- Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: before Elden Ring DLC releases. Working on it right now.) +- [UniCtrl](https://github.com/XuweiyiChen/UniCtrl) support +- Unet-Ref support so that a bunch of papers can be ported over +- [StoryDiffusion](https://github.com/HVision-NKU/StoryDiffusion) implementation +- Merging motion model weights/components, including per block customization +- Maskable Motion LoRA - Timestep schedulable GLIGEN coordinates - Dynamic memory management for motion models that load/unload at different start/end_percents -- [PIA](https://github.com/open-mmlab/PIA) support -- [UniCtrl](https://github.com/XuweiyiChen/UniCtrl) support - Built-in prompt travel implementation - Anything else AnimateDiff-related that comes out diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index e5a666c..aecc2c0 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -8,12 +8,13 @@ import uuid import math +import comfy.conds import comfy.lora import comfy.model_management import comfy.utils from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel -from comfy.sd import CLIP +from comfy.sd import CLIP, VAE from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding @@ -21,10 +22,12 @@ from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention, has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) from .logger import logger -from .utils_motion import ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, get_combined_multival, ade_broadcast_image_to, normalize_min_max +from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA, + get_combined_multival, get_combined_input, get_combined_input_effect_multival, + ade_broadcast_image_to, extend_to_batch_size, prepare_mask_batch) from .conditioning import HookRef, LoraHook, LoraHookGroup, LoraHookMode from .motion_lora import MotionLoraInfo, MotionLoraList -from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type +from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched from .sample_settings import SampleSettings, SeedNoiseGeneration @@ -138,7 +141,6 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s ''' Based on add_patches, but for hooked weights. ''' - # TODO: make this work with timestep scheduling current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) p = set() model_sd = self.model.state_dict() @@ -164,7 +166,6 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng ''' Based on add_hooked_patches, but intended for using a model's weights as lora hook. ''' - # TODO: make this work with timestep scheduling current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) p = set() model_sd = self.model.state_dict() @@ -180,6 +181,7 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng p.add(k) current_patches: list[tuple] = current_hooked_patches.get(key, []) # take difference between desired weight and existing weight to get diff + # TODO: create fix for fp8 current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset)) current_hooked_patches[key] = current_patches self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches @@ -238,7 +240,7 @@ def patch_model_lowvram(self, *args, **kwargs): self.model_params_lowvram_keys[f"{n}.weight"] = n if getattr(m, "bias_function", None) is not None: self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.weight"] = n + self.model_params_lowvram_keys[f"{n}.bias"] = n def unpatch_model(self, device_to=None, unpatch_weights=True): # first, eject motion model from unet @@ -721,7 +723,16 @@ def __init__(self, *args, **kwargs): self.orig_camera_entries: list[CameraEntry] = None self.camera_features: list[Tensor] = None # temporary self.camera_features_shape: tuple = None - self.cameractrl_multival = None + self.cameractrl_multival: Union[float, Tensor] = None + + # PIA + self.orig_pia_images: Tensor = None + self.pia_vae: VAE = None + self.pia_input: InputPIA = None + self.cached_pia_c_concat: comfy.conds.CONDNoiseShape = None # cached + self.prev_pia_latents_shape: tuple = None + self.prev_current_pia_input: InputPIA = None + self.pia_multival: Union[float, Tensor] = None # temporary variables self.current_used_steps = 0 @@ -730,9 +741,12 @@ def __init__(self, *args, **kwargs): self.current_scale: Union[float, Tensor] = None self.current_effect: Union[float, Tensor] = None self.current_cameractrl_effect: Union[float, Tensor] = None + self.current_pia_input: InputPIA = None self.combined_scale: Union[float, Tensor] = None self.combined_effect: Union[float, Tensor] = None self.combined_cameractrl_effect: Union[float, Tensor] = None + self.combined_pia_mask: Union[float, Tensor] = None + self.combined_pia_effect: Union[float, Tensor] = None self.was_within_range = False self.prev_sub_idxs = None self.prev_batched_number = None @@ -774,7 +788,7 @@ def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, t: Tensor): + def prepare_current_keyframe(self, x: Tensor, t: Tensor): curr_t: float = t[0] prev_index = self.current_index # if met guaranteed steps, look for next keyframe in case need to switch @@ -802,6 +816,10 @@ def prepare_current_keyframe(self, t: Tensor): self.current_cameractrl_effect = self.current_keyframe.cameractrl_multival elif not self.current_keyframe.inherit_missing: self.current_cameractrl_effect = None + if self.current_keyframe.has_pia_input(): + self.current_pia_input = self.current_keyframe.pia_input + elif not self.current_keyframe.inherit_missing: + self.current_pia_input = None # if guarantee_steps greater than zero, stop searching for other keyframes if self.current_keyframe.guarantee_steps > 0: break @@ -814,6 +832,8 @@ def prepare_current_keyframe(self, t: Tensor): self.combined_scale = get_combined_multival(self.scale_multival, self.current_scale) self.combined_effect = get_combined_multival(self.effect_multival, self.current_effect) self.combined_cameractrl_effect = get_combined_multival(self.cameractrl_multival, self.current_cameractrl_effect) + self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x) + self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input) # apply scale and effect self.model.set_scale(self.combined_scale) self.model.set_effect(self.combined_effect) @@ -889,6 +909,72 @@ def prepare_camera_features(self, x: Tensor, cond_or_uncond: list[int], ad_param self.prev_sub_idxs = sub_idxs self.prev_batched_number = batched_number + def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: + # if have cached shape, check if matches - if so, return cached pia_latents + if self.prev_pia_latents_shape is not None: + if self.prev_pia_latents_shape[0] == x.shape[0] and self.prev_pia_latents_shape[2] == x.shape[2] and self.prev_pia_latents_shape[3] == x.shape[3]: + # if mask is also the same for this timestep, then return cached + if self.prev_current_pia_input == self.current_pia_input: + return self.cached_pia_c_concat + # otherwise, adjust new mask, and create new cached_pia_c_concat + b, c, h ,w = x.shape + mask = prepare_mask_batch(self.combined_pia_mask, x.shape) + mask = extend_to_batch_size(mask, b) + # make sure to update prev_current_pia_input to know when is changed + self.prev_current_pia_input = self.current_pia_input + # TODO: handle self.combined_pia_effect eventually (feature hidden for now) + # the first index in dim=1 is the mask that needs to be updated - update in place + self.cached_pia_c_concat.cond[:, :1, :, :] = mask + return self.cached_pia_c_concat + self.prev_pia_latents_shape = None + # otherwise, x shape should be the cached pia_latents_shape + # get currently used models so they can be properly reloaded after perfoming VAE Encoding + if hasattr(comfy.model_management, "loaded_models"): + cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + else: + cached_loaded_models: list[ModelPatcherAndInjector] = [x.model for x in comfy.model_management.current_loaded_models] + try: + b, c, h ,w = x.shape + usable_ref = self.orig_pia_images[:b] + # in diffusers, the image is scaled from [-1, 1] instead of default [0, 1], + # but form my testing, that blows out the images here, so I skip it + # usable_images = usable_images * 2 - 1 + # resize images to latent's dims + usable_ref = usable_ref.movedim(-1,1) + usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.pia_vae.downscale_ratio, height=h*self.pia_vae.downscale_ratio, + upscale_method="bilinear", crop="center") + usable_ref = usable_ref.movedim(1,-1) + # VAE encode images + logger.info("VAE Encoding PIA input images...") + usable_ref = model.process_latent_in(vae_encode_raw_batched(vae=self.pia_vae, pixels=usable_ref, show_pbar=False)) + logger.info("VAE Encoding PIA input images complete.") + # make pia_latents match expected length + usable_ref = extend_to_batch_size(usable_ref, b) + self.prev_pia_latents_shape = x.shape + # now, take care of the mask + mask = prepare_mask_batch(self.combined_pia_mask, x.shape) + mask = extend_to_batch_size(mask, b) + #mask = mask.unsqueeze(1) + self.prev_current_pia_input = self.current_pia_input + if type(self.combined_pia_effect) == Tensor or not math.isclose(self.combined_pia_effect, 1.0): + real_pia_effect = self.combined_pia_effect + if type(self.combined_pia_effect) == Tensor: + real_pia_effect = extend_to_batch_size(prepare_mask_batch(self.combined_pia_effect, x.shape), b) + zero_mask = torch.zeros_like(mask) + mask = mask * real_pia_effect + zero_mask * (1.0 - real_pia_effect) + del zero_mask + zero_usable_ref = torch.zeros_like(usable_ref) + usable_ref = usable_ref * real_pia_effect + zero_usable_ref * (1.0 - real_pia_effect) + del zero_usable_ref + # cache pia c_concat + self.cached_pia_c_concat = comfy.conds.CONDNoiseShape(torch.cat([mask, usable_ref], dim=1)) + return self.cached_pia_c_concat + finally: + comfy.model_management.load_models_gpu(cached_loaded_models) + + def is_pia(self): + return self.model.mm_info.mm_format == AnimateDiffFormat.PIA and self.orig_pia_images is not None + def cleanup(self): if self.model is not None: self.model.cleanup() @@ -900,6 +986,9 @@ def cleanup(self): del self.camera_features self.camera_features = None self.camera_features_shape = None + # PIA + self.combined_pia_mask = None + self.combined_pia_effect = None # Default self.current_used_steps = 0 self.current_keyframe = None @@ -943,6 +1032,11 @@ def clone(self): # CameraCtrl n.orig_camera_entries = self.orig_camera_entries n.cameractrl_multival = self.cameractrl_multival + # PIA + n.orig_pia_images = self.orig_pia_images + n.pia_vae = self.pia_vae + n.pia_input = self.pia_input + n.pia_multival = self.pia_multival return n @@ -995,9 +1089,16 @@ def cleanup(self): for motion_model in self.models: motion_model.cleanup() - def prepare_current_keyframe(self, t: Tensor): + def prepare_current_keyframe(self, x: Tensor, t: Tensor): + for motion_model in self.models: + motion_model.prepare_current_keyframe(x=x, t=t) + + def get_pia_models(self): + pia_motion_models: list[MotionModelPatcher] = [] for motion_model in self.models: - motion_model.prepare_current_keyframe(t=t) + if motion_model.is_pia(): + pia_motion_models.append(motion_model) + return pia_motion_models def get_name_string(self, show_version=False): identifiers = [] @@ -1161,6 +1262,14 @@ def inject_img_encoder_into_model(motion_model: MotionModelPatcher, w_encoder: M motion_model.model.img_encoder.load_state_dict(w_encoder.model.img_encoder.state_dict()) +def inject_pia_conv_in_into_model(motion_model: MotionModelPatcher, w_pia: MotionModelPatcher): + motion_model.model.init_conv_in(w_pia.model.state_dict()) + motion_model.model.conv_in.to(comfy.model_management.unet_dtype()) + motion_model.model.conv_in.to(comfy.model_management.unet_offload_device()) + motion_model.model.conv_in.load_state_dict(w_pia.model.conv_in.state_dict()) + motion_model.model.mm_info.mm_format = AnimateDiffFormat.PIA + + def inject_camera_encoder_into_model(motion_model: MotionModelPatcher, camera_ctrl_name: str): camera_ctrl_path = get_motion_model_path(camera_ctrl_name) full_state_dict = comfy.utils.load_torch_file(camera_ctrl_path, safe_load=True) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 48d147d..4ee3b63 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -8,6 +8,7 @@ from comfy.ldm.modules.attention import FeedForward, SpatialTransformer from comfy.model_patcher import ModelPatcher +from comfy.model_base import BaseModel from comfy.ldm.modules.diffusionmodules import openaimodel from comfy.ldm.modules.diffusionmodules.openaimodel import SpatialTransformer from comfy.controlnet import broadcast_image_to @@ -35,6 +36,9 @@ class AnimateDiffFormat: ANIMATEDIFF = "AnimateDiff" HOTSHOTXL = "HotshotXL" ANIMATELCM = "AnimateLCM" + PIA = "PIA" + + _LIST = [ANIMATEDIFF, HOTSHOTXL, ANIMATELCM, PIA] class AnimateDiffVersion: @@ -42,6 +46,8 @@ class AnimateDiffVersion: V2 = "v2" V3 = "v3" + _LIST = [V1, V2, V3] + class AnimateDiffInfo: def __init__(self, sd_type: str, mm_format: str, mm_version: str, mm_name: str): @@ -70,6 +76,13 @@ def is_animatelcm(mm_state_dict: dict[str, Tensor]) -> bool: return True +def is_pia(mm_state_dict: dict[str, Tensor]) -> bool: + # check if conv_in.weight and .bias are present + if "conv_in.weight" in mm_state_dict and "conv_in.bias" in mm_state_dict: + return True + return False + + def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int: # keep track of biggest down_block count in module biggest_block = 0 @@ -140,6 +153,8 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T mm_format = AnimateDiffFormat.HOTSHOTXL if is_animatelcm(mm_state_dict): mm_format = AnimateDiffFormat.ANIMATELCM + if is_pia(mm_state_dict): + mm_format = AnimateDiffFormat.PIA # for AnimateLCM-I2V purposes, check for img_encoder keys contains_img_encoder = has_img_encoder(mm_state_dict) # remove all non-temporal keys (in case model has extra stuff in it) @@ -147,6 +162,8 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T if "temporal" not in key: if mm_format == AnimateDiffFormat.ANIMATELCM and contains_img_encoder and key.startswith("img_encoder."): continue + if mm_format == AnimateDiffFormat.PIA and key.startswith("conv_in."): + continue del mm_state_dict[key] # determine the model's version mm_version = AnimateDiffVersion.V1 @@ -215,11 +232,18 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo): self.mid_block = MotionModule(1280, temporal_pe=self.has_position_encoding, temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.MID, ops=ops) self.AD_video_length: int = 24 - # create AdapterEmbed if keys present for it + self.effect_model = 1.0 + # AnimateLCM-I2V stuff - create AdapterEmbed if keys present for it self.img_encoder: AdapterEmbed = None if has_img_encoder(mm_state_dict): self.init_img_encoder() + # CameraCtrl stuff self.camera_encoder: 'CameraPoseEncoder' = None + # PIA stuff - create conv_in if keys are present for it + self.conv_in: comfy.ops.disable_weight_init.Conv2d = None + self.orig_conv_in: comfy.ops.disable_weight_init.Conv2d = None + if is_pia(mm_state_dict): + self.init_conv_in(mm_state_dict) def init_img_encoder(self): del self.img_encoder @@ -229,6 +253,20 @@ def set_camera_encoder(self, camera_encoder: 'CameraPoseEncoder'): del self.camera_encoder self.camera_encoder = camera_encoder + def init_conv_in(self, mm_state_dict: dict[str, Tensor]): + ''' + Used for PIA + ''' + del self.conv_in + # hardcoded values, for now + # dim=2, in_channels=9, model_channels=320, kernel=3, padding=1, + # dtype=comfy.model_management.unet_dtype(), device=offload_device + in_channels = mm_state_dict["conv_in.weight"].size(1) # expected to be 9 + model_channels = mm_state_dict["conv_in.weight"].size(0) # expected to be 320 + # create conv_in with proper params + self.conv_in = self.ops.conv_nd(2, in_channels, model_channels, 3, padding=1, + dtype=comfy.model_management.unet_dtype(), device=comfy.model_management.unet_offload_device()) + def get_device_debug(self): return self.down_blocks[0].motion_modules[0].temporal_transformer.proj_in.weight.device @@ -338,6 +376,35 @@ def _eject(self, unet_blocks: nn.ModuleList): for idx in sorted(idx_to_pop, reverse=True): block.pop(idx) + def inject_unet_conv_in_pia(self, model: BaseModel): + if self.conv_in is None: + return + # TODO: make sure works with lowvram + # expected conv_in is in the first input block, and is the first module + self.orig_conv_in = model.diffusion_model.input_blocks[0][0] + + present_state_dict: dict[str, Tensor] = self.orig_conv_in.state_dict() + new_state_dict: dict[str, Tensor] = self.conv_in.state_dict() + # bias stays the same, but weight needs to inherit first in_channels from model + combined_state_dict = {} + combined_state_dict["bias"] = present_state_dict["bias"] + combined_state_dict["weight"] = torch.cat([present_state_dict["weight"], + new_state_dict["weight"][:, 4:, :, :].to(dtype=present_state_dict["weight"].dtype, + device=present_state_dict["weight"].device)], dim=1) + # create combined_conv_in with proper params + in_channels = new_state_dict["weight"].size(1) # expected to be 9 + model_channels = present_state_dict["weight"].size(0) # expected to be 320 + combined_conv_in = self.ops.conv_nd(2, in_channels, model_channels, 3, padding=1, + dtype=present_state_dict["weight"].dtype, device=present_state_dict["weight"].device) + combined_conv_in.load_state_dict(combined_state_dict) + # now can apply combined_conv_in to unet block + model.diffusion_model.input_blocks[0][0] = combined_conv_in + + def restore_unet_conv_in_pia(self, model: BaseModel): + if self.orig_conv_in is not None: + model.diffusion_model.input_blocks[0][0] = self.orig_conv_in.to(model.diffusion_model.input_blocks[0][0].weight.device) + self.orig_conv_in = None + def set_video_length(self, video_length: int, full_length: int): self.AD_video_length = video_length if self.down_blocks is not None: @@ -360,6 +427,12 @@ def set_scale(self, multival: Union[float, Tensor]): self._set_scale_mask(None) def set_effect(self, multival: Union[float, Tensor]): + # keep track of if model is in effect + if multival is None: + self.effect_model = 1.0 + else: + self.effect_model = multival + # pass down effect multival to all blocks if self.down_blocks is not None: for block in self.down_blocks: block.set_effect(multival) @@ -369,6 +442,11 @@ def set_effect(self, multival: Union[float, Tensor]): if self.mid_block is not None: self.mid_block.set_effect(multival) + def is_in_effect(self): + if type(self.effect_model) == Tensor: + return True + return not math.isclose(self.effect_model, 0.0) + def set_cameractrl_effect(self, multival: Union[float, Tensor]): # cameractrl should only impact down and up blocks if self.down_blocks is not None: diff --git a/animatediff/nodes.py b/animatediff/nodes.py index e9f9669..6a8e7cf 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -9,7 +9,8 @@ from .nodes_cameractrl import (LoadAnimateDiffModelWithCameraCtrl, ApplyAnimateDiffWithCameraCtrl, CameraCtrlADKeyframeNode, LoadCameraPoses, CameraCtrlPoseBasic, CameraCtrlPoseCombo, CameraCtrlPoseAdvanced, CameraCtrlManualAppendPose, CameraCtrlReplaceCameraParameters, CameraCtrlSetOriginalAspectRatio) -from .nodes_multival import MultivalDynamicNode, MultivalScaledMaskNode +from .nodes_pia import (ApplyAnimateDiffPIAModel, LoadAnimateDiffAndInjectPIANode, InputPIA_MultivalNode, InputPIA_PaperPresetsNode, PIA_ADKeyframeNode) +from .nodes_multival import MultivalDynamicNode, MultivalScaledMaskNode, MultivalDynamicFloatInputNode, MultivalConvertToMaskNode from .nodes_conditioning import (MaskableLoraLoader, MaskableLoraLoaderModelOnly, MaskableSDModelLoader, MaskableSDModelLoaderModelOnly, SetModelLoraHook, SetClipLoraHook, CombineLoraHooks, CombineLoraHookFourOptional, CombineLoraHookEightOptional, @@ -19,7 +20,7 @@ ConditioningTimestepsNode, SetLoraHookKeyframes, CreateLoraHookKeyframe, CreateLoraHookKeyframeInterpolation, CreateLoraHookKeyframeFromStrengthList) from .nodes_sample import (FreeInitOptionsNode, NoiseLayerAddWeightedNode, SampleSettingsNode, NoiseLayerAddNode, NoiseLayerReplaceNode, IterationOptionsNode, - CustomCFGNode, CustomCFGKeyframeNode) + CustomCFGNode, CustomCFGKeyframeNode, NoisedImageInjectionNode, NoisedImageInjectOptionsNode) from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode) from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode, StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode) @@ -45,7 +46,9 @@ "ADE_AnimateDiffKeyframe": ADKeyframeNode, # Multival Nodes "ADE_MultivalDynamic": MultivalDynamicNode, + "ADE_MultivalDynamicFloatInput": MultivalDynamicFloatInputNode, "ADE_MultivalScaledMask": MultivalScaledMaskNode, + "ADE_MultivalConvertToMask": MultivalConvertToMaskNode, # Context Opts "ADE_StandardStaticContextOptions": StandardStaticContextOptionsNode, "ADE_StandardUniformContextOptions": StandardUniformContextOptionsNode, @@ -104,6 +107,8 @@ "ADE_SigmaScheduleWeightedAverage": WeightedAverageSigmaScheduleNode, "ADE_SigmaScheduleWeightedAverageInterp": InterpolatedWeightedAverageSigmaScheduleNode, "ADE_SigmaScheduleSplitAndCombine": SplitAndCombineSigmaScheduleNode, + "ADE_NoisedImageInjection": NoisedImageInjectionNode, + "ADE_NoisedImageInjectOptions": NoisedImageInjectOptionsNode, # Extras Nodes "ADE_AnimateDiffUnload": AnimateDiffUnload, "ADE_EmptyLatentImageLarge": EmptyLatentImageLarge, @@ -132,8 +137,12 @@ "ADE_CameraManualPoseAppend": CameraCtrlManualAppendPose, "ADE_ReplaceCameraParameters": CameraCtrlReplaceCameraParameters, "ADE_ReplaceOriginalPoseAspectRatio": CameraCtrlSetOriginalAspectRatio, - # MaskedLoraLoader - #"ADE_MaskedLoadLora": MaskedLoraLoader, + # PIA Nodes + "ADE_ApplyAnimateDiffModelWithPIA": ApplyAnimateDiffPIAModel, + "ADE_InputPIA_Multival": InputPIA_MultivalNode, + "ADE_InputPIA_PaperPresets": InputPIA_PaperPresetsNode, + "ADE_PIA_AnimateDiffKeyframe": PIA_ADKeyframeNode, + "ADE_InjectPIAIntoAnimateDiffModel": LoadAnimateDiffAndInjectPIANode, # Deprecated Nodes "AnimateDiffLoaderV1": AnimateDiffLoader_Deprecated, "ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvanced_Deprecated, @@ -150,7 +159,9 @@ "ADE_AnimateDiffKeyframe": "AnimateDiff Keyframe πŸŽ­πŸ…πŸ…“", # Multival Nodes "ADE_MultivalDynamic": "Multival Dynamic πŸŽ­πŸ…πŸ…“", + "ADE_MultivalDynamicFloatInput": "Multival Dynamic [Float List] πŸŽ­πŸ…πŸ…“", "ADE_MultivalScaledMask": "Multival Scaled Mask πŸŽ­πŸ…πŸ…“", + "ADE_MultivalConvertToMask": "Multival to Mask πŸŽ­πŸ…πŸ…“", # Context Opts "ADE_StandardStaticContextOptions": "Context Optionsβ—†Standard Static πŸŽ­πŸ…πŸ…“", "ADE_StandardUniformContextOptions": "Context Optionsβ—†Standard Uniform πŸŽ­πŸ…πŸ…“", @@ -209,6 +220,8 @@ "ADE_SigmaScheduleWeightedAverage": "Sigma Schedule Weighted Mean πŸŽ­πŸ…πŸ…“", "ADE_SigmaScheduleWeightedAverageInterp": "Sigma Schedule Interpolated Mean πŸŽ­πŸ…πŸ…“", "ADE_SigmaScheduleSplitAndCombine": "Sigma Schedule Split Combine πŸŽ­πŸ…πŸ…“", + "ADE_NoisedImageInjection": "Image Injection πŸŽ­πŸ…πŸ…“", + "ADE_NoisedImageInjectOptions": "Image Injection Options πŸŽ­πŸ…πŸ…“", # Extras Nodes "ADE_AnimateDiffUnload": "AnimateDiff Unload πŸŽ­πŸ…πŸ…“", "ADE_EmptyLatentImageLarge": "Empty Latent Image (Big Batch) πŸŽ­πŸ…πŸ…“", @@ -237,8 +250,12 @@ "ADE_CameraManualPoseAppend": "Manual Append CameraCtrl Poses πŸŽ­πŸ…πŸ…“β‘‘", "ADE_ReplaceCameraParameters": "Replace Camera Parameters πŸŽ­πŸ…πŸ…“β‘‘", "ADE_ReplaceOriginalPoseAspectRatio": "Replace Orig. Pose Aspect Ratio πŸŽ­πŸ…πŸ…“β‘‘", - # MaskedLoraLoader - #"ADE_MaskedLoadLora": "Load LoRA (Masked) πŸŽ­πŸ…πŸ…“", + # PIA Nodes + "ADE_ApplyAnimateDiffModelWithPIA": "Apply AnimateDiff-PIA Model πŸŽ­πŸ…πŸ…“β‘‘", + "ADE_InputPIA_Multival": "PIA Input [Multival] πŸŽ­πŸ…πŸ…“β‘‘", + "ADE_InputPIA_PaperPresets": "PIA Input [Paper Presets] πŸŽ­πŸ…πŸ…“β‘‘", + "ADE_PIA_AnimateDiffKeyframe": "AnimateDiff-PIA Keyframe πŸŽ­πŸ…πŸ…“", + "ADE_InjectPIAIntoAnimateDiffModel": "πŸ§ͺInject PIA into AnimateDiff Model πŸŽ­πŸ…πŸ…“β‘‘", # Deprecated Nodes "AnimateDiffLoaderV1": "🚫AnimateDiff Loader [DEPRECATED] πŸŽ­πŸ…πŸ…“", "ADE_AnimateDiffLoaderV1Advanced": "🚫AnimateDiff Loader (Advanced) [DEPRECATED] πŸŽ­πŸ…πŸ…“", diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index 3b9866a..f06e819 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -7,7 +7,7 @@ from .context import ContextOptionsGroup from .logger import logger from .utils_model import BIGMAX, BetaSchedules, get_available_motion_models -from .utils_motion import ADKeyframeGroup, ADKeyframe +from .utils_motion import ADKeyframeGroup, ADKeyframe, InputPIA from .motion_lora import MotionLoraList from .model_injection import (InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher, create_fresh_motion_module, load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2) @@ -203,13 +203,14 @@ def INPUT_TYPES(s): def load_keyframe(self, start_percent: float, prev_ad_keyframes=None, scale_multival: Union[float, torch.Tensor]=None, effect_multival: Union[float, torch.Tensor]=None, - cameractrl_multival: Union[float, torch.Tensor]=None, + cameractrl_multival: Union[float, torch.Tensor]=None, pia_input: InputPIA=None, inherit_missing: bool=True, guarantee_steps: int=1): if not prev_ad_keyframes: prev_ad_keyframes = ADKeyframeGroup() prev_ad_keyframes = prev_ad_keyframes.clone() keyframe = ADKeyframe(start_percent=start_percent, - scale_multival=scale_multival, effect_multival=effect_multival, cameractrl_multival=cameractrl_multival, + scale_multival=scale_multival, effect_multival=effect_multival, + cameractrl_multival=cameractrl_multival, pia_input=pia_input, inherit_missing=inherit_missing, guarantee_steps=guarantee_steps) prev_ad_keyframes.add(keyframe) return (prev_ad_keyframes,) diff --git a/animatediff/nodes_multival.py b/animatediff/nodes_multival.py index 0abfe28..25d1098 100644 --- a/animatediff/nodes_multival.py +++ b/animatediff/nodes_multival.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -from .utils_motion import linear_conversion, normalize_min_max, extend_to_batch_size +from .utils_motion import linear_conversion, normalize_min_max, extend_to_batch_size, extend_list_to_batch_size class ScaleType: @@ -40,7 +40,7 @@ def create_multival(self, float_val: Union[float, list[float]]=1.0, mask_optiona if mask_optional is not None: if len(float_val) < mask_optional.shape[0]: # copies last entry enough times to match mask shape - float_val = float_val + float_val[-1]*(mask_optional.shape[0]-len(float_val)) + float_val = extend_list_to_batch_size(float_val, mask_optional.shape[0]) if mask_optional.shape[0] < len(float_val): mask_optional = extend_to_batch_size(mask_optional, len(float_val)) float_val = float_val[:mask_optional.shape[0]] @@ -84,12 +84,25 @@ def INPUT_TYPES(s): FUNCTION = "create_multival" def create_multival(self, min_float_val: float, max_float_val: float, mask: Tensor, scaling: str=ScaleType.ABSOLUTE): - # TODO: allow min_float_val and max_float_val to be list[float] + lengths = [mask.shape[0]] + iterable_inputs = [False, False] + val_inputs = [min_float_val, max_float_val] if isinstance(min_float_val, Iterable): - raise ValueError(f"min_float_val must be type float (no lists allowed here), not {type(min_float_val).__name__}.") + iterable_inputs[0] = True + val_inputs[0] = list(min_float_val) + lengths.append(len(min_float_val)) if isinstance(max_float_val, Iterable): - raise ValueError(f"max_float_val must be type float (no lists allowed here), not {type(max_float_val).__name__}.") - + iterable_inputs[1] = True + val_inputs[1] = list(max_float_val) + lengths.append(len(max_float_val)) + # make sure mask and any iterable float_vals match max length + max_length = max(lengths) + mask = extend_to_batch_size(mask, max_length) + for i in range(len(iterable_inputs)): + if iterable_inputs[i] == True: + # make sure tensors will match dimensions of mask + val_inputs[i] = torch.tensor(extend_list_to_batch_size(val_inputs[i], max_length)).unsqueeze(-1).unsqueeze(-1) + min_float_val, max_float_val = val_inputs if scaling == ScaleType.ABSOLUTE: mask = linear_conversion(mask.clone(), new_min=min_float_val, new_max=max_float_val) elif scaling == ScaleType.RELATIVE: @@ -134,3 +147,26 @@ def INPUT_TYPES(s): def create_multival(self, float_val: Union[float, list[float]]=None): return MultivalDynamicNode.create_multival(self, float_val=float_val) + + +class MultivalConvertToMaskNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "multival": ("MULTIVAL",) + } + } + + RETURN_TYPES = ("MASK",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/multival" + FUNCTION = "convert_multival_to_mask" + + def convert_multival_to_mask(self, multival: Union[float, Tensor]): + # if already tensor, assume is a valid mask + if type(multival) == Tensor: + return (multival,) + # otherwise, make a single 1x1 mask with the proper value + shape = (1,1,1) + converted_multival = torch.ones(shape) * multival + return (converted_multival,) diff --git a/animatediff/nodes_pia.py b/animatediff/nodes_pia.py new file mode 100644 index 0000000..fc8024d --- /dev/null +++ b/animatediff/nodes_pia.py @@ -0,0 +1,264 @@ +from typing import Union +import torch +from torch import Tensor +import math + +from comfy.sd import VAE + +from .ad_settings import AnimateDiffSettings +from .logger import logger +from .utils_model import BIGMIN, BIGMAX, get_available_motion_models +from .utils_motion import ADKeyframeGroup, InputPIA, InputPIA_Multival, extend_list_to_batch_size, extend_to_batch_size, prepare_mask_batch +from .motion_lora import MotionLoraList +from .model_injection import MotionModelGroup, MotionModelPatcher, load_motion_module_gen2, inject_pia_conv_in_into_model +from .motion_module_ad import AnimateDiffFormat +from .nodes_gen2 import ApplyAnimateDiffModelNode, ADKeyframeNode + + +# Preset values ported over from PIA repository: +# https://github.com/open-mmlab/PIA/blob/main/animatediff/utils/util.py +class PIA_RANGES: + ANIMATION_SMALL = "Animation (Small Motion)" + ANIMATION_MEDIUM = "Animation (Medium Motion)" + ANIMATION_LARGE = "Animation (Large Motion)" + LOOP_SMALL = "Loop (Small Motion)" + LOOP_MEDIUM = "Loop (Medium Motion)" + LOOP_LARGE = "Loop (Large Motion)" + STYLE_TRANSFER_SMALL = "Style Transfer (Small Motion)" + STYLE_TRANSFER_MEDIUM = "Style Transfer (Medium Motion)" + STYLE_TRANSFER_LARGE = "Style Transfer (Large Motion)" + + _LOOPED = [LOOP_SMALL, LOOP_MEDIUM, LOOP_LARGE] + _LIST_ALL = [ANIMATION_SMALL, ANIMATION_MEDIUM, ANIMATION_LARGE, + LOOP_SMALL, LOOP_MEDIUM, LOOP_LARGE, + STYLE_TRANSFER_SMALL, STYLE_TRANSFER_MEDIUM, STYLE_TRANSFER_LARGE] + + _MAPPING = { + ANIMATION_SMALL: [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], + ANIMATION_MEDIUM: [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], + ANIMATION_LARGE: [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], + LOOP_SMALL: [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], + LOOP_MEDIUM: [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], + LOOP_LARGE: [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5], + STYLE_TRANSFER_SMALL: [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], + STYLE_TRANSFER_MEDIUM: [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], + STYLE_TRANSFER_LARGE: [0.5, 0.2], + } + + @classmethod + def get_preset(cls, preset: str) -> list[float]: + if preset in cls._MAPPING: + return cls._MAPPING[preset] + raise Exception(f"PIA Preset '{preset}' is not recognized.") + + @classmethod + def is_looped(cls, preset: str) -> bool: + return preset in cls._LOOPED + + +class InputPIA_PaperPresets(InputPIA): + def __init__(self, preset: str, index: int, mult_multival: Union[float, Tensor]=None, effect_multival: Union[float, Tensor]=None): + super().__init__(effect_multival=effect_multival) + self.preset = preset + self.index = index + self.mult_multival = mult_multival if mult_multival is not None else 1.0 + + def get_mask(self, x: Tensor): + b, c, h, w = x.shape + values = PIA_RANGES.get_preset(self.preset) + # if preset is looped, make values loop + if PIA_RANGES.is_looped(self.preset): + # even length + if b % 2 == 0: + # extend to half length to get half of the loop + values = extend_list_to_batch_size(values, b // 2) + # apply second half of loop (just reverse it) + values += list(reversed(values)) + # odd length + else: + inter_values = extend_list_to_batch_size(values, b // 2) + middle_vals = [values[min(len(inter_values), len(values)-1)]] + # make middle vals long enough to fill in gaps (or none if not needed) + middle_vals = middle_vals * (max(0, b-2*len(inter_values))) + values = inter_values + middle_vals + list(reversed(inter_values)) + # otherwise, just extend values to desired length + else: + values = extend_list_to_batch_size(values, b) + assert len(values) == b + + index = self.index + # handle negative index + if index < 0: + index = b + index + # constrain index between 0 and b-1 + index = max(0, min(b-1, index)) + # center values around targer index + order = [abs(i - index) for i in range(b)] + real_values = [values[order[i]] for i in range(b)] + # using real values, generate masks + tensor_values = torch.tensor(real_values).unsqueeze(-1).unsqueeze(-1) + mask = torch.ones(size=(b, h, w)) * tensor_values + # apply multi_multival to mask + if type(self.mult_multival) == Tensor or not math.isclose(self.mult_multival, 1.0): + real_mult = self.mult_multival + if type(real_mult) == Tensor: + real_mult = extend_to_batch_size(prepare_mask_batch(real_mult, x.shape), b).squeeze(1) + mask = mask * real_mult + return mask + + +class ApplyAnimateDiffPIAModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "motion_model": ("MOTION_MODEL_ADE",), + "image": ("IMAGE",), + "vae": ("VAE",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "optional": { + "pia_input": ("PIA_INPUT",), + "motion_lora": ("MOTION_LORA",), + "scale_multival": ("MULTIVAL",), + "effect_multival": ("MULTIVAL",), + "ad_keyframes": ("AD_KEYFRAMES",), + "prev_m_models": ("M_MODELS",), + } + } + + RETURN_TYPES = ("M_MODELS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA" + FUNCTION = "apply_motion_model" + + def apply_motion_model(self, motion_model: MotionModelPatcher, image: Tensor, vae: VAE, + start_percent: float=0.0, end_percent: float=1.0, pia_input: InputPIA=None, + motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, + scale_multival=None, effect_multival=None, ref_multival=None, + prev_m_models: MotionModelGroup=None,): + new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, + motion_lora=motion_lora, ad_keyframes=ad_keyframes, + scale_multival=scale_multival, effect_multival=effect_multival, prev_m_models=prev_m_models) + # most recent added model will always be first in list; + curr_model = new_m_models[0].models[0] + # confirm that model is PIA + if curr_model.model.mm_info.mm_format != AnimateDiffFormat.PIA: + raise Exception(f"Motion model '{curr_model.model.mm_info.mm_name}' is not a PIA model; cannot be used with Apply AnimateDiff-PIA Model node.") + curr_model.orig_pia_images = image + curr_model.pia_vae = vae + if pia_input is None: + pia_input = InputPIA_Multival(1.0) + curr_model.pia_input = pia_input + #curr_model.pia_multival = ref_multival + return new_m_models + + +class LoadAnimateDiffAndInjectPIANode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (get_available_motion_models(),), + "motion_model": ("MOTION_MODEL_ADE",), + }, + "optional": { + "ad_settings": ("AD_SETTINGS",), + } + } + + RETURN_TYPES = ("MOTION_MODEL_ADE",) + RETURN_NAMES = ("MOTION_MODEL",) + + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA/πŸ§ͺexperimental" + FUNCTION = "load_motion_model" + + def load_motion_model(self, model_name: str, motion_model: MotionModelPatcher, ad_settings: AnimateDiffSettings=None): + # make sure model actually has PIA conv_in + if motion_model.model.conv_in is None: + raise Exception("Passed-in motion model was expected to be PIA (contain conv_in), but did not.") + # load motion module and motion settings, if included + loaded_motion_model = load_motion_module_gen2(model_name=model_name, motion_model_settings=ad_settings) + inject_pia_conv_in_into_model(motion_model=loaded_motion_model, w_pia=motion_model) + return (loaded_motion_model,) + + +class PIA_ADKeyframeNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ), + }, + "optional": { + "prev_ad_keyframes": ("AD_KEYFRAMES", ), + "scale_multival": ("MULTIVAL",), + "effect_multival": ("MULTIVAL",), + "pia_input": ("PIA_INPUT",), + "inherit_missing": ("BOOLEAN", {"default": True}, ), + "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + } + } + + RETURN_TYPES = ("AD_KEYFRAMES", ) + FUNCTION = "load_keyframe" + + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA" + + def load_keyframe(self, + start_percent: float, prev_ad_keyframes=None, + scale_multival: Union[float, torch.Tensor]=None, effect_multival: Union[float, torch.Tensor]=None, + pia_input: InputPIA=None, + inherit_missing: bool=True, guarantee_steps: int=1): + return ADKeyframeNode.load_keyframe(self, + start_percent=start_percent, prev_ad_keyframes=prev_ad_keyframes, + scale_multival=scale_multival, effect_multival=effect_multival, pia_input=pia_input, + inherit_missing=inherit_missing, guarantee_steps=guarantee_steps + ) + + +class InputPIA_MultivalNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "multival": ("MULTIVAL",), + }, + # "optional": { + # "effect_multival": ("MULTIVAL",), + # } + } + + RETURN_TYPES = ("PIA_INPUT",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA" + FUNCTION = "create_pia_input" + + def create_pia_input(self, multival: Union[float, Tensor], effect_multival: Union[float, Tensor]=None): + return (InputPIA_Multival(multival, effect_multival),) + + +class InputPIA_PaperPresetsNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "preset": (PIA_RANGES._LIST_ALL,), + "batch_index": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}), + }, + "optional": { + "mult_multival": ("MULTIVAL",), + "print_values": ("BOOLEAN", {"default": False},), + #"effect_multival": ("MULTIVAL",), + } + } + + RETURN_TYPES = ("PIA_INPUT",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA" + FUNCTION = "create_pia_input" + + def create_pia_input(self, preset: str, batch_index: int, mult_multival: Union[float, Tensor]=None, print_values: bool=False, effect_multival: Union[float, Tensor]=None): + # verify preset exists - function will throw error if does not + values = PIA_RANGES.get_preset(preset) + if print_values: + logger.info(f"PIA Preset '{preset}': {values}") + return (InputPIA_PaperPresets(preset=preset, index=batch_index, mult_multival=mult_multival, effect_multival=effect_multival),) diff --git a/animatediff/nodes_sample.py b/animatediff/nodes_sample.py index e43e55a..0744c66 100644 --- a/animatediff/nodes_sample.py +++ b/animatediff/nodes_sample.py @@ -1,11 +1,14 @@ from typing import Union from torch import Tensor +from comfy.sd import VAE + from .freeinit import FreeInitFilter from .sample_settings import (FreeInitOptions, IterationOptions, NoiseLayerAdd, NoiseLayerAddWeighted, NoiseLayerGroup, NoiseLayerReplace, NoiseLayerType, - SeedNoiseGeneration, SampleSettings, CustomCFGKeyframeGroup, CustomCFGKeyframe) -from .utils_model import BIGMIN, BIGMAX, SigmaSchedule + SeedNoiseGeneration, SampleSettings, CustomCFGKeyframeGroup, CustomCFGKeyframe, + NoisedImageToInjectGroup, NoisedImageToInject, NoisedImageInjectOptions) +from .utils_model import BIGMIN, BIGMAX, MAX_RESOLUTION, SigmaSchedule class SampleSettingsNode: @@ -25,6 +28,7 @@ def INPUT_TYPES(s): "adapt_denoise_steps": ("BOOLEAN", {"default": False},), "custom_cfg": ("CUSTOM_CFG",), "sigma_schedule": ("SIGMA_SCHEDULE",), + "image_inject": ("IMAGE_INJECT",), } } @@ -35,10 +39,10 @@ def INPUT_TYPES(s): def create_settings(self, batch_offset: int, noise_type: str, seed_gen: str, seed_offset: int, noise_layers: NoiseLayerGroup=None, iteration_opts: IterationOptions=None, seed_override: int=None, adapt_denoise_steps=False, - custom_cfg: CustomCFGKeyframeGroup=None, sigma_schedule: SigmaSchedule=None): + custom_cfg: CustomCFGKeyframeGroup=None, sigma_schedule: SigmaSchedule=None, image_inject: NoisedImageToInjectGroup=None): sampling_settings = SampleSettings(batch_offset=batch_offset, noise_type=noise_type, seed_gen=seed_gen, seed_offset=seed_offset, noise_layers=noise_layers, iteration_opts=iteration_opts, seed_override=seed_override, adapt_denoise_steps=adapt_denoise_steps, - custom_cfg=custom_cfg, sigma_schedule=sigma_schedule) + custom_cfg=custom_cfg, sigma_schedule=sigma_schedule, image_injection=image_inject) return (sampling_settings,) @@ -253,3 +257,61 @@ def create_custom_cfg(self, cfg_multival: Union[float, Tensor], start_percent: f keyframe = CustomCFGKeyframe(cfg_multival=cfg_multival, start_percent=start_percent, guarantee_steps=guarantee_steps) prev_custom_cfg.add(keyframe) return (prev_custom_cfg,) + + +class NoisedImageInjectionNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE", ), + "vae": ("VAE", ), + }, + "optional": { + "mask_opt": ("MASK", ), + "invert_mask": ("BOOLEAN", {"default": False}), + "resize_image": ("BOOLEAN", {"default": True}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "guarantee_steps": ("INT", {"default": 1, "min": 1, "max": BIGMAX}), + "img_inject_opts": ("IMAGE_INJECT_OPTIONS", ), + "strength_multival": ("MULTIVAL", ), + "prev_image_inject": ("IMAGE_INJECT", ), + } + } + + RETURN_TYPES = ("IMAGE_INJECT",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/sample settings/image inject" + FUNCTION = "create_image_inject" + + def create_image_inject(self, image: Tensor, vae: VAE, invert_mask: bool, resize_image: bool, start_percent: float, + mask_opt: Tensor=None, strength_multival: Union[float, Tensor]=None, prev_image_inject: NoisedImageToInjectGroup=None, guarantee_steps=1, + img_inject_opts=None): + if not prev_image_inject: + prev_image_inject = NoisedImageToInjectGroup() + prev_image_inject = prev_image_inject.clone() + to_inject = NoisedImageToInject(image=image, mask=mask_opt, vae=vae, invert_mask=invert_mask, resize_image=resize_image, strength_multival=strength_multival, + start_percent=start_percent, guarantee_steps=guarantee_steps, + img_inject_opts=img_inject_opts) + prev_image_inject.add(to_inject) + return (prev_image_inject,) + + +class NoisedImageInjectOptionsNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "composite_x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "composite_y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + } + } + + RETURN_TYPES = ("IMAGE_INJECT_OPTIONS",) + RETURN_NAMES = ("IMG_INJECT_OPTS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/sample settings/image inject" + FUNCTION = "create_image_inject_opts" + + def create_image_inject_opts(self, x=0, y=0): + return (NoisedImageInjectOptions(x=x, y=y),) diff --git a/animatediff/sample_settings.py b/animatediff/sample_settings.py index f46d907..818bba7 100644 --- a/animatediff/sample_settings.py +++ b/animatediff/sample_settings.py @@ -5,8 +5,10 @@ import comfy.sample import comfy.samplers +import comfy.model_management from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel +from comfy.sd import VAE from . import freeinit from .conditioning import LoraHookMode @@ -54,7 +56,7 @@ class NoiseNormalize: class SampleSettings: def __init__(self, batch_offset: int=0, noise_type: str=None, seed_gen: str=None, seed_offset: int=0, noise_layers: 'NoiseLayerGroup'=None, iteration_opts=None, seed_override:int=None, negative_cond_flipflop=False, adapt_denoise_steps: bool=False, - custom_cfg: 'CustomCFGKeyframeGroup'=None, sigma_schedule: SigmaSchedule=None): + custom_cfg: 'CustomCFGKeyframeGroup'=None, sigma_schedule: SigmaSchedule=None, image_injection: 'NoisedImageToInjectGroup'=None): self.batch_offset = batch_offset self.noise_type = noise_type if noise_type is not None else NoiseLayerType.DEFAULT self.seed_gen = seed_gen if seed_gen is not None else SeedNoiseGeneration.COMFY @@ -66,6 +68,7 @@ def __init__(self, batch_offset: int=0, noise_type: str=None, seed_gen: str=None self.adapt_denoise_steps = adapt_denoise_steps self.custom_cfg = custom_cfg.clone() if custom_cfg else custom_cfg self.sigma_schedule = sigma_schedule + self.image_injection = image_injection.clone() if image_injection else NoisedImageToInjectGroup() def prepare_noise(self, seed: int, latents: Tensor, noise: Tensor, extra_seed_offset=0, extra_args:dict={}, force_create_noise=True): if self.seed_override is not None: @@ -93,15 +96,20 @@ def prepare_noise(self, seed: int, latents: Tensor, noise: Tensor, extra_seed_of def pre_run(self, model: ModelPatcher): if self.custom_cfg is not None: self.custom_cfg.reset() + if self.image_injection is not None: + self.image_injection.reset() def cleanup(self): if self.custom_cfg is not None: self.custom_cfg.reset() + if self.image_injection is not None: + self.image_injection.reset() def clone(self): return SampleSettings(batch_offset=self.batch_offset, noise_type=self.noise_type, seed_gen=self.seed_gen, seed_offset=self.seed_offset, noise_layers=self.noise_layers.clone(), iteration_opts=self.iteration_opts, seed_override=self.seed_override, - negative_cond_flipflop=self.negative_cond_flipflop, adapt_denoise_steps=self.adapt_denoise_steps, custom_cfg=self.custom_cfg, sigma_schedule=self.sigma_schedule) + negative_cond_flipflop=self.negative_cond_flipflop, adapt_denoise_steps=self.adapt_denoise_steps, custom_cfg=self.custom_cfg, + sigma_schedule=self.sigma_schedule, image_injection=self.image_injection) class NoiseLayer: @@ -554,3 +562,166 @@ def cfg_multival(self): if self._current_keyframe != None: return self._current_keyframe.cfg_multival return None + + +class NoisedImageInjectOptions: + def __init__(self, x=0, y=0): + self.x = x + self.y = y + + def clone(self): + return NoisedImageInjectOptions(x=self.x, y=self.y) + + +class NoisedImageToInject: + def __init__(self, image: Tensor, mask: Tensor, vae: VAE, start_percent: float, guarantee_steps: int=1, + invert_mask=False, resize_image=True, strength_multival=None, + img_inject_opts: NoisedImageInjectOptions=None): + self.image = image + self.mask = mask + self.vae = vae + self.invert_mask = invert_mask + self.resize_image = resize_image + self.strength_multival = 1.0 if strength_multival is None else strength_multival + if img_inject_opts is None: + img_inject_opts = NoisedImageInjectOptions() + self.img_inject_opts = img_inject_opts + # scheduling + self.start_percent = float(start_percent) + self.start_t = 999999999.9 + self.start_timestep = 999 + self.guarantee_steps = guarantee_steps + + def clone(self): + cloned = NoisedImageToInject(image=self.image, vae=self.vae, start_percent=self.start_percent, + guarantee_steps=self.guarantee_steps, invert_mask=self.invert_mask, resize_image=self.resize_image, + img_inject_opts=self.img_inject_opts) + cloned.start_t = self.start_t + cloned.start_timestep = self.start_timestep + return cloned + + +class NoisedImageToInjectGroup: + def __init__(self): + self.injections: list[NoisedImageToInject] = [] + self._current_index: int = -1 + self._current_used_steps: int = 0 + + @property + def current_injection(self): + return self.injections[self._current_index] + + def reset(self): + self._current_index = -1 + self._current_used_steps: int = 0 + + def add(self, to_inject: NoisedImageToInject): + # add to end of list, then sort + self.injections.append(to_inject) + self.injections = get_sorted_list_via_attr(self.injections, "start_percent") + + def is_empty(self) -> bool: + return len(self.injections) == 0 + + def has_index(self, index: int) -> int: + return index >=0 and index < len(self.injections) + + def clone(self): + cloned = NoisedImageToInjectGroup() + for to_inject in self.injections: + cloned.injections.append(to_inject) + return cloned + + def initialize_timesteps(self, model: BaseModel): + for to_inject in self.injections: + to_inject.start_t = model.model_sampling.percent_to_sigma(to_inject.start_percent) + to_inject.start_timestep = model.model_sampling.timestep(torch.tensor(to_inject.start_t)) + + def ksampler_get_injections(self, model: ModelPatcher, scheduler: str, sampler_name: str, denoise: float, force_full_denoise: bool, start_step: int, last_step: int, total_steps: int) -> tuple[list[list[int]], list[NoisedImageToInject]]: + actual_last_step = min(last_step, total_steps) + steps = list(range(start_step, actual_last_step+1)) + # create sampler that will be used to get sigmas + sampler = comfy.samplers.KSampler(model, steps=total_steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) + # replicate KSampler.sample function to get the exact sigmas + sigmas = sampler.sigmas + if last_step is not None and last_step < (len(sigmas) - 1): + sigmas = sigmas[:last_step + 1] + if force_full_denoise: + sigmas[-1] = 0 + if start_step is not None: + if start_step < (len(sigmas) - 1): + sigmas = sigmas[start_step:] + else: + return [[start_step,actual_last_step], []] + assert len(steps) == len(sigmas) + model_sampling = model.get_model_object("model_sampling") + timesteps = [model_sampling.timestep(x) for x in sigmas] + # get actual ranges + injections + ranges, injections = self._prepare_injections(timesteps=timesteps) + # ranges are given with end-exclusive index, so subtract by 1 to get real step value + steps_list = [[steps[x[0]],steps[x[1]-1]] for x in ranges] + return steps_list, injections + + def custom_ksampler_get_injections(self, model: ModelPatcher, sigmas: Tensor) -> tuple[list[list[Tensor]], list[NoisedImageToInject]]: + model_sampling = model.get_model_object("model_sampling") + timesteps = [] + for i in range(sigmas.shape[0]): + timesteps.append(model_sampling.timestep(sigmas[i])) + # get actual ranges + injections + ranges, injections = self._prepare_injections(timesteps=timesteps) + sigmas_list = [sigmas[x[0]:x[1]] for x in ranges] + return sigmas_list, injections + + def _prepare_injections(self, timesteps: list[Tensor]) -> tuple[list[list[Tensor]], list[NoisedImageToInject]]: + range_start = timesteps[0] + range_end = timesteps[-1] + # if nothing to inject, return all indexes of timesteps and no injections + if self.is_empty(): + return ([(0, len(timesteps))], []) + # otherwise, need to populate lists + timesteps_list: list[list[Tensor]] = [] + injection_list: list[NoisedImageToInject] = [] + remaining_timesteps = timesteps.copy() + remaining_offset = 0 + # NOTE: timesteps start at 999 and end at 0; the smaller the timestep, the 'later' the step + for eval_c in self.injections: + if len(remaining_timesteps) <= 2: + break + current_used_steps = 0 + # if start_timestep is greater than range_start, ignore it + if eval_c.start_timestep > range_start: + continue + # if start_timestep is less than range_end, ignore it + if eval_c.start_timestep < range_end: + continue + while current_used_steps < eval_c.guarantee_steps: + if len(remaining_timesteps) <= 2: + break + # otherwise, make a split in timesteps + broken_nicely = False + for i in range(1, len(remaining_timesteps)-1): + # if smaller than timestep, look at next timestep + if eval_c.start_timestep < remaining_timesteps[i]: + continue + # if only one timestep would be leftover, then end + if len(remaining_timesteps[i:]) < 2: + broken_nicely = True + break + new_timestep_range = (remaining_offset, remaining_offset+i+1) + timesteps_list.append(new_timestep_range) + injection_list.append(eval_c) + current_used_steps += 1 + remaining_timesteps = remaining_timesteps[i:] + remaining_offset += i + # expected break + broken_nicely = True + break + # did not find a match for the timestep, so should break out of while loop + if not broken_nicely: + break + + # add remaining timestep range + timesteps_list.append((remaining_offset, remaining_offset+len(remaining_timesteps))) + # return lists - timesteps list len should be one greater than injection list len (fenceposts problem) + assert len(timesteps_list) == len(injection_list) + 1 + return timesteps_list, injection_list diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 32aa5ca..fb83362 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -21,12 +21,14 @@ import comfy.utils from comfy.controlnet import ControlBase from comfy.model_base import BaseModel +import comfy.conds import comfy.ops -from .conditioning import COND_CONST, LoraHookGroup +from .conditioning import COND_CONST, LoraHookGroup, conditioning_set_values from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows -from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration -from .utils_model import ModelTypeSD +from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration, NoisedImageToInject +from .utils_model import ModelTypeSD, vae_encode_raw_batched, vae_decode_raw_batched +from .utils_motion import composite_extend, get_combined_multival, prepare_mask_batch, extend_to_batch_size from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion, VanillaTemporalModule from .logger import logger @@ -41,6 +43,8 @@ def __init__(self): self.motion_models: MotionModelGroup = None self.params: InjectionParams = None self.sample_settings: SampleSettings = None + self.callback_output_dict: dict[str] = {} + self.function_injections: FunctionInjectionHolder = None self.reset() def initialize(self, model: BaseModel): @@ -63,9 +67,9 @@ def hooks_initialize(self, model: BaseModel, hook_groups: list[LoraHookGroup]): hook.reset() hook.initialize_timesteps(model) - def prepare_current_keyframes(self, timestep: Tensor): + def prepare_current_keyframes(self, x: Tensor, timestep: Tensor): if self.motion_models is not None: - self.motion_models.prepare_current_keyframe(t=timestep) + self.motion_models.prepare_current_keyframe(x=x, t=timestep) if self.params.context_options is not None: self.params.context_options.prepare_current_context(t=timestep) if self.sample_settings.custom_cfg is not None: @@ -75,6 +79,24 @@ def prepare_hooks_current_keyframes(self, timestep: Tensor, hook_groups: list[Lo if self.model_patcher is not None: self.model_patcher.prepare_hooked_patches_current_keyframe(t=timestep, hook_groups=hook_groups) + def perform_special_model_features(self, model: BaseModel, conds: list, x_in: Tensor): + if self.motion_models is not None: + pia_models = self.motion_models.get_pia_models() + if len(pia_models) > 0: + for pia_model in pia_models: + if pia_model.model.is_in_effect(): + pia_model.model.inject_unet_conv_in_pia(model) + conds = get_conds_with_c_concat(conds, + pia_model.get_pia_c_concat(model, x_in)) + return conds + + def restore_special_model_features(self, model: BaseModel): + if self.motion_models is not None: + pia_models = self.motion_models.get_pia_models() + if len(pia_models) > 0: + for pia_model in reversed(pia_models): + pia_model.model.restore_unet_conv_in_pia(model) + def reset(self): self.initialized = False self.hooks_initialized = False @@ -82,6 +104,8 @@ def reset(self): self.last_step: int = 0 self.current_step: int = 0 self.total_steps: int = 0 + self.callback_output_dict.clear() + self.callback_output_dict = {} if self.model_patcher is not None: self.model_patcher.clean_hooks() del self.model_patcher @@ -95,6 +119,9 @@ def reset(self): if self.sample_settings is not None: del self.sample_settings self.sample_settings = None + if self.function_injections is not None: + del self.function_injections + self.function_injections = None def update_with_inject_params(self, params: InjectionParams): self.params = params @@ -196,6 +223,14 @@ def apply_model_ade_wrapper(self, *args, **kwargs): del x return orig_apply_model(*args, **kwargs) return apply_model_ade_wrapper + +def diffusion_model_forward_groupnormed_factory(orig_diffusion_model_forward: Callable, inject_helper: 'GroupnormInjectHelper'): + def diffusion_model_forward_groupnormed(*args, **kwargs): + with inject_helper: + return orig_diffusion_model_forward(*args, **kwargs) + return diffusion_model_forward_groupnormed + + ###################################################################### ################################################################################## @@ -242,7 +277,8 @@ def apply_params_to_motion_models(motion_models: MotionModelGroup, params: Injec class FunctionInjectionHolder: def __init__(self): - pass + self.temp_uninjector: GroupnormUninjectHelper = GroupnormUninjectHelper() + self.groupnorm_injector: GroupnormInjectHelper = GroupnormInjectHelper() def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionParams): # Save Original Functions - order must match between here and restore_functions @@ -250,6 +286,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara self.orig_memory_required = model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights + self.orig_diffusion_model_forward = model.model.diffusion_model.forward self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult if SAMPLE_FALLBACK: # for backwards compatibility, for now @@ -262,12 +299,15 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara if params.unlimited_area_hack: model.model.memory_required = unlimited_memory_required if model.motion_models is not None: - # only apply groupnorm hack if not [v3 or ([not Hotshot] and SD1.5 and v2 and apply_v2_properly)] + # only apply groupnorm hack if PIA, v2 and not properly applied, or v1 info: AnimateDiffInfo = model.motion_models[0].model.mm_info - if not (info.mm_version == AnimateDiffVersion.V3 or - (info.mm_format not in [AnimateDiffFormat.HOTSHOTXL] and info.sd_type == ModelTypeSD.SD1_5 and info.mm_version == AnimateDiffVersion.V2 and params.apply_v2_properly)): - torch.nn.GroupNorm.forward = groupnorm_mm_factory(params) - comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) + if ((info.mm_format == AnimateDiffFormat.PIA) or + (info.mm_version == AnimateDiffVersion.V2 and not params.apply_v2_properly) or + (info.mm_version == AnimateDiffVersion.V1)): + self.inject_groupnorm_forward = groupnorm_mm_factory(params) + self.inject_groupnorm_forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) + self.groupnorm_injector = GroupnormInjectHelper(self) + model.model.diffusion_model.forward = diffusion_model_forward_groupnormed_factory(self.orig_diffusion_model_forward, self.groupnorm_injector) # if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack try: if model.load_device.type == "mps": @@ -286,6 +326,8 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara comfy.sample.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) else: comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) + # create temp_uninjector to help facilitate uninjecting functions + self.temp_uninjector = GroupnormUninjectHelper(self) def restore_functions(self, model: ModelPatcherAndInjector): # Restoration @@ -294,6 +336,7 @@ def restore_functions(self, model: ModelPatcherAndInjector): openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed torch.nn.GroupNorm.forward = self.orig_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights + model.model.diffusion_model.forward = self.orig_diffusion_model_forward comfy.samplers.sampling_function = self.orig_sampling_function comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult if SAMPLE_FALLBACK: # for backwards compatibility, for now @@ -306,6 +349,60 @@ def restore_functions(self, model: ModelPatcherAndInjector): "to save original functions before injection, and a more specific error was thrown by ComfyUI.") +class GroupnormUninjectHelper: + def __init__(self, holder: FunctionInjectionHolder=None): + self.holder = holder + self.previous_gn_forward = None + self.previous_dwi_gn_cast_weights = None + + def __enter__(self): + if self.holder is None: + return self + # backup current groupnorm funcs + self.previous_gn_forward = torch.nn.GroupNorm.forward + self.previous_dwi_gn_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights + # restore groupnorm to default state + torch.nn.GroupNorm.forward = self.holder.orig_groupnorm_forward + comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.holder.orig_groupnorm_forward_comfy_cast_weights + return self + + def __exit__(self, *args, **kwargs): + if self.holder is None: + return + # bring groupnorm back to previous state + torch.nn.GroupNorm.forward = self.previous_gn_forward + comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.previous_dwi_gn_cast_weights + self.previous_gn_forward = None + self.previous_dwi_gn_cast_weights = None + + +class GroupnormInjectHelper: + def __init__(self, holder: FunctionInjectionHolder=None): + self.holder = holder + self.previous_gn_forward = None + self.previous_dwi_gn_cast_weights = None + + def __enter__(self): + if self.holder is None: + return self + # store previous gn_forward + self.previous_gn_forward = torch.nn.GroupNorm.forward + self.previous_dwi_gn_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights + # inject groupnorm functions + torch.nn.GroupNorm.forward = self.holder.inject_groupnorm_forward + comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.holder.inject_groupnorm_forward_comfy_cast_weights + return self + + def __exit__(self, *args, **kwargs): + if self.holder is None: + return + # bring groupnorm back to previous state + torch.nn.GroupNorm.forward = self.previous_gn_forward + comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.previous_dwi_gn_cast_weights + self.previous_gn_forward = None + self.previous_dwi_gn_cast_weights = None + + def motion_sample_factory(orig_comfy_sample: Callable, is_custom: bool=False) -> Callable: def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs): # check if model is intended for injecting @@ -349,12 +446,16 @@ def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs def ad_callback(step, x0, x, total_steps): if original_callback is not None: original_callback(step, x0, x, total_steps) + # store denoised latents if image_injection will be used + if not model.sample_settings.image_injection.is_empty(): + ADGS.callback_output_dict["x0"] = x0 # update GLOBALSTATE for next iteration ADGS.current_step = ADGS.start_step + step + 1 kwargs["callback"] = ad_callback ADGS.model_patcher = model ADGS.motion_models = model.motion_models ADGS.sample_settings = model.sample_settings + ADGS.function_injections = function_injections # apply adapt_denoise_steps args = list(args) @@ -416,7 +517,73 @@ def ad_callback(step, x0, x, total_steps): model.motion_models.pre_run(model) if model.sample_settings is not None: model.sample_settings.pre_run(model) - latents = orig_comfy_sample(model, noise, *args, **kwargs) + + if ADGS.sample_settings.image_injection.is_empty(): + latents = orig_comfy_sample(model, noise, *args, **kwargs) + else: + ADGS.sample_settings.image_injection.initialize_timesteps(model.model) + # separate handling for KSampler vs Custom KSampler + if is_custom: + sigmas = args[2] + sigmas_list, injection_list = ADGS.sample_settings.image_injection.custom_ksampler_get_injections(model, sigmas) + # useful logging + if len(injection_list) > 0: + inj_str = "s" if len(injection_list) > 1 else "" + logger.info(f"Found {len(injection_list)} applicable image injection{inj_str}; sampling will be split into {len(sigmas_list)}.") + else: + logger.info(f"Found 0 applicable image injections within the step bounds of this sampler; sampling unaffected.") + is_first = True + new_noise = noise + for i in range(len(sigmas_list)): + args[2] = sigmas_list[i] + args[-1] = latents + latents = orig_comfy_sample(model, new_noise, *args, **kwargs) + if is_first: + new_noise = torch.zeros_like(latents) + # if injection expected, perform injection + if i < len(injection_list): + to_inject = injection_list[i] + latents = perform_image_injection(model.model, latents, to_inject) + else: + is_ksampler_advanced = kwargs.get("start_step", None) is not None + # force_full_denoise should be respected on final sampling - should be True for normal KSampler + final_force_full_denoise = kwargs.get("force_full_denoise", False) + new_kwargs = kwargs.copy() + if not is_ksampler_advanced: + final_force_full_denoise = True + new_kwargs["start_step"] = 0 + new_kwargs["last_step"] = 10000 + steps_list, injection_list = ADGS.sample_settings.image_injection.ksampler_get_injections(model, scheduler=args[-4], sampler_name=args[-5], denoise=kwargs["denoise"], force_full_denoise=final_force_full_denoise, + start_step=new_kwargs["start_step"], last_step=new_kwargs["last_step"], total_steps=args[0]) + # useful logging + if len(injection_list) > 0: + inj_str = "s" if len(injection_list) > 1 else "" + logger.info(f"Found {len(injection_list)} applicable image injection{inj_str}; sampling will be split into {len(steps_list)}.") + else: + logger.info(f"Found 0 applicable image injections within the step bounds of this sampler; sampling unaffected.") + is_first = True + new_noise = noise + for i in range(len(steps_list)): + steps_range = steps_list[i] + args[-1] = latents + # first run will respect original disable_noise, but should have no effect on anything + # as disable_noise only does something in the functions that call this one + if not is_first: + new_kwargs["disable_noise"] = True + new_kwargs["start_step"] = steps_range[0] + new_kwargs["last_step"] = steps_range[1] + # if is last, respect original sampler's force_full_denoise + if i == len(steps_list)-1: + new_kwargs["force_full_denoise"] = final_force_full_denoise + else: + new_kwargs["force_full_denoise"] = False + latents = orig_comfy_sample(model, new_noise, *args, **new_kwargs) + if is_first: + new_noise = torch.zeros_like(latents) + # if injection expected, perform injection + if i < len(injection_list): + to_inject = injection_list[i] + latents = perform_image_injection(model.model, latents, to_inject) return latents finally: del latents @@ -434,49 +601,104 @@ def ad_callback(step, x0, x, total_steps): return motion_sample -def evolved_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options: dict={}, seed=None): +def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, cond_scale, model_options: dict={}, seed=None): ADGS.initialize(model) - ADGS.prepare_current_keyframes(timestep=timestep) - - # never use cfg1 optimization if using custom_cfg (since can have timesteps and such) - if ADGS.sample_settings.custom_cfg is None and math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: - uncond_ = None - else: - uncond_ = uncond + ADGS.prepare_current_keyframes(x=x, timestep=timestep) + try: + cond, uncond = ADGS.perform_special_model_features(model, [cond, uncond], x) - # add AD/evolved-sampling params to model_options (transformer_options) - model_options = model_options.copy() - if "tranformer_options" not in model_options: - model_options["tranformer_options"] = {} - model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() + # never use cfg1 optimization if using custom_cfg (since can have timesteps and such) + if ADGS.sample_settings.custom_cfg is None and math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: + uncond_ = None + else: + uncond_ = uncond - if not ADGS.is_using_sliding_context(): - cond_pred, uncond_pred = calc_cond_uncond_batch_wrapper(model, [cond, uncond_], x, timestep, model_options) - else: - cond_pred, uncond_pred = sliding_calc_conds_batch(model, [cond, uncond_], x, timestep, model_options) + # add AD/evolved-sampling params to model_options (transformer_options) + model_options = model_options.copy() + if "tranformer_options" not in model_options: + model_options["tranformer_options"] = {} + model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() - if hasattr(comfy.samplers, "cfg_function"): - try: - cached_calc_cond_batch = comfy.samplers.calc_cond_batch - # support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch - comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch) - return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond) - finally: - comfy.samplers.calc_cond_batch = cached_calc_cond_batch - else: # for backwards compatibility, for now - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} - cfg_result = x - model_options["sampler_cfg_function"](args) + if not ADGS.is_using_sliding_context(): + cond_pred, uncond_pred = calc_cond_uncond_batch_wrapper(model, [cond, uncond_], x, timestep, model_options) else: - cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + cond_pred, uncond_pred = sliding_calc_conds_batch(model, [cond, uncond_], x, timestep, model_options) - for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, - "sigma": timestep, "model_options": model_options, "input": x} - cfg_result = fn(args) + if hasattr(comfy.samplers, "cfg_function"): + try: + cached_calc_cond_batch = comfy.samplers.calc_cond_batch + # support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch + comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch) + return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond) + finally: + comfy.samplers.calc_cond_batch = cached_calc_cond_batch + else: # for backwards compatibility, for now + if "sampler_cfg_function" in model_options: + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale - return cfg_result + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + + return cfg_result + finally: + ADGS.restore_special_model_features(model) + + +def perform_image_injection(model: BaseModel, latents: Tensor, to_inject: NoisedImageToInject) -> Tensor: + # NOTE: the latents here have already been process_latent_out'ed + # get currently used models so they can be properly reloaded after perfoming VAE Encoding + if hasattr(comfy.model_management, "loaded_models"): + cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + else: + cached_loaded_models: list[ModelPatcherAndInjector] = [x.model for x in comfy.model_management.current_loaded_models] + try: + orig_device = latents.device + orig_dtype = latents.dtype + # follow same steps as in KSampler Custom to get same denoised_x0 value + x0 = ADGS.callback_output_dict.get("x0", None) + if x0 is None: + return latents + # x0 should be process_latent_out'ed to match expected state of latents between nodes + x0 = model.process_latent_out(x0) + + # first, decode x0 into images, and then re-encode + decoded_images = vae_decode_raw_batched(to_inject.vae, x0) + encoded_x0 = vae_encode_raw_batched(to_inject.vae, decoded_images) + + # get difference between sampled latents and encoded_x0 + encoded_x0 = latents - encoded_x0 + + # get mask, or default to full mask + mask = to_inject.mask + b, c, h, w = encoded_x0.shape + # need to resize images and masks to match expected dims + if mask is None: + mask = torch.ones(1, h, w) + if to_inject.invert_mask: + mask = 1.0 - mask + opts = to_inject.img_inject_opts + # composite decoded_x0 with image to inject; + # make sure to move dims to match expectation of (b,c,h,w) + composited = composite_extend(destination=decoded_images.movedim(-1, 1), source=to_inject.image.movedim(-1, 1), x=opts.x, y=opts.y, mask=mask, + multiplier=to_inject.vae.downscale_ratio, resize_source=to_inject.resize_image).movedim(1, -1) + # encode composited to get latent representation + composited = vae_encode_raw_batched(to_inject.vae, composited) + # add encoded_x0 diff to composited + composited += encoded_x0 + if type(to_inject.strength_multival) == float and math.isclose(1.0, to_inject.strength_multival): + return composited.to(dtype=orig_dtype, device=orig_device) + strength = to_inject.strength_multival + if type(strength) == Tensor: + strength = extend_to_batch_size(prepare_mask_batch(strength, composited.shape), b) + return composited * strength + latents * (1.0 - strength) + finally: + comfy.model_management.load_models_gpu(cached_loaded_models) def wrapped_cfg_sliding_calc_cond_batch_factory(orig_calc_cond_batch): @@ -627,6 +849,30 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list return conds_final +def get_conds_with_c_concat(conds: list[dict], c_concat: comfy.conds.CONDNoiseShape): + new_conds = [] + for cond in conds: + resized_cond = None + if cond is not None: + # reuse or resize cond items to match context requirements + resized_cond = [] + # cond object is a list containing a dict - outer list is irrelevant, so just loop through it + for actual_cond in cond: + resized_actual_cond = actual_cond.copy() + # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary + for key in actual_cond: + if key == "model_conds": + new_model_conds = actual_cond[key].copy() + if "c_concat" in new_model_conds: + new_model_conds["c_concat"] = comfy.conds.CONDNoiseShape(torch.cat(new_model_conds["c_concat"].cond, c_concat.cond, dim=1)) + else: + new_model_conds["c_concat"] = c_concat + resized_actual_cond[key] = new_model_conds + resized_cond.append(resized_actual_cond) + new_conds.append(resized_cond) + return new_conds + + def calc_cond_uncond_batch_wrapper(model, conds: list[dict], x_in: Tensor, timestep, model_options): # check if conds or unconds contain lora_hook or default_cond contains_lora_hooks = False diff --git a/animatediff/utils_model.py b/animatediff/utils_model.py index d52f885..c59af31 100644 --- a/animatediff/utils_model.py +++ b/animatediff/utils_model.py @@ -5,6 +5,7 @@ from time import time import copy +from torch import Tensor import torch import numpy as np @@ -12,14 +13,60 @@ from comfy.model_base import SD21UNCLIP, SDXL, BaseModel, SDXLRefiner, SVD_img2vid, model_sampling, ModelType from comfy.model_management import xformers_enabled from comfy.model_patcher import ModelPatcher +from comfy.sd import VAE +from comfy.utils import ProgressBar import comfy.model_sampling import comfy_extras.nodes_model_advanced +from .logger import logger BIGMIN = -(2**53-1) BIGMAX = (2**53-1) +MAX_RESOLUTION = 16384 # mirrors ComfyUI's nodes.py MAX_RESOLUTION + + +def vae_encode_raw_dynamic_batched(vae: VAE, pixels: Tensor, max_batch=16, min_batch=1, max_size=512*512, show_pbar=False): + b, h, w, c = pixels.shape + actual_size = h*w + actual_batch_size = int(max(min_batch, min(max_batch, max_batch // max((actual_size / max_size), 1.0)))) + logger.info(f"actual_batch_size: {actual_batch_size}") + return vae_encode_raw_batched(vae=vae, pixels=pixels, per_batch=actual_batch_size, show_pbar=show_pbar) + + +def vae_decode_raw_dynamic_batched(vae: VAE, latents: Tensor, max_batch=16, min_batch=1, max_size=512*512, show_pbar=False): + b, c, h, w = latents.shape + actual_size = (h*vae.downscale_ratio)*(w*vae.downscale_ratio) + actual_batch_size = int(max(min_batch, min(max_batch, max_batch // max((actual_size / max_size), 1.0)))) + return vae_decode_raw_batched(vae=vae, latents=latents, per_batch=actual_batch_size, show_pbar=show_pbar) + + +def vae_encode_raw_batched(vae: VAE, pixels: Tensor, per_batch=16, show_pbar=False): + encoded = [] + pbar = None + if show_pbar: + pbar = ProgressBar(pixels.shape[0]) + for start_idx in range(0, pixels.shape[0], per_batch): + sub_encoded = vae.encode(pixels[start_idx:start_idx+per_batch][:,:,:,:3]) + encoded.append(sub_encoded) + if pbar is not None: + pbar.update(sub_encoded.shape[0]) + return torch.cat(encoded, dim=0) + + +def vae_decode_raw_batched(vae: VAE, latents: Tensor, per_batch=16, show_pbar=False): + decoded = [] + pbar = None + if show_pbar: + pbar = ProgressBar(latents.shape[0]) + for start_idx in range(0, latents.shape[0], per_batch): + sub_decoded = vae.decode(latents[start_idx:start_idx+per_batch]) + decoded.append(sub_decoded) + if pbar is not None: + pbar.update(sub_decoded.shape[0]) + return torch.cat(decoded, dim=0) + class ModelSamplingConfig: def __init__(self, beta_schedule: str, linear_start: float=None, linear_end: float=None): @@ -338,6 +385,8 @@ class ModelTypeSD: SDXL_REFINER = "SDXL_Refiner" SVD = "SVD" + _LIST = [SD1_5, SD2_1, SDXL, SDXL_REFINER, SVD] + def get_sd_model_type(model: ModelPatcher) -> str: if model is None: diff --git a/animatediff/utils_motion.py b/animatediff/utils_motion.py index a2d7510..fcc205a 100644 --- a/animatediff/utils_motion.py +++ b/animatediff/utils_motion.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F from torch import Tensor, nn +from abc import ABC, abstractmethod import comfy.model_management as model_management import comfy.ops @@ -96,15 +97,11 @@ def forward(self, input: Tensor) -> Tensor: # applies min-max normalization, from: # https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch -def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0): +def normalize_min_max(x: Tensor, new_min=0.0, new_max=1.0): return linear_conversion(x, x_min=x.min(), x_max=x.max(), new_min=new_min, new_max=new_max) def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0): - x_min = float(x_min) - x_max = float(x_max) - new_min = float(new_min) - new_max = float(new_max) return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min @@ -126,6 +123,14 @@ def extend_to_batch_size(tensor: Tensor, batch_size: int): return tensor +def extend_list_to_batch_size(_list: list, batch_size: int): + if len(_list) > batch_size: + return _list[:batch_size] + elif len(_list) < batch_size: + return _list + _list[-1:]*(batch_size-len(_list)) + return _list.copy() + + # from comfy/controlnet.py def ade_broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -146,6 +151,42 @@ def ade_broadcast_image_to(tensor, target_batch_size, batched_number): return torch.cat([tensor] * batched_number, dim=0) +# originally from comfy_extras/nodes_mask.py::composite function +def composite_extend(destination: Tensor, source: Tensor, x: int, y: int, mask: Tensor = None, multiplier = 8, resize_source = False): + source = source.to(destination.device) + if resize_source: + source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") + + source = extend_to_batch_size(source, destination.shape[0]) + + x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) + y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) + + left, top = (x // multiplier, y // multiplier) + right, bottom = (left + source.shape[3], top + source.shape[2],) + + if mask is None: + mask = torch.ones_like(source) + else: + mask = mask.to(destination.device, copy=True) + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = extend_to_batch_size(mask, source.shape[0]) + + # calculate the bounds of the source that will be overlapping the destination + # this prevents the source trying to overwrite latent pixels that are out of bounds + # of the destination + visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) + + mask = mask[:, :, :visible_height, :visible_width] + inverse_mask = torch.ones_like(mask) - mask + + source_portion = mask * source[:, :, :visible_height, :visible_width] + destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + + destination[:, :, top:bottom, left:right] = source_portion + destination_portion + return destination + + def get_sorted_list_via_attr(objects: list, attr: str) -> list: if not objects: return objects @@ -174,6 +215,29 @@ class MotionCompatibilityError(ValueError): pass +class InputPIA(ABC): + def __init__(self, effect_multival: Union[float, Tensor]=None): + self.effect_multival = effect_multival if effect_multival is not None else 1.0 + + @abstractmethod + def get_mask(self, x: Tensor): + pass + + +class InputPIA_Multival(InputPIA): + def __init__(self, multival: Union[float, Tensor], effect_multival: Union[float, Tensor]=None): + super().__init__(effect_multival=effect_multival) + self.multival = multival + + def get_mask(self, x: Tensor): + if type(self.multival) is Tensor: + return self.multival + # if not Tensor, then is float, and simply return a mask with the right dimensions + value + b, c, h, w = x.shape + mask = torch.ones(size=(b, h, w)) + return mask * self.multival + + def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor]) -> Union[float, Tensor]: # if one is None, use the other if multivalA == None: @@ -200,12 +264,29 @@ def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[floa return multivalA * multivalB +def get_combined_input(inputA: Union[InputPIA, None], inputB: Union[InputPIA, None], x: Tensor): + if inputA is None: + inputA = InputPIA_Multival(1.0) + if inputB is None: + inputB = InputPIA_Multival(1.0) + return get_combined_multival(inputA.get_mask(x), inputB.get_mask(x)) + + +def get_combined_input_effect_multival(inputA: Union[InputPIA, None], inputB: Union[InputPIA, None]): + if inputA is None: + inputA = InputPIA_Multival(1.0) + if inputB is None: + inputB = InputPIA_Multival(1.0) + return get_combined_multival(inputA.effect_multival, inputB.effect_multival) + + class ADKeyframe: def __init__(self, start_percent: float = 0.0, scale_multival: Union[float, Tensor]=None, effect_multival: Union[float, Tensor]=None, cameractrl_multival: Union[float, Tensor]=None, + pia_input: InputPIA=None, inherit_missing: bool=True, guarantee_steps: int=1, default: bool=False, @@ -215,6 +296,7 @@ def __init__(self, self.scale_multival = scale_multival self.effect_multival = effect_multival self.cameractrl_multival = cameractrl_multival + self.pia_input = pia_input self.inherit_missing = inherit_missing self.guarantee_steps = guarantee_steps self.default = default @@ -227,6 +309,9 @@ def has_effect(self): def has_cameractrl_effect(self): return self.cameractrl_multival is not None + + def has_pia_input(self): + return self.pia_input is not None class ADKeyframeGroup: diff --git a/pyproject.toml b/pyproject.toml index 0cfd061..deaca26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.0.4" +version = "1.0.5" license = "LICENSE" dependencies = []