Skip to content

Commit

Permalink
Merge PR #422 from Kosinkadink/develop - Custom CFG Improvements + GP…
Browse files Browse the repository at this point in the history
…U Noise

Custom CFG Improvements + GPU Noise
  • Loading branch information
Kosinkadink authored Jul 10, 2024
2 parents 7c75983 + b74d56c commit f3b24c1
Show file tree
Hide file tree
Showing 11 changed files with 525 additions and 56 deletions.
91 changes: 91 additions & 0 deletions animatediff/cfg_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Union

import inspect
import torch
from torch import Tensor

import comfy.model_patcher
import comfy.samplers

from .utils_motion import extend_to_batch_size, prepare_mask_batch


################################################################################
# helpers for modifying model_options to apply cfg function patches;
# taken from comfy/model_patcher.py
def set_model_options_sampler_cfg_function(model_options: dict[str], sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
model_options["disable_cfg1_optimization"] = True
return model_options
#-------------------------------------------------------------------------------


# this is a modified version of PerturbedAttentionGuidance from comfy_extras/nodes_pag.py
def perturbed_attention_guidance_patch(scale_multival: Union[float, Tensor]):
unet_block = "middle"
unet_block_id = 0

def perturbed_attention(q, k, v, extra_options, mask=None):
return v

def post_cfg_function(args):
model = args["model"]
cond_pred: Tensor = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]

if type(scale_multival) != Tensor and scale_multival == 0:
return cfg_result

scale = scale_multival
if isinstance(scale, Tensor):
scale = prepare_mask_batch(scale.to(cond_pred.dtype).to(cond_pred.device), cond_pred.shape)
scale = extend_to_batch_size(scale, cond_pred.shape[0])

# Replace Self-attention with PAG
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id)
(pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)

return cfg_result + (cond_pred - pag) * scale

return post_cfg_function


# this is a modified version of RescaleCFG from comfy_extras/nodes_model_advanced.py
def rescale_cfg_patch(multiplier_multival: Union[float, Tensor]):
def cfg_function(args):
cond: Tensor = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
sigma = args["sigma"]
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
x_orig = args["input"]

#rescale cfg has to be done on v-pred model output
x = x_orig / (sigma * sigma + 1.0)
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)

#rescalecfg
x_cfg = uncond + cond_scale * (cond - uncond)
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)

multiplier = multiplier_multival
if isinstance(multiplier, Tensor):
multiplier = prepare_mask_batch(multiplier.to(cond.dtype).to(cond.device), cond.shape)
multiplier = extend_to_batch_size(multiplier, cond.shape[0])

x_rescaled = x_cfg * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg

return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)

return cfg_function
31 changes: 31 additions & 0 deletions animatediff/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Callable, Optional, Union

import torchvision
import PIL

import numpy as np
from torch import Tensor

Expand Down Expand Up @@ -473,3 +476,31 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta


##########################
# Context Visualization
##########################
class Colors:
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)
MAGENTA = (255, 0, 255)
CYAN = (0, 255, 255)


class VisualizeSettings:
def __init__(self, img_width, img_height, video_length):
self.img_width = img_width
self.img_height = img_height
self.video_length = video_length
self.grid = img_width // video_length
self.pil_to_tensor = torchvision.transforms.Compose([torchvision.transforms.PILToTensor()])


def generate_context_visualization(context_opts: ContextOptionsGroup, model: BaseModel, width=1440, height=200, video_length=32, start_step=0, end_step=20):
vs = VisualizeSettings(width, height, video_length)
pass
4 changes: 2 additions & 2 deletions animatediff/motion_module_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def has_img_encoder(mm_state_dict: dict[str, Tensor]):

def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> Tuple[dict[str, Tensor], AnimateDiffInfo]:
# from pathlib import Path
# with open(Path(__file__).parent.parent.parent / f"keys_{mm_name}.txt", "w") as afile:
# log_name = mm_name.split('\\')[-1]
# with open(Path(__file__).parent.parent.parent / rf"keys_{log_name}.txt", "w") as afile:
# for key, value in mm_state_dict.items():
# afile.write(f"{key}:\t{value.shape}\n")

