Skip to content

Commit

Permalink
Merge PR #460 from Kosinkadink/develop - New ComfyUI version changes
Browse files Browse the repository at this point in the history
Work with new ComfyUI ModelPatcher update (NOT backwards compatible)
  • Loading branch information
Kosinkadink authored Aug 20, 2024
2 parents d8af8fe + 790f3a9 commit c5c2780
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 117 deletions.
119 changes: 64 additions & 55 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,53 +220,54 @@ def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup):
combined_patches[key] = current_patches
return combined_patches

def model_patches_to(self, device):
super().model_patches_to(device)

def patch_model(self, device_to=None, patch_weights=True):
def patch_model(self, *args, **kwargs):
was_injected = False
if self.currently_injected:
self.eject_model()
was_injected = True
# first, perform model patching
if patch_weights: # TODO: keep only 'else' portion when don't need to worry about past comfy versions
patched_model = super().patch_model(device_to)
else:
patched_model = super().patch_model(device_to, patch_weights)
# finally, perform motion model injection
self.inject_model()
patched_model = super().patch_model(*args, **kwargs)
# bring injection back to original state
if was_injected and not self.currently_injected:
self.inject_model()
return patched_model

def patch_model_lowvram(self, *args, **kwargs):
def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
self.eject_model()
try:
return super().patch_model_lowvram(*args, **kwargs)
return super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs)
finally:
# check if any modules have weight_function or bias_function that is not None
# NOTE: this serves no purpose currently, but I have it here for future reasons
for n, m in self.model.named_modules():
if not hasattr(m, "comfy_cast_weights"):
continue
if getattr(m, "weight_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.bias"] = n
self.inject_model()
if lowvram_model_memory > 0:
self._patch_lowvram_extras()

def _patch_lowvram_extras(self):
# check if any modules have weight_function or bias_function that is not None
# NOTE: this serves no purpose currently, but I have it here for future reasons
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()
for n, m in self.model.named_modules():
if not hasattr(m, "comfy_cast_weights"):
continue
if getattr(m, "weight_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.bias"] = n

def unpatch_model(self, device_to=None, unpatch_weights=True):
# first, eject motion model from unet
self.eject_model()
# finally, do normal model unpatching
if unpatch_weights: # TODO: keep only 'else' portion when don't need to worry about past comfy versions
if unpatch_weights:
# handle hooked_patches first
self.clean_hooks()
try:
return super().unpatch_model(device_to)
finally:
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()
else:
try:
return super().unpatch_model(device_to, unpatch_weights)
finally:
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()
try:
return super().unpatch_model(device_to, unpatch_weights)
finally:
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()

def partially_load(self, *args, **kwargs):
# partially_load calls patch_model, but we don't want to inject model in the intermediate call;
Expand Down Expand Up @@ -625,31 +626,37 @@ def patch_hooked_replace_weight_to_device(self, model_sd: dict, replace_patches:
else:
comfy.utils.set_attr_param(self.model, key, out_weight)

def patch_model(self, device_to=None, patch_weights=True, *args, **kwargs):
def patch_model(self, device_to=None, *args, **kwargs):
if self.desired_lora_hooks is not None:
self.patches_backup = self.patches.copy()
relevant_patches = self.get_combined_hooked_patches(lora_hooks=self.desired_lora_hooks)
for key in relevant_patches:
self.patches.setdefault(key, [])
self.patches[key].extend(relevant_patches[key])
self.current_lora_hooks = self.desired_lora_hooks
return super().patch_model(device_to, patch_weights, *args, **kwargs)
return super().patch_model(device_to, *args, **kwargs)

def patch_model_lowvram(self, *args, **kwargs):
def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
try:
return super().patch_model_lowvram(*args, **kwargs)
return super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs)
finally:
# check if any modules have weight_function or bias_function that is not None
# NOTE: this serves no purpose currently, but I have it here for future reasons
for n, m in self.model.named_modules():
if not hasattr(m, "comfy_cast_weights"):
continue
if getattr(m, "weight_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if lowvram_model_memory > 0:
self._patch_lowvram_extras()

def _patch_lowvram_extras(self):
# check if any modules have weight_function or bias_function that is not None
# NOTE: this serves no purpose currently, but I have it here for future reasons
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()
for n, m in self.model.named_modules():
if not hasattr(m, "comfy_cast_weights"):
continue
if getattr(m, "weight_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n

def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs):
try:
Expand Down Expand Up @@ -797,10 +804,14 @@ def __init__(self, *args, **kwargs):
self.was_within_range = False
self.prev_sub_idxs = None
self.prev_batched_number = None

def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, *args, **kwargs):
patched_model = super().patch_model_lowvram(device_to, lowvram_model_memory, force_patch_weights, *args, **kwargs)

