Skip to content

Commit

Permalink
Merge PR #405 from Kosinkadink/develop - PIA, Image Injection, Multiv…
Browse files Browse the repository at this point in the history
…al Expansion

PIA, Image Injection, and Multival Expansion
  • Loading branch information
Kosinkadink authored Jun 17, 2024
2 parents 379044e + 528ae52 commit c98e8e6
Show file tree
Hide file tree
Showing 13 changed files with 1,213 additions and 90 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,26 @@ NOTE: you can also use custom locations for models/motion loras by making use of
- NOTE: Requires same settings as described for AnimateLCM above. Requires ```Apply AnimateLCM-I2V Model``` Gen2 node usage so that ```ref_latent``` can be provided; use ```Scale Ref Image and VAE Encode``` node to preprocess input images. While this was intended as an img2video model, I found it works best for vid2vid purposes with ```ref_drift=0.0```, and to use it for only at least 1 step before switching over to other models via chaining with toher Apply AnimateDiff Model (Adv.) nodes. The ```apply_ref_when_disabled``` can be set to True to allow the img_encoder to do its thing even when the ```end_percent``` is reached. AnimateLCM-I2V is also extremely useful for maintaining coherence at higher resolutions (with ControlNet and SD LoRAs active, I could easily upscale from 512x512 source to 1024x1024 in a single pass). TODO: add examples
- [CameraCtrl](https://github.com/hehao13/CameraCtrl) support, with the pruned model you must use here: [CameraCtrl_pruned.safetensors](https://huggingface.co/Kosinkadink/CameraCtrl/tree/main)
- NOTE: Requires AnimateDiff SD1.5 models, and was specifically trained for v3 model. Gen2 only, with helper nodes provided under Gen2/CameraCtrl submenu.
- [PIA](https://github.com/open-mmlab/PIA) support, with the model [pia.ckpt](https://huggingface.co/Leoxing/PIA/tree/main)
- NOTE: You will need to use ```autoselect``` or ```sqrt_linear (AnimateDiff)``` beta_schedule. Requires ```Apply AnimateDiff-PIA Model``` Gen2 node usage if you want to actually provide input images. The ```pia_input``` can be provided via the paper's presets (```PIA Input [Paper Presets]```) or by manually entering values (```PIA Input [Multival]```).
- AnimateDiff Keyframes to change Scale and Effect at different points in the sampling process.
- fp8 support; requires newest ComfyUI and torch >= 2.1 (decreases VRAM usage, but changes outputs)
- Mac M1/M2/M3 support
- Usage of Context Options and Sample Settings outside of AnimateDiff via Gen2 Use Evolved Sampling node
- Maskable and Schedulable SD LoRA (and Models as LoRA) for both AnimateDiff and StableDiffusion usage via LoRA Hooks
- Per-frame GLIGEN coordinates control
- Currently requires GLIGENTextBoxApplyBatch from KJNodes to do so, but I will add native nodes to do this soon.
- Image Injection mid-sampling

## Upcoming Features
- Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: mid-May)
- Maskable Motion LoRA (Goal: end of May/beginning of June)
- Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: before Elden Ring DLC releases. Working on it right now.)
- [UniCtrl](https://github.com/XuweiyiChen/UniCtrl) support
- Unet-Ref support so that a bunch of papers can be ported over
- [StoryDiffusion](https://github.com/HVision-NKU/StoryDiffusion) implementation
- Merging motion model weights/components, including per block customization
- Maskable Motion LoRA
- Timestep schedulable GLIGEN coordinates
- Dynamic memory management for motion models that load/unload at different start/end_percents
- [PIA](https://github.com/open-mmlab/PIA) support
- [UniCtrl](https://github.com/XuweiyiChen/UniCtrl) support
- Built-in prompt travel implementation
- Anything else AnimateDiff-related that comes out

Expand Down
129 changes: 119 additions & 10 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,26 @@
import uuid
import math

import comfy.conds
import comfy.lora
import comfy.model_management
import comfy.utils
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.sd import CLIP
from comfy.sd import CLIP, VAE

from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight
from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding
from .context import ContextOptions, ContextOptions, ContextOptionsGroup
from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention,
has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len)
from .logger import logger
from .utils_motion import ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, get_combined_multival, ade_broadcast_image_to, normalize_min_max
from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA,
get_combined_multival, get_combined_input, get_combined_input_effect_multival,
ade_broadcast_image_to, extend_to_batch_size, prepare_mask_batch)
from .conditioning import HookRef, LoraHook, LoraHookGroup, LoraHookMode
from .motion_lora import MotionLoraInfo, MotionLoraList
from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type
from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched
from .sample_settings import SampleSettings, SeedNoiseGeneration


Expand Down Expand Up @@ -138,7 +141,6 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
'''
Based on add_patches, but for hooked weights.
'''
# TODO: make this work with timestep scheduling
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
p = set()
model_sd = self.model.state_dict()
Expand All @@ -164,7 +166,6 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
'''
Based on add_hooked_patches, but intended for using a model's weights as lora hook.
'''
# TODO: make this work with timestep scheduling
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
p = set()
model_sd = self.model.state_dict()
Expand All @@ -180,6 +181,7 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
# take difference between desired weight and existing weight to get diff
# TODO: create fix for fp8
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
Expand Down Expand Up @@ -238,7 +240,7 @@ def patch_model_lowvram(self, *args, **kwargs):
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
self.model_params_lowvram_keys[f"{n}.bias"] = n

def unpatch_model(self, device_to=None, unpatch_weights=True):
# first, eject motion model from unet
Expand Down Expand Up @@ -721,7 +723,16 @@ def __init__(self, *args, **kwargs):
self.orig_camera_entries: list[CameraEntry] = None
self.camera_features: list[Tensor] = None # temporary
self.camera_features_shape: tuple = None
self.cameractrl_multival = None
self.cameractrl_multival: Union[float, Tensor] = None

# PIA
self.orig_pia_images: Tensor = None
self.pia_vae: VAE = None
self.pia_input: InputPIA = None
self.cached_pia_c_concat: comfy.conds.CONDNoiseShape = None # cached
self.prev_pia_latents_shape: tuple = None
self.prev_current_pia_input: InputPIA = None
self.pia_multival: Union[float, Tensor] = None

# temporary variables
self.current_used_steps = 0
Expand All @@ -730,9 +741,12 @@ def __init__(self, *args, **kwargs):
self.current_scale: Union[float, Tensor] = None
self.current_effect: Union[float, Tensor] = None
self.current_cameractrl_effect: Union[float, Tensor] = None
self.current_pia_input: InputPIA = None
self.combined_scale: Union[float, Tensor] = None
self.combined_effect: Union[float, Tensor] = None
self.combined_cameractrl_effect: Union[float, Tensor] = None
self.combined_pia_mask: Union[float, Tensor] = None
self.combined_pia_effect: Union[float, Tensor] = None
self.was_within_range = False
self.prev_sub_idxs = None
self.prev_batched_number = None
Expand Down Expand Up @@ -774,7 +788,7 @@ def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)

def prepare_current_keyframe(self, t: Tensor):
def prepare_current_keyframe(self, x: Tensor, t: Tensor):
curr_t: float = t[0]
prev_index = self.current_index
# if met guaranteed steps, look for next keyframe in case need to switch
Expand Down Expand Up @@ -802,6 +816,10 @@ def prepare_current_keyframe(self, t: Tensor):
self.current_cameractrl_effect = self.current_keyframe.cameractrl_multival
elif not self.current_keyframe.inherit_missing:
self.current_cameractrl_effect = None
if self.current_keyframe.has_pia_input():
self.current_pia_input = self.current_keyframe.pia_input
elif not self.current_keyframe.inherit_missing:
self.current_pia_input = None
# if guarantee_steps greater than zero, stop searching for other keyframes
if self.current_keyframe.guarantee_steps > 0:
break
Expand All @@ -814,6 +832,8 @@ def prepare_current_keyframe(self, t: Tensor):
self.combined_scale = get_combined_multival(self.scale_multival, self.current_scale)
self.combined_effect = get_combined_multival(self.effect_multival, self.current_effect)
self.combined_cameractrl_effect = get_combined_multival(self.cameractrl_multival, self.current_cameractrl_effect)
self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x)
self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input)
# apply scale and effect
self.model.set_scale(self.combined_scale)
self.model.set_effect(self.combined_effect)
Expand Down Expand Up @@ -889,6 +909,72 @@ def prepare_camera_features(self, x: Tensor, cond_or_uncond: list[int], ad_param
self.prev_sub_idxs = sub_idxs
self.prev_batched_number = batched_number

def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor:
# if have cached shape, check if matches - if so, return cached pia_latents
if self.prev_pia_latents_shape is not None:
if self.prev_pia_latents_shape[0] == x.shape[0] and self.prev_pia_latents_shape[2] == x.shape[2] and self.prev_pia_latents_shape[3] == x.shape[3]:
# if mask is also the same for this timestep, then return cached
if self.prev_current_pia_input == self.current_pia_input:
return self.cached_pia_c_concat
# otherwise, adjust new mask, and create new cached_pia_c_concat
b, c, h ,w = x.shape
mask = prepare_mask_batch(self.combined_pia_mask, x.shape)
mask = extend_to_batch_size(mask, b)
# make sure to update prev_current_pia_input to know when is changed
self.prev_current_pia_input = self.current_pia_input
# TODO: handle self.combined_pia_effect eventually (feature hidden for now)
# the first index in dim=1 is the mask that needs to be updated - update in place
self.cached_pia_c_concat.cond[:, :1, :, :] = mask
return self.cached_pia_c_concat
self.prev_pia_latents_shape = None
# otherwise, x shape should be the cached pia_latents_shape
# get currently used models so they can be properly reloaded after perfoming VAE Encoding
if hasattr(comfy.model_management, "loaded_models"):
cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
else:
cached_loaded_models: list[ModelPatcherAndInjector] = [x.model for x in comfy.model_management.current_loaded_models]
try:
b, c, h ,w = x.shape
usable_ref = self.orig_pia_images[:b]
# in diffusers, the image is scaled from [-1, 1] instead of default [0, 1],
# but form my testing, that blows out the images here, so I skip it
# usable_images = usable_images * 2 - 1
# resize images to latent's dims
usable_ref = usable_ref.movedim(-1,1)
usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.pia_vae.downscale_ratio, height=h*self.pia_vae.downscale_ratio,
upscale_method="bilinear", crop="center")
usable_ref = usable_ref.movedim(1,-1)
# VAE encode images
logger.info("VAE Encoding PIA input images...")
usable_ref = model.process_latent_in(vae_encode_raw_batched(vae=self.pia_vae, pixels=usable_ref, show_pbar=False))
logger.info("VAE Encoding PIA input images complete.")
# make pia_latents match expected length
usable_ref = extend_to_batch_size(usable_ref, b)
self.prev_pia_latents_shape = x.shape
# now, take care of the mask
mask = prepare_mask_batch(self.combined_pia_mask, x.shape)
mask = extend_to_batch_size(mask, b)
#mask = mask.unsqueeze(1)
self.prev_current_pia_input = self.current_pia_input
if type(self.combined_pia_effect) == Tensor or not math.isclose(self.combined_pia_effect, 1.0):
real_pia_effect = self.combined_pia_effect
if type(self.combined_pia_effect) == Tensor:
real_pia_effect = extend_to_batch_size(prepare_mask_batch(self.combined_pia_effect, x.shape), b)
zero_mask = torch.zeros_like(mask)
mask = mask * real_pia_effect + zero_mask * (1.0 - real_pia_effect)
del zero_mask
zero_usable_ref = torch.zeros_like(usable_ref)
usable_ref = usable_ref * real_pia_effect + zero_usable_ref * (1.0 - real_pia_effect)
del zero_usable_ref
# cache pia c_concat
self.cached_pia_c_concat = comfy.conds.CONDNoiseShape(torch.cat([mask, usable_ref], dim=1))
return self.cached_pia_c_concat
finally:
comfy.model_management.load_models_gpu(cached_loaded_models)

