Skip to content

Commit

Permalink
Support dynamic caching from Comfy-WaveSpeed nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Jan 31, 2025
1 parent a5f8f31 commit b742e10
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 32 deletions.
1 change: 1 addition & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CheckpointInput:
v_prediction_zsnr: bool = False
rescale_cfg: float = 0.7
self_attention_guidance: bool = False
dynamic_caching: bool = False


@dataclass
Expand Down
7 changes: 4 additions & 3 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,12 @@ def from_list(data: list):


class ClientFeatures(NamedTuple):
ip_adapter: bool
translation: bool
languages: list[TranslationPackage]
ip_adapter: bool = True
translation: bool = True
languages: list[TranslationPackage] = []
max_upload_size: int = 0
max_control_layers: int = 1000
wave_speed: bool = False


class Client(ABC):
Expand Down
1 change: 1 addition & 0 deletions ai_diffusion/cloud_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def performance_settings(self):
batch_size=clamp(settings.batch_size, 4, 8),
resolution_multiplier=settings.resolution_multiplier,
max_pixel_count=clamp(settings.max_pixel_count, 1, 8),
dynamic_caching=False,
)

@property
Expand Down
27 changes: 10 additions & 17 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def __init__(self, url):
self._requests = RequestManager()
self._id = str(uuid.uuid4())
self._active: Optional[JobInfo] = None
self._features: ClientFeatures = ClientFeatures()
self._supported_archs: dict[Arch, list[ResourceId]] = {}
self._supported_languages: list[TranslationPackage] = []
self._messages = asyncio.Queue()
self._queue = asyncio.Queue()
self._jobs: deque[JobInfo] = deque()
Expand All @@ -107,7 +107,6 @@ async def connect(url=default_url, access_token=""):

# Retrieve system info
client.device_info = DeviceInfo.parse(await client._get("system_stats"))
client._supported_languages = await _list_languages(client)

# Try to establish websockets connection
wsurl = websocket_url(client.url)
Expand All @@ -124,6 +123,13 @@ async def connect(url=default_url, access_token=""):
if len(missing) > 0:
raise MissingResources(missing)

client._features = ClientFeatures(
ip_adapter=True,
translation=True,
languages=await _list_languages(client),
wave_speed="ApplyFBCacheOnModel" in nodes,
)

# Check for required and optional model resources
models = client.models
models.node_inputs = {name: nodes[name]["input"] for name in nodes}
Expand Down Expand Up @@ -460,28 +466,15 @@ def missing_resources(self):

@property
def features(self):
return ClientFeatures(
ip_adapter=True, translation=True, languages=self._supported_languages
)

@property
def supports_ip_adapter(self):
return True

@property
def supports_translation(self):
return True

@property
def supported_languages(self):
return self._supported_languages
return self._features

@property
def performance_settings(self):
return PerformanceSettings(
batch_size=settings.batch_size,
resolution_multiplier=settings.resolution_multiplier,
max_pixel_count=settings.max_pixel_count,
dynamic_caching=settings.dynamic_caching and self.features.wave_speed,
)

async def upload_loras(self, work: WorkflowInput, local_job_id: str):
Expand Down
12 changes: 12 additions & 0 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,18 @@ def estimate_pose(self, image: Output, resolution: int):
mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx"
return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls)

def apply_first_block_cache(self, model: Output, arch: Arch):
return self.add(
"ApplyFBCacheOnModel",
1,
model=model,
object_to_patch="diffusion_model",
residual_diff_threshold=0.2 if arch.is_sdxl_like else 0.12,
start=0.0,
end=1.0,
max_consecutive_cache_hits=-1,
)

def create_hook_lora(self, loras: list[tuple[str, float]]):
key = "CreateHookLora" + str(loras)
hooks = self._cache.get(key, None)
Expand Down
9 changes: 8 additions & 1 deletion ai_diffusion/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ class CustomNode(NamedTuple):
"https://github.com/city96/ComfyUI-GGUF",
"5875c52f59baca3a9372d68c43a3775e21846fe0",
["UnetLoaderGGUF", "DualCLIPLoaderGGUF"],
)
),
CustomNode(
"WaveSpeed",
"Comfy-WaveSpeed",
"https://github.com/chengzeyi/Comfy-WaveSpeed",
"a9caacb0706c5fbe5fbc8718081f7c3e3e348ebd",
["ApplyFBCacheOnModel"],
),
]


Expand Down
30 changes: 22 additions & 8 deletions ai_diffusion/settings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass, asdict
from dataclasses import dataclass
import os
import json
from enum import Enum
from pathlib import Path
from typing import Optional, Any
from typing import NamedTuple, Optional, Any
from PyQt5.QtCore import QObject, pyqtSignal

from .util import is_macos, is_windows, user_data_dir, client_logger as log
Expand Down Expand Up @@ -55,11 +55,18 @@ class PerformancePreset(Enum):
custom = _("Custom")


class PerformancePresetSettings(NamedTuple):
batch_size: int = 4
resolution_multiplier: float = 1.0
max_pixel_count: int = 6