def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
to_return = super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs)
if lowvram_model_memory > 0:
self._patch_lowvram_extras(device_to=device_to)
return to_return

def _patch_lowvram_extras(self, device_to=None):
# figure out the tensors (likely pe's) that should be cast to device besides just the named_modules
remaining_tensors = list(self.model.state_dict().keys())
named_modules = []
Expand All @@ -817,8 +828,6 @@ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patc
if device_to is not None:
comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to))

return patched_model

def pre_run(self, model: ModelPatcherAndInjector):
self.cleanup()
self.model.set_scale(self.scale_multival, self.per_block_list)
Expand Down
81 changes: 20 additions & 61 deletions animatediff/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
from comfy.ldm.modules.diffusionmodules import openaimodel
import comfy.model_management
import comfy.samplers
import comfy.sample
SAMPLE_FALLBACK = False
try:
import comfy.sampler_helpers
except ImportError:
SAMPLE_FALLBACK = True
import comfy.sampler_helpers
import comfy.utils
from comfy.controlnet import ControlBase
from comfy.model_base import BaseModel
Expand Down Expand Up @@ -291,10 +286,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
self.orig_diffusion_model_forward = model.model.diffusion_model.forward
self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers
self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult
if SAMPLE_FALLBACK: # for backwards compatibility, for now
self.orig_get_additional_models = comfy.sample.get_additional_models
else:
self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models
self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models
self.orig_apply_model = model.model.apply_model
# Inject Functions
openaimodel.forward_timestep_embed = forward_timestep_embed_factory()
Expand Down Expand Up @@ -324,10 +316,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
del info
comfy.samplers.sampling_function = evolved_sampling_function
comfy.samplers.get_area_and_mult = get_area_and_mult_ADE
if SAMPLE_FALLBACK: # for backwards compatibility, for now
comfy.sample.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models)
else:
comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models)
comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models)
# create temp_uninjector to help facilitate uninjecting functions
self.temp_uninjector = GroupnormUninjectHelper(self)

