diff --git a/__init__.py b/__init__.py index 1380392..15b1e10 100644 --- a/__init__.py +++ b/__init__.py @@ -30,14 +30,21 @@ NODE_DISPLAY_NAME_MAPPINGS = {} nodes = ["base", "lazy", "tools"] +optional_nodes = ["attnmask"] if importlib.util.find_spec("comfy.hooks"): - nodes.append("hooks") + nodes.extend(["hooks"]) else: - log.warning( - "Your ComfyUI version is too old, can't import comfy.hooks for PCEncodeSchedule and PCLoraHooksFromSchedule. Update your installation." - ) + log.error("Your ComfyUI version is too old, can't import comfy.hooks. Update your installation.") for node in nodes: mod = importlib.import_module(f".prompt_control.nodes_{node}", package=__name__) NODE_CLASS_MAPPINGS.update(mod.NODE_CLASS_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(mod.NODE_DISPLAY_NAME_MAPPINGS) + +for node in optional_nodes: + try: + mod = importlib.import_module(f".prompt_control.nodes_{node}", package=__name__) + NODE_CLASS_MAPPINGS.update(mod.NODE_CLASS_MAPPINGS) + NODE_DISPLAY_NAME_MAPPINGS.update(mod.NODE_DISPLAY_NAME_MAPPINGS) + except ImportError: + log.info(f"Could not import optional nodes: {node}; continuing anyway") diff --git a/doc/syntax.md b/doc/syntax.md index 3208573..f98cbad 100644 --- a/doc/syntax.md +++ b/doc/syntax.md @@ -195,3 +195,15 @@ The order of the `FEATHER` and `MASK` calls doesn't matter; you can have `FEATHE ## Miscellaneous - `` is alternative syntax for `embedding:xyz` to work around a syntax conflict with `[embedding:xyz:0.5]` which is parsed as a schedule that switches from `embedding` to `xyz`. + +# Experimental features + +Experimental features are unstable and may disappear or break without warning. + +## Attention masking + +Use `ATTN()` in combination with `MASK()` or `IMASK()` to enable attention masking. Currently, it's pretty slow and only works with SDXL. You need to have a recent enough version of ComfyUI for this to work. + +## TE_WEIGHT + +For models using multiple text encoders, you can set weights per TE using the syntax `TE_WEIGHT(clipname=weight, clipname2=weight2, ...)` where `clipname` is one of `g`, `l`, or `t5xxl`. For example with SDXL, try `TE_WEIGHT(g=0.25, l=0.75`). The weights are applied as a multiplier to the TE output. diff --git a/prompt_control/nodes_attnmask.py b/prompt_control/nodes_attnmask.py new file mode 100644 index 0000000..80f25a5 --- /dev/null +++ b/prompt_control/nodes_attnmask.py @@ -0,0 +1,79 @@ +import logging + +log = logging.getLogger("comfyui-prompt-control") + +from comfy.hooks import TransformerOptionsHook, HookGroup, EnumHookScope +from comfy.ldm.modules.attention import optimized_attention +import torch.nn.functional as F +import torch +from math import sqrt + + +class MaskedAttn2: + def __init__(self, mask): + self.mask = mask + + def __call__(self, q, k, v, extra_options): + mask = self.mask + orig_shape = extra_options["original_shape"] + _, _, oh, ow = orig_shape + seq_len = q.shape[1] + mask_h = oh / sqrt(oh * ow / seq_len) + mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) + mask_w = seq_len // mask_h + r = optimized_attention(q, k, v, extra_options["n_heads"]) + mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="nearest").squeeze(1) + mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, r.shape[2]) + + return mask * r + + +def create_attention_hook(mask): + attn_replacements = {} + mask = mask.detach().to(device="cuda", dtype=torch.float16) + + masked_attention = MaskedAttn2(mask) + + for id in [4, 5, 7, 8]: # id of input_blocks that have cross attention + block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth + for index in block_indices: + k = ("input", id, index) + attn_replacements[k] = masked_attention + for id in range(6): # id of output_blocks that have cross attention + block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth + for index in block_indices: + k = ("output", id, index) + attn_replacements[k] = masked_attention + for index in range(10): + k = ("middle", 1, index) + attn_replacements[k] = masked_attention + + hook = TransformerOptionsHook( + transformers_dict={"patches_replace": {"attn2": attn_replacements}}, hook_scope=EnumHookScope.HookedOnly + ) + group = HookGroup() + group.add(hook) + + return group + + +class AttentionMaskHookExperimental: + @classmethod + def INPUT_TYPES(s): + return { + "required": {"mask": ("MASK",)}, + } + + RETURN_TYPES = ("HOOKS",) + CATEGORY = "promptcontrol/_testing" + FUNCTION = "apply" + EXPERIMENTAL = True + DESCRIPTION = "Experimental attention masking hook. For testing only" + + def apply(self, mask): + return (create_attention_hook(mask),) + + +NODE_CLASS_MAPPINGS = {"AttentionMaskHookExperimental": AttentionMaskHookExperimental} + +NODE_DISPLAY_NAME_MAPPINGS = {} diff --git a/prompt_control/prompts.py b/prompt_control/prompts.py index 9e666e9..1ec6674 100644 --- a/prompt_control/prompts.py +++ b/prompt_control/prompts.py @@ -9,6 +9,21 @@ from .cutoff import process_cuts from .parser import parse_cuts +try: + from .nodes_attnmask import create_attention_hook + from comfy.hooks import set_hooks_for_conditioning + + def set_cond_attnmask(cond, mask): + hook = create_attention_hook(mask) + return set_hooks_for_conditioning(cond, hooks=hook) + +except ImportError: + + def set_cond_attnmask(cond, mask): + log.info("Attention masking is not available") + return cond + + log = logging.getLogger("comfyui-prompt-control") AVAILABLE_STYLES = ["comfy", "perp", "A1111", "compel", "comfy++", "down_weight"] @@ -430,6 +445,11 @@ def weight(t): # TODO: is this still needed? # scale = sum(abs(weight(p)[0]) for p in prompts if not ("AREA(" in p or "MASK(" in p)) for prompt in prompts: + attn = False + if "ATTN()" in prompt: + prompt = prompt.replace("ATTN()", "") + attn = True + log.info("Using attention masking for prompt segment") prompt, mask, mask_weight = get_mask(prompt, mask_size, masks) w, opts, prompt = weight(prompt) text, noise_w, generator = get_noise(text) @@ -452,6 +472,11 @@ def weight(t): settings["start_percent"] = start_pct settings["end_percent"] = end_pct x = encode_prompt_segment(clip, prompt, settings, style, normalization) + if attn and mask is not None: + mask = settings.pop("mask") + strength = settings.pop("mask_strength") + x = set_cond_attnmask(x, mask * strength) + conds.extend(x) return conds