# determine what SD model the motion module is intended for
sd_type: str = None
down_block_max = get_down_block_max(mm_state_dict)
Expand Down
34 changes: 27 additions & 7 deletions animatediff/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
ConditioningTimestepsNode, SetLoraHookKeyframes,
CreateLoraHookKeyframe, CreateLoraHookKeyframeInterpolation, CreateLoraHookKeyframeFromStrengthList)
from .nodes_sample import (FreeInitOptionsNode, NoiseLayerAddWeightedNode, SampleSettingsNode, NoiseLayerAddNode, NoiseLayerReplaceNode, IterationOptionsNode,
CustomCFGNode, CustomCFGKeyframeNode, NoisedImageInjectionNode, NoisedImageInjectOptionsNode)
CustomCFGNode, CustomCFGSimpleNode, CustomCFGKeyframeNode, CustomCFGKeyframeSimpleNode,
CFGExtrasPAGNode, CFGExtrasPAGSimpleNode, CFGExtrasRescaleCFGNode, CFGExtrasRescaleCFGSimpleNode,
NoisedImageInjectionNode, NoisedImageInjectOptionsNode)
from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode)
from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode,
StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode)
StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsInt)
from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode,
WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode,
WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode)
from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect
from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect, PerturbedAttentionGuidanceMultival, RescaleCFGMultival
from .nodes_deprecated import (AnimateDiffLoader_Deprecated, AnimateDiffLoaderAdvanced_Deprecated, AnimateDiffCombine_Deprecated,
AnimateDiffModelSettings, AnimateDiffModelSettingsSimple, AnimateDiffModelSettingsAdvanced, AnimateDiffModelSettingsAdvancedAttnStrengths)
from .nodes_lora import AnimateDiffLoraLoader
Expand Down Expand Up @@ -56,6 +58,7 @@
"ADE_ViewsOnlyContextOptions": ViewAsContextOptionsNode,
"ADE_BatchedContextOptions": BatchedContextOptionsNode,
"ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy
#"ADE_VisualizeContextOptions": VisualizeContextOptionsInt,
# View Opts
"ADE_StandardStaticViewOptions": StandardStaticViewOptionsNode,
"ADE_StandardUniformViewOptions": StandardUniformViewOptionsNode,
Expand Down Expand Up @@ -100,7 +103,9 @@
"ADE_AdjustWeightIndivAttnAdd": WeightAdjustIndivAttnAddNode,
"ADE_AdjustWeightIndivAttnMult": WeightAdjustIndivAttnMultNode,
# Sample Settings
"ADE_CustomCFGSimple": CustomCFGSimpleNode,
"ADE_CustomCFG": CustomCFGNode,
"ADE_CustomCFGKeyframeSimple": CustomCFGKeyframeSimpleNode,
"ADE_CustomCFGKeyframe": CustomCFGKeyframeNode,
"ADE_SigmaSchedule": SigmaScheduleNode,
"ADE_RawSigmaSchedule": RawSigmaScheduleNode,
Expand All @@ -109,10 +114,16 @@
"ADE_SigmaScheduleSplitAndCombine": SplitAndCombineSigmaScheduleNode,
"ADE_NoisedImageInjection": NoisedImageInjectionNode,
"ADE_NoisedImageInjectOptions": NoisedImageInjectOptionsNode,
"ADE_CFGExtrasPAGSimple": CFGExtrasPAGSimpleNode,
"ADE_CFGExtrasPAG": CFGExtrasPAGNode,
"ADE_CFGExtrasRescaleCFGSimple": CFGExtrasRescaleCFGSimpleNode,
"ADE_CFGExtrasRescaleCFG": CFGExtrasRescaleCFGNode,
# Extras Nodes
"ADE_AnimateDiffUnload": AnimateDiffUnload,
"ADE_EmptyLatentImageLarge": EmptyLatentImageLarge,
"CheckpointLoaderSimpleWithNoiseSelect": CheckpointLoaderSimpleWithNoiseSelect,
"ADE_PerturbedAttentionGuidanceMultival": PerturbedAttentionGuidanceMultival,
"ADE_RescaleCFGMultival": RescaleCFGMultival,
# Gen1 Nodes
"ADE_AnimateDiffLoaderGen1": AnimateDiffLoaderGen1,
"ADE_AnimateDiffLoaderWithContext": LegacyAnimateDiffLoaderWithContext,
Expand Down Expand Up @@ -158,8 +169,8 @@
"ADE_AnimateDiffSamplingSettings": "Sample Settings 🎭🅐🅓",
"ADE_AnimateDiffKeyframe": "AnimateDiff Keyframe 🎭🅐🅓",
# Multival Nodes
"ADE_MultivalDynamic": "Multival Dynamic 🎭🅐🅓",
"ADE_MultivalDynamicFloatInput": "Multival Dynamic [Float List] 🎭🅐🅓",
"ADE_MultivalDynamic": "Multival 🎭🅐🅓",
"ADE_MultivalDynamicFloatInput": "Multival [Float List] 🎭🅐🅓",
"ADE_MultivalScaledMask": "Multival Scaled Mask 🎭🅐🅓",
"ADE_MultivalConvertToMask": "Multival to Mask 🎭🅐🅓",
# Context Opts
Expand All @@ -169,6 +180,7 @@
"ADE_ViewsOnlyContextOptions": "Context Options◆Views Only [VRAM⇈] 🎭🅐🅓",
"ADE_BatchedContextOptions": "Context Options◆Batched [Non-AD] 🎭🅐🅓",
"ADE_AnimateDiffUniformContextOptions": "Context Options◆Looped Uniform 🎭🅐🅓", # Legacy
"ADE_VisualizeContextOptions": "Visualize Context Options 🎭🅐🅓",
# View Opts
"ADE_StandardStaticViewOptions": "View Options◆Standard Static 🎭🅐🅓",
"ADE_StandardUniformViewOptions": "View Options◆Standard Uniform 🎭🅐🅓",
Expand Down Expand Up @@ -213,19 +225,27 @@
"ADE_AdjustWeightIndivAttnAdd": "Adjust Weight [Indiv-Attn◆Add] 🎭🅐🅓",
"ADE_AdjustWeightIndivAttnMult": "Adjust Weight [Indiv-Attn◆Mult] 🎭🅐🅓",
# Sample Settings
"ADE_CustomCFG": "Custom CFG 🎭🅐🅓",
"ADE_CustomCFGKeyframe": "Custom CFG Keyframe 🎭🅐🅓",
"ADE_CustomCFGSimple": "Custom CFG 🎭🅐🅓",
"ADE_CustomCFG": "Custom CFG [Multival] 🎭🅐🅓",
"ADE_CustomCFGKeyframeSimple": "Custom CFG Keyframe 🎭🅐🅓",
"ADE_CustomCFGKeyframe": "Custom CFG Keyframe [Multival] 🎭🅐🅓",
"ADE_SigmaSchedule": "Create Sigma Schedule 🎭🅐🅓",
"ADE_RawSigmaSchedule": "Create Raw Sigma Schedule 🎭🅐🅓",
"ADE_SigmaScheduleWeightedAverage": "Sigma Schedule Weighted Mean 🎭🅐🅓",
"ADE_SigmaScheduleWeightedAverageInterp": "Sigma Schedule Interpolated Mean 🎭🅐🅓",
"ADE_SigmaScheduleSplitAndCombine": "Sigma Schedule Split Combine 🎭🅐🅓",
"ADE_NoisedImageInjection": "Image Injection 🎭🅐🅓",
"ADE_NoisedImageInjectOptions": "Image Injection Options 🎭🅐🅓",
"ADE_CFGExtrasPAGSimple": "CFG Extras◆PAG 🎭🅐🅓",
"ADE_CFGExtrasPAG": "CFG Extras◆PAG [Multival] 🎭🅐🅓",
"ADE_CFGExtrasRescaleCFGSimple": "CFG Extras◆RescaleCFG 🎭🅐🅓",
"ADE_CFGExtrasRescaleCFG": "CFG Extras◆RescaleCFG [Multival] 🎭🅐🅓",
# Extras Nodes
"ADE_AnimateDiffUnload": "AnimateDiff Unload 🎭🅐🅓",
"ADE_EmptyLatentImageLarge": "Empty Latent Image (Big Batch) 🎭🅐🅓",
"CheckpointLoaderSimpleWithNoiseSelect": "Load Checkpoint w/ Noise Select 🎭🅐🅓",
"ADE_PerturbedAttentionGuidanceMultival": "PerturbedAttnGuide [Multival] 🎭🅐🅓",
"ADE_RescaleCFGMultival": "RescaleCFG [Multival] 🎭🅐🅓",
# Gen1 Nodes
"ADE_AnimateDiffLoaderGen1": "AnimateDiff Loader 🎭🅐🅓①",
"ADE_AnimateDiffLoaderWithContext": "AnimateDiff Loader [Legacy] 🎭🅐🅓①",
Expand Down
30 changes: 30 additions & 0 deletions animatediff/nodes_context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import torch
from torch import Tensor

