Skip to content

Commit

Permalink
Merge PR #425 from Kosinkadink/develop - Visualize Context Options nodes
Browse files Browse the repository at this point in the history
Visualize Context Options nodes
  • Loading branch information
Kosinkadink authored Jul 11, 2024
2 parents f3b24c1 + c8480f9 commit 1b660e5
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 19 deletions.
191 changes: 182 additions & 9 deletions animatediff/context.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Callable, Optional, Union
from typing import Union

import torch
import torchvision
import PIL
from PIL import Image, ImageFont, ImageDraw

import numpy as np
from torch import Tensor

import comfy.samplers
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher

from .utils_motion import get_sorted_list_via_attr

Expand Down Expand Up @@ -76,7 +79,7 @@ def __init__(self):
self._current_context: ContextOptions = None
self._current_used_steps: int = 0
self._current_index: int = 0
self.step = 0
self._step = 0

def reset(self):
self._current_context = None
Expand All @@ -85,6 +88,15 @@ def reset(self):
self.step = 0
self._set_first_as_current()

@property
def step(self):
return self._step
@step.setter
def step(self, value: int):
self._step = value
if self._current_context is not None:
self._current_context.step = value

@classmethod
def default(cls):
def_context = ContextOptions()
Expand Down Expand Up @@ -492,15 +504,176 @@ class Colors:
CYAN = (0, 255, 255)


class BorderWidth:
INDEXES = 2
CONTEXT = 4


class VisualizeSettings:
def __init__(self, img_width, img_height, video_length):
self.img_width = img_width
self.img_height = img_height
def __init__(self, img_width: int, video_length: int):
self.video_length = video_length
self.img_width = img_width
self.grid = img_width // video_length
self.img_height = self.grid * 5
self.pil_to_tensor = torchvision.transforms.Compose([torchvision.transforms.PILToTensor()])
self.font_size = int(self.grid * 0.5)
self.font = ImageFont.load_default(size=self.font_size)
#self.title_font = ImageFont.load_default(size=int(self.font_size * 1.5))
self.title_font = ImageFont.load_default(size=int(self.font_size * 1.2))

self.background_color = Colors.BLACK
self.grid_outline_color = Colors.WHITE
self.start_idx_fill_color = Colors.MAGENTA
self.subidx_end_color = Colors.YELLOW

def generate_context_visualization(context_opts: ContextOptionsGroup, model: BaseModel, width=1440, height=200, video_length=32, start_step=0, end_step=20):
vs = VisualizeSettings(width, height, video_length)
pass
self.context_color = Colors.GREEN
self.view_color = Colors.RED


class GridDisplay:
def __init__(self, draw: ImageDraw.ImageDraw, vs: VisualizeSettings, home_x: int=0, home_y: int=0):
self.home_x = home_x
self.home_y = home_y
self.draw = draw
self.vs = vs


def get_text_xy(input: str, font: ImageFont, x: int, y: int, centered=True):
return (x, y,)


def draw_text(text: str, font: ImageFont, gd: GridDisplay, x: int, y: int, color=Colors.WHITE, centered=True):
x, y = get_text_xy(text, font, x, y, centered=centered)
gd.draw.text(xy=(gd.home_x+x, gd.home_y+y), text=text, fill=color, font=font)


def draw_first_grid_row(total_length: int, gd: GridDisplay, start_idx=-1):
vs = gd.vs
# the first row is white squares, with the indexes drawed in
for i in range(total_length):
x1 = gd.home_x+(vs.grid*i)
y1 = gd.home_y
x2 = x1 + vs.grid
y2 = y1 + vs.grid

fill = None
if i==start_idx:
fill=vs.start_idx_fill_color
gd.draw.rectangle(xy=(x1, y1, x2, y2), fill=fill, outline=vs.grid_outline_color, width=BorderWidth.INDEXES)
draw_text(text=str(i), font=vs.font, gd=gd, x=vs.grid*i, y=0)


def draw_subidxs(window: list[int], gd: GridDisplay, y_grid_offset: int, color: tuple):
vs = gd.vs
# with no indexes drawed in- just solid squares, mostly
y_offset = vs.grid * y_grid_offset
for i, val in enumerate(window):
x1 = gd.home_x+(vs.grid*val)
y1 = gd.home_y+y_offset
x2 = x1 + vs.grid
y2 = y1 + vs.grid
fill_color = color
# if at an end of indexes, make inside be different color
if i == 0 or i == len(window)-1:
fill_color = vs.subidx_end_color
gd.draw.rectangle(xy=(x1, y1, x2, y2), fill=fill_color, outline=color, width=BorderWidth.CONTEXT)