Expand All @@ -341,10 +330,7 @@ def restore_functions(self, model: ModelPatcherAndInjector):
model.model.diffusion_model.forward = self.orig_diffusion_model_forward
comfy.samplers.sampling_function = self.orig_sampling_function
comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult
if SAMPLE_FALLBACK: # for backwards compatibility, for now
comfy.sample.get_additional_models = self.orig_get_additional_models
else:
comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models
comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models
model.model.apply_model = self.orig_apply_model
except AttributeError:
logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \
Expand Down Expand Up @@ -505,17 +491,8 @@ def ad_callback(step, x0, x, total_steps):
if is_custom:
iter_kwargs[IterationOptions.SAMPLER] = None #args[-5]
else:
if SAMPLE_FALLBACK: # backwards compatibility, for now
# in older comfy, model needs to be loaded to get proper model_sampling to be used for sigmas
comfy.model_management.load_model_gpu(model)
iter_model = model.model
else:
iter_model = model
current_device = None
if hasattr(model, "current_device"): # backwards compatibility, for now
current_device = model.current_device
else:
current_device = model.model.device
iter_model = model
current_device = model.model.device
iter_kwargs[IterationOptions.SAMPLER] = comfy.samplers.KSampler(
iter_model, steps=999, #steps=args[-7],
device=current_device, sampler=args[-5],
Expand Down Expand Up @@ -653,35 +630,20 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond,
model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params()

if not ADGS.is_using_sliding_context():
cond_pred, uncond_pred = calc_cond_uncond_batch_wrapper(model, [cond, uncond_], x, timestep, model_options)
cond_pred, uncond_pred = calc_conds_batch_wrapper(model, [cond, uncond_], x, timestep, model_options)
else:
cond_pred, uncond_pred = sliding_calc_conds_batch(model, [cond, uncond_], x, timestep, model_options)

if hasattr(comfy.samplers, "cfg_function"):
if ADGS.sample_settings.custom_cfg is not None:
cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred)
model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options)
try:
cached_calc_cond_batch = comfy.samplers.calc_cond_batch
# support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch
comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch)
return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond)
finally:
comfy.samplers.calc_cond_batch = cached_calc_cond_batch
else: # for backwards compatibility, for now
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale

for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)

return cfg_result
if ADGS.sample_settings.custom_cfg is not None:
cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred)
model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options)
try:
cached_calc_cond_batch = comfy.samplers.calc_cond_batch
# support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch
comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch)
return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond)
finally:
comfy.samplers.calc_cond_batch = cached_calc_cond_batch
finally:
ADGS.restore_special_model_features(model)

Expand Down Expand Up @@ -745,7 +707,7 @@ def wrapped_cfg_sliding_calc_cond_batch(model, conds, x_in, timestep, model_opti
# when inside sliding_calc_conds_batch, should return to original calc_cond_batch
comfy.samplers.calc_cond_batch = orig_calc_cond_batch
if not ADGS.is_using_sliding_context():
return calc_cond_uncond_batch_wrapper(model, conds, x_in, timestep, model_options)
return calc_conds_batch_wrapper(model, conds, x_in, timestep, model_options)
else:
return sliding_calc_conds_batch(model, conds, x_in, timestep, model_options)
finally:
Expand Down Expand Up @@ -922,7 +884,7 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list
model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF
#logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}")

sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options)
sub_conds_out = calc_conds_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options)

if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
full_length = ADGS.params.full_length
Expand Down Expand Up @@ -1008,7 +970,7 @@ def get_conds_with_c_concat(conds: list[dict], c_concat: comfy.conds.CONDNoiseSh
return new_conds


def calc_cond_uncond_batch_wrapper(model, conds: list[dict], x_in: Tensor, timestep, model_options):
def calc_conds_batch_wrapper(model, conds: list[dict], x_in: Tensor, timestep, model_options):
# check if conds or unconds contain lora_hook or default_cond
contains_lora_hooks = False
has_default_cond = False
Expand All @@ -1028,9 +990,6 @@ def calc_cond_uncond_batch_wrapper(model, conds: list[dict], x_in: Tensor, times
ADGS.hooks_initialize(model, hook_groups=hook_groups)
ADGS.prepare_hooks_current_keyframes(timestep, hook_groups=hook_groups)
return calc_conds_batch_lora_hook(model, conds, x_in, timestep, model_options, has_default_cond)
# keep for backwards compatibility, for now
if not hasattr(comfy.samplers, "calc_cond_batch"):
return comfy.samplers.calc_cond_uncond_batch(model, conds[0], conds[1], x_in, timestep, model_options)
return comfy.samplers.calc_cond_batch(model, conds, x_in, timestep, model_options)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.1.4"
version = "1.2.0"
license = { file = "LICENSE" }
dependencies = []

Expand Down

0 comments on commit c5c2780

Please sign in to comment.