from comfy.model_patcher import ModelPatcher

from .context import ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules
from .utils_model import BIGMAX

Expand Down Expand Up @@ -346,3 +351,28 @@ def create_options(self, view_length: int, view_overlap: int, view_stride: int,
use_on_equal_length=use_on_equal_length,
)
return (view_options,)


class VisualizeContextOptionsInt:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"context_opts": ("CONTEXT_OPTIONS",),
},
"optional": {
"latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}),
"start_step": ("INT", {"min": 0, "max": BIGMAX, "default": 0}),
"end_step": ("INT", {"min": 1, "max": BIGMAX, "default": 20}),
}
}

RETURN_TYPES = ("IMAGE",)
CATEGORY = "Animate Diff 🎭🅐🅓/context opts/visualize"
FUNCTION = "visualize"

def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup,
latents_length=32, start_step=0, end_step=20):
images = torch.zeros((latents_length, 256, 256, 3))
return (images,)
8 changes: 4 additions & 4 deletions animatediff/nodes_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def INPUT_TYPES(s):
},
"optional": {
"mask_motion_scale": ("MASK",),
"optional": {"deprecation_warning": ("ADEWARN", {"text": "Deprecated"})},
"deprecation_warning": ("ADEWARN", {"text": "Deprecated"}),
}
}