def draw_context(window: list[int], gd: GridDisplay):
draw_subidxs(window=window, gd=gd, y_grid_offset=1, color=gd.vs.context_color)


def draw_view(window: list[int], gd: GridDisplay):
draw_subidxs(window=window, gd=gd, y_grid_offset=2, color=gd.vs.view_color)


def generate_context_visualization(context_opts: ContextOptionsGroup, model: ModelPatcher, sampler_name: str=None, scheduler: str=None,
width=1440, height=200, video_length=32,
steps=None, start_step=None, end_step=None, sigmas=None, force_full_denoise=False, denoise=None):
context_opts = context_opts.clone()
vs = VisualizeSettings(width, video_length)
all_imgs = []

if sigmas is None:
sampler = comfy.samplers.KSampler(
model=model, steps=steps, device="cpu", sampler=sampler_name, scheduler=scheduler,
denoise=denoise, model_options=model.model_options,
)
sigmas = sampler.sigmas
if end_step is not None and end_step < (len(sigmas) - 1):
sigmas = sigmas[:end_step + 1]
if force_full_denoise:
sigmas[-1] = 0
if start_step is not None:
if start_step < (len(sigmas) - 1):
sigmas = sigmas[start_step:]
# remove last sigma, as sampling uses pairs of sigmas at a time (fence post problem)
sigmas = sigmas[:-1]

context_opts.reset()
context_opts.initialize_timesteps(model.model)

if start_step is None:
start_step = 0 # use this in case start_step is provided, to display accurate step
if steps is None:
steps = len(sigmas)

for i, t in enumerate(sigmas):
# make context_opts reflect current step/sigma
context_opts.prepare_current_context([t])
context_opts.step = start_step+i

# check if context should even be active in this case
context_active = True
if video_length < context_opts.context_length:
context_active = False
elif video_length == context_opts.context_length and not context_opts.use_on_equal_length:
context_active = False

if context_active:
context_windows = get_context_windows(num_frames=video_length, opts=context_opts)
else:
context_windows = [list(range(video_length))]
start_idx = -1
for j,window in enumerate(context_windows):
repeat_count = 0
view_windows = []
total_repeats = 1
view_options = context_opts.view_options
if view_options is not None:
view_active = True
if len(window) < view_options.context_length:
view_active = False
elif video_length == view_options.context_length and not view_options.use_on_equal_length:
view_active = False
if view_active:
view_windows = get_context_windows(num_frames=len(window), opts=view_options)
total_repeats = len(view_windows)
while total_repeats > repeat_count:
# create new frame
frame: Image = Image.new(mode="RGB", size=(vs.img_width, vs.img_height), color=vs.background_color)
draw = ImageDraw.Draw(frame)
gd = GridDisplay(draw=draw, vs=vs, home_x=0, home_y=vs.grid)
# if views present, do view stuff
if len(view_windows) > 0:
converted_view = [window[x] for x in view_windows[repeat_count]]
draw_view(window=converted_view, gd=gd)
# draw context_type + current step
title_str = f"{context_opts.context_schedule} - Step {context_opts.step+1}/{steps} (Context {j+1}/{len(context_windows)})"
if len(view_windows) > 0:
title_str = f"{title_str} (View {repeat_count+1}/{len(view_windows)})"
draw_text(text=title_str, font=vs.title_font, gd=gd, x=0-gd.home_x, y=0-gd.home_y, centered=False)
# draw first row (total length, white)
if j == 0:
start_idx = window[0]
draw_first_grid_row(total_length=video_length, gd=gd, start_idx=start_idx)
# draw context row
draw_context(window=window, gd=gd)
# save image + iterate repeat_count
img: Tensor = vs.pil_to_tensor(frame)
all_imgs.append(img)
repeat_count += 1