@dataclass
class PerformanceSettings:
batch_size: int = 4
resolution_multiplier: float = 1.0
max_pixel_count: int = 6
dynamic_caching: bool = False


class Setting:
Expand Down Expand Up @@ -253,28 +260,35 @@ class Settings(QObject):
_("Maximum resolution to generate images at, in megapixels (FullHD ~ 2MP, 4k ~ 8MP)."),
)

dynamic_caching: bool
_dynamic_caching = Setting(
_("Dynamic Caching"),
False,
_("Re-use outputs of previous steps (First Block Cache) to speed up generation."),
)

_performance_presets = {
PerformancePreset.cpu: PerformanceSettings(
PerformancePreset.cpu: PerformancePresetSettings(
batch_size=1,
resolution_multiplier=1.0,
max_pixel_count=2,
),
PerformancePreset.low: PerformanceSettings(
PerformancePreset.low: PerformancePresetSettings(
batch_size=2,
resolution_multiplier=1.0,
max_pixel_count=2,
),
PerformancePreset.medium: PerformanceSettings(
PerformancePreset.medium: PerformancePresetSettings(
batch_size=4,
resolution_multiplier=1.0,
max_pixel_count=6,
),
PerformancePreset.high: PerformanceSettings(
PerformancePreset.high: PerformancePresetSettings(
batch_size=6,
resolution_multiplier=1.0,
max_pixel_count=8,
),
PerformancePreset.cloud: PerformanceSettings(
PerformancePreset.cloud: PerformancePresetSettings(
batch_size=8,
resolution_multiplier=1.0,
max_pixel_count=6,
Expand Down Expand Up @@ -355,7 +369,7 @@ def load(self, path: Optional[Path] = None):

def apply_performance_preset(self, preset: PerformancePreset):
if preset not in [PerformancePreset.custom, PerformancePreset.auto]:
for k, v in asdict(self._performance_presets[preset]).items():
for k, v in self._performance_presets[preset]._asdict().items():
self._values[k] = v

def _migrate_legacy_settings(self, path: Path):
Expand Down
18 changes: 15 additions & 3 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,10 @@ def __init__(self):
self._max_pixel_count.value_changed.connect(self.write)
advanced_layout.addWidget(self._max_pixel_count)

self._dynamic_caching = SwitchSetting(Settings._dynamic_caching, parent=self)
self._dynamic_caching.value_changed.connect(self.write)
self._layout.addWidget(self._dynamic_caching)

self._layout.addStretch()

def _change_performance_preset(self, index):
Expand All @@ -603,13 +607,19 @@ def _change_performance_preset(self, index):
if not is_custom:
self.read()

def update_device_info(self):
def update_client_info(self):
if root.connection.state is ConnectionState.connected:
client = root.connection.client
self._device_info.setText(
_("Device")
+ f": [{client.device_info.type.upper()}] {client.device_info.name} ({client.device_info.vram} GB)"
)
self._dynamic_caching.enabled = client.features.wave_speed
self._dynamic_caching.setToolTip(
_("The {node_name} node is not installed.").format(node_name="Comfy-WaveSpeed")
if not client.features.wave_speed
else ""
)

def _read(self):
self._history_size.value = settings.history_size
Expand All @@ -622,7 +632,8 @@ def _read(self):
)
self._resolution_multiplier.value = settings.resolution_multiplier
self._max_pixel_count.value = settings.max_pixel_count
self.update_device_info()
self._dynamic_caching.value = settings.dynamic_caching
self.update_client_info()

def _write(self):
settings.history_size = self._history_size.value
Expand All @@ -633,6 +644,7 @@ def _write(self):
settings.performance_preset = list(PerformancePreset)[
self._performance_preset.currentIndex()
]
settings.dynamic_caching = self._dynamic_caching.value


class AboutSettings(SettingsTab):
Expand Down Expand Up @@ -897,7 +909,7 @@ def _update_connection(self):
self.connection.update_server_status()
if root.connection.state == ConnectionState.connected:
self.interface.update_translation(root.connection.client)
self.performance.update_device_info()
self.performance.update_client_info()

def _open_settings_folder(self):
QDesktopServices.openUrl(QUrl.fromLocalFile(str(util.user_data_dir)))
Expand Down
4 changes: 4 additions & 0 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
if vae is None:
vae = w.load_vae(models.for_arch(arch).vae)

if checkpoint.dynamic_caching and (arch in [Arch.flux, Arch.sd3] or arch.is_sdxl_like):
model = w.apply_first_block_cache(model, arch)

for lora in checkpoint.loras:
model, clip = w.load_lora(model, clip, lora.name, lora.strength, lora.strength)

Expand Down Expand Up @@ -1279,6 +1282,7 @@ def prepare(
i.models = style.get_models(models.checkpoints.keys())
i.conditioning.positive += _collect_lora_triggers(i.models.loras, files)
i.models.loras = unique(i.models.loras + extra_loras, key=lambda l: l.name)
i.models.dynamic_caching = perf.dynamic_caching
arch = i.models.version = resolve_arch(style, models)

_check_server_has_models(i.models, i.conditioning.regions, models, files, style.name)
Expand Down

0 comments on commit b742e10

Please sign in to comment.