def is_pia(self):
return self.model.mm_info.mm_format == AnimateDiffFormat.PIA and self.orig_pia_images is not None

def cleanup(self):
if self.model is not None:
self.model.cleanup()
Expand All @@ -900,6 +986,9 @@ def cleanup(self):
del self.camera_features
self.camera_features = None
self.camera_features_shape = None
# PIA
self.combined_pia_mask = None
self.combined_pia_effect = None
# Default
self.current_used_steps = 0
self.current_keyframe = None
Expand Down Expand Up @@ -943,6 +1032,11 @@ def clone(self):
# CameraCtrl
n.orig_camera_entries = self.orig_camera_entries
n.cameractrl_multival = self.cameractrl_multival
# PIA
n.orig_pia_images = self.orig_pia_images
n.pia_vae = self.pia_vae
n.pia_input = self.pia_input
n.pia_multival = self.pia_multival
return n


Expand Down Expand Up @@ -995,9 +1089,16 @@ def cleanup(self):
for motion_model in self.models:
motion_model.cleanup()

def prepare_current_keyframe(self, t: Tensor):
def prepare_current_keyframe(self, x: Tensor, t: Tensor):
for motion_model in self.models:
motion_model.prepare_current_keyframe(x=x, t=t)

def get_pia_models(self):
pia_motion_models: list[MotionModelPatcher] = []
for motion_model in self.models:
motion_model.prepare_current_keyframe(t=t)
if motion_model.is_pia():
pia_motion_models.append(motion_model)
return pia_motion_models

def get_name_string(self, show_version=False):
identifiers = []
Expand Down Expand Up @@ -1161,6 +1262,14 @@ def inject_img_encoder_into_model(motion_model: MotionModelPatcher, w_encoder: M
motion_model.model.img_encoder.load_state_dict(w_encoder.model.img_encoder.state_dict())


def inject_pia_conv_in_into_model(motion_model: MotionModelPatcher, w_pia: MotionModelPatcher):
motion_model.model.init_conv_in(w_pia.model.state_dict())
motion_model.model.conv_in.to(comfy.model_management.unet_dtype())
motion_model.model.conv_in.to(comfy.model_management.unet_offload_device())
motion_model.model.conv_in.load_state_dict(w_pia.model.conv_in.state_dict())
motion_model.model.mm_info.mm_format = AnimateDiffFormat.PIA


def inject_camera_encoder_into_model(motion_model: MotionModelPatcher, camera_ctrl_name: str):
camera_ctrl_path = get_motion_model_path(camera_ctrl_name)
full_state_dict = comfy.utils.load_torch_file(camera_ctrl_path, safe_load=True)
Expand Down
Loading

0 comments on commit c98e8e6

Please sign in to comment.