images = torch.stack(all_imgs)
images = images.movedim(1, -1).to(torch.float32)
return images
11 changes: 8 additions & 3 deletions animatediff/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
NoisedImageInjectionNode, NoisedImageInjectOptionsNode)
from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode)
from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode,
StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsInt)
StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode,
VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom)
from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode,
WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode,
WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode)
Expand Down Expand Up @@ -58,7 +59,9 @@
"ADE_ViewsOnlyContextOptions": ViewAsContextOptionsNode,
"ADE_BatchedContextOptions": BatchedContextOptionsNode,
"ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy
#"ADE_VisualizeContextOptions": VisualizeContextOptionsInt,
"ADE_VisualizeContextOptionsK": VisualizeContextOptionsK,
"ADE_VisualizeContextOptionsKAdv": VisualizeContextOptionsKAdv,
"ADE_VisualizeContextOptionsSCustom": VisualizeContextOptionsSCustom,
# View Opts
"ADE_StandardStaticViewOptions": StandardStaticViewOptionsNode,
"ADE_StandardUniformViewOptions": StandardUniformViewOptionsNode,
Expand Down Expand Up @@ -180,7 +183,9 @@
"ADE_ViewsOnlyContextOptions": "Context Options◆Views Only [VRAM⇈] 🎭🅐🅓",
"ADE_BatchedContextOptions": "Context Options◆Batched [Non-AD] 🎭🅐🅓",
"ADE_AnimateDiffUniformContextOptions": "Context Options◆Looped Uniform 🎭🅐🅓", # Legacy
"ADE_VisualizeContextOptions": "Visualize Context Options 🎭🅐🅓",
"ADE_VisualizeContextOptionsK": "Visualize Context Options (K.) 🎭🅐🅓",
"ADE_VisualizeContextOptionsKAdv": "Visualize Context Options (K.Adv.) 🎭🅐🅓",
"ADE_VisualizeContextOptionsSCustom": "Visualize Context Options (S.Cus.) 🎭🅐🅓",
# View Opts
"ADE_StandardStaticViewOptions": "View Options◆Standard Static 🎭🅐🅓",
"ADE_StandardUniformViewOptions": "View Options◆Standard Uniform 🎭🅐🅓",
Expand Down
76 changes: 70 additions & 6 deletions animatediff/nodes_context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
from torch import Tensor

import comfy.samplers
from comfy.model_patcher import ModelPatcher

from .context import ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules
from .utils_model import BIGMAX
from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules,
generate_context_visualization)
from .utils_model import BIGMAX, MAX_RESOLUTION


LENGTH_MAX = 128 # keep an eye on these max values;
Expand Down Expand Up @@ -353,16 +355,20 @@ def create_options(self, view_length: int, view_overlap: int, view_stride: int,
return (view_options,)


class VisualizeContextOptionsInt:
class VisualizeContextOptionsKAdv:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"context_opts": ("CONTEXT_OPTIONS",),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
},
"optional": {
"visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}),
"latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}),
"steps": ("INT", {"min": 0, "max": BIGMAX, "default": 20}),
"start_step": ("INT", {"min": 0, "max": BIGMAX, "default": 0}),
"end_step": ("INT", {"min": 1, "max": BIGMAX, "default": 20}),
}
Expand All @@ -372,7 +378,65 @@ def INPUT_TYPES(s):
CATEGORY = "Animate Diff 🎭🅐🅓/context opts/visualize"
FUNCTION = "visualize"

def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup,
latents_length=32, start_step=0, end_step=20):
images = torch.zeros((latents_length, 256, 256, 3))
def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str,
visual_width: 1280, latents_length=32, steps=20, start_step=0, end_step=20):
images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length,
sampler_name=sampler_name, scheduler=scheduler,
steps=steps, start_step=start_step, end_step=end_step)
return (images,)


class VisualizeContextOptionsK:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"context_opts": ("CONTEXT_OPTIONS",),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
},
"optional": {
"visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}),
"latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}),
"steps": ("INT", {"min": 0, "max": BIGMAX, "default": 20}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}

RETURN_TYPES = ("IMAGE",)
CATEGORY = "Animate Diff 🎭🅐🅓/context opts/visualize"
FUNCTION = "visualize"

def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str,
visual_width: 1280, latents_length=32, steps=20, denoise=1.0):
images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length,
sampler_name=sampler_name, scheduler=scheduler,
steps=steps, denoise=denoise)
return (images,)


class VisualizeContextOptionsSCustom:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"context_opts": ("CONTEXT_OPTIONS",),
"sigmas": ("SIGMAS", ),
},
"optional": {
"visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}),
"latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}),
}
}

RETURN_TYPES = ("IMAGE",)
CATEGORY = "Animate Diff 🎭🅐🅓/context opts/visualize"
FUNCTION = "visualize"

def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigmas,
visual_width: 1280, latents_length=32):
images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length,
sigmas=sigmas)
return (images,)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.0.9"
version = "1.0.10"
license = "LICENSE"
dependencies = []

Expand Down

0 comments on commit 1b660e5

Please sign in to comment.