Expand Down Expand Up @@ -321,7 +321,7 @@ def INPUT_TYPES(s):
"mask_motion_scale": ("MASK",),
"min_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
"max_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
"optional": {"deprecation_warning": ("ADEWARN", {"text": "Deprecated"})},
"deprecation_warning": ("ADEWARN", {"text": "Deprecated"}),
}
}

Expand Down Expand Up @@ -360,7 +360,7 @@ def INPUT_TYPES(s):
"mask_motion_scale": ("MASK",),
"min_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
"max_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
"optional": {"deprecation_warning": ("ADEWARN", {"text": "Deprecated"})},
"deprecation_warning": ("ADEWARN", {"text": "Deprecated"}),
}
}

Expand Down Expand Up @@ -415,7 +415,7 @@ def INPUT_TYPES(s):
"mask_motion_scale": ("MASK",),
"min_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
"max_motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
"optional": {"deprecation_warning": ("ADEWARN", {"text": "Deprecated"})},
"deprecation_warning": ("ADEWARN", {"text": "Deprecated"}),
}
}

Expand Down
50 changes: 50 additions & 0 deletions animatediff/nodes_extras.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from typing import Union

import torch
from torch import Tensor

import folder_paths
import nodes as comfy_nodes
from comfy.model_patcher import ModelPatcher
import comfy.model_patcher
import comfy.samplers
from comfy.sd import load_checkpoint_guess_config

from .logger import logger
from .utils_model import BetaSchedules
from .utils_motion import extend_to_batch_size, prepare_mask_batch
from .model_injection import get_vanilla_model_patcher
from .cfg_extras import perturbed_attention_guidance_patch, rescale_cfg_patch


class AnimateDiffUnload:
Expand Down Expand Up @@ -76,3 +83,46 @@ def INPUT_TYPES(s):
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return ({"samples":latent}, )


class PerturbedAttentionGuidanceMultival:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale_multival": ("MULTIVAL",),
}
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "Animate Diff 🎭🅐🅓/extras"

def patch(self, model: ModelPatcher, scale_multival: Union[float, Tensor]):
m = model.clone()
m.set_model_sampler_post_cfg_function(perturbed_attention_guidance_patch(scale_multival))

return (m,)


class RescaleCFGMultival:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"mult_multival": ("MULTIVAL",),
}
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "Animate Diff 🎭🅐🅓/extras"

def patch(self, model: ModelPatcher, mult_multival: Union[float, Tensor]):
m = model.clone()
m.set_model_sampler_cfg_function(rescale_cfg_patch(mult_multival))
return (m, )
Loading

0 comments on commit f3b24c1

Please sign in to comment.