Skip to content

Commit

Permalink
adjust setup, fix sparse controlnet loading and other issues, add spa…
Browse files Browse the repository at this point in the history
…rse controlnet to frontend/backend
  • Loading branch information
painebenjamin committed Dec 17, 2023
1 parent 5ad63a4 commit 8bf352c
Show file tree
Hide file tree
Showing 14 changed files with 478 additions and 216 deletions.
2 changes: 2 additions & 0 deletions src/js/controller/sidebar/99-invoke.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class InvokeButtonController extends Controller {
formattedState["anchor"] = datum.anchor;
formattedState["opacity"] = datum.opacity;
formattedState["visibility"] = datum.visibility;
formattedState["frame"] = isEmpty(datum.startFrame) ? 0 : datum.startFrame - 1;

if (datum.imagePrompt) {
formattedState["ip_adapter_scale"] = datum.imagePromptScale;
}
Expand Down
11 changes: 11 additions & 0 deletions src/js/forms/enfugue/image-editor/image.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ class ImageEditorImageNodeOptionsFormView extends FormView {
"controlnetUnits": {
"class": ControlNetUnitsInputView
}
},
"Animation Options": {
"startFrame": {
"label": "Starting Frame",
"class": NumberInputView,
"config": {
"tooltip": "When using animation, this controls what frame to begin this layer's influence, starting from one.",
"min": 1,
"step": 1
}
}
}
};

Expand Down
7 changes: 5 additions & 2 deletions src/js/forms/input/enfugue/engine.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ class ControlNetInputView extends SelectInputView {
"depth": "Depth Detection (MiDaS)",
"normal": "Normal Detection (Estimate)",
"pose": "Pose Detection (DWPose/OpenPose)",
"qr": "QR Code"
"qr": "QR Code",
"sparse-rgb": "Sparse RGB",
};

/**
Expand All @@ -239,7 +240,9 @@ class ControlNetInputView extends SelectInputView {
"<strong>Depth</strong>: This uses Intel's MiDaS model to estimate monocular depth from a single image. This uses a greyscale image showing the distance from the camera to any given object.<br />" +
"<strong>Normal</strong>: Normal maps are similar to depth maps, but instead of using a greyscale depth, three sets of distance data is encoded into red, green and blue channels.<br />" +
"<strong>DWPose/OpenPose</strong>: OpenPose is an AI model from the Carnegie Mellon University's Perceptual Computing Lab detects human limb, face and digit poses from an image, and DWPose is a faster and more accurate model built on top of OpenPose. Using this data, you can generate different people in the same pose.<br />" +
"<strong>QR Code</strong> is a specialized control network designed to generate images from QR codes that are scannable QR codes themselves.";
"<strong>QR Code</strong> is a specialized control network designed to generate images from QR codes that are scannable QR codes themselves.<br />" +
"<strong>Sparse RGB</strong> is a ControlNet designed for generating videos given one or more images as frames along a timeline. However it can also be used for image generation as a general-purpose reference ControlNet.";

};

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from __future__ import annotations
# type: ignore
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
Expand All @@ -15,6 +14,7 @@
# limitations under the License.
#
# Changes were made to this source code by Yuwei Guo.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -27,11 +27,8 @@
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin

from einops import repeat, rearrange
from enfugue.diffusion.animate.diff.resnet import InflatedConv3d
from enfugue.diffusion.animate.diff.unet_blocks import (
CrossAttnDownBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
get_down_block,
)
Expand Down
80 changes: 25 additions & 55 deletions src/python/enfugue/diffusion/animate/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from enfugue.diffusion.pipeline import EnfugueStableDiffusionPipeline
from enfugue.diffusion.util.torch_util import load_state_dict

from enfugue.diffusion.animate.diff.sparse_controlnet import SparseControlNetModel # type: ignore[attr-defined]
from enfugue.diffusion.animate.diff.unet import UNet3DConditionModel as AnimateDiffUNet # type: ignore[attr-defined]
from enfugue.diffusion.animate.diffxl.unet import UNet3DConditionModel as AnimateDiffXLUNet # type: ignore[attr-defined]
from enfugue.diffusion.animate.hotshot.unet import UNet3DConditionModel as HotshotUNet # type: ignore[attr-defined]
Expand Down Expand Up @@ -475,14 +474,14 @@ def load_diff_state_dict(
raise ValueError(f"Unknown AnimateDiff version {animate_diff_mm_version}")

motion_cache_dir = cache_dir
if not os.path.exists(os.path.join(cache_dir, os.path.basename(motion_module))):
if not os.path.exists(os.path.join(cache_dir, os.path.basename(motion_module))): # type: ignore[arg-type]
if motion_dir is not None:
motion_cache_dir = motion_dir

if task_callback is not None and not os.path.exists(os.path.join(motion_cache_dir, os.path.basename(motion_module))):
if task_callback is not None and not os.path.exists(os.path.join(motion_cache_dir, os.path.basename(motion_module))): # type: ignore[arg-type]
task_callback(f"Downloading {motion_module}")

motion_module = check_download_to_dir(motion_module, motion_cache_dir)
motion_module = check_download_to_dir(motion_module, motion_cache_dir) # type: ignore[arg-type]

if isinstance(motion_module, dict):
logger.debug(f"Loading AnimateDiff motion module with truncate length '{position_encoding_truncate_length}' and scale length '{position_encoding_scale_length}'")
Expand All @@ -493,6 +492,7 @@ def load_diff_state_dict(

state_dict = load_state_dict(motion_module) # type: ignore[assignment]

state_dict.pop("animatediff_config", "")
if position_encoding_truncate_length is not None or position_encoding_scale_length is not None:
for key in state_dict:
if key.endswith(".pe"):
Expand Down Expand Up @@ -539,7 +539,7 @@ def load_diff_xl_state_dict(

logger.debug(f"Loading AnimateDiff motion module {motion_module} with truncate length '{position_encoding_truncate_length}' and scale length '{position_encoding_scale_length}'")
state_dict = load_state_dict(motion_module) # type: ignore[assignment,arg-type]

state_dict.pop("animatediff_config", "")
if position_encoding_truncate_length is not None or position_encoding_scale_length is not None:
for key in state_dict:
if key.endswith(".pe"):
Expand Down Expand Up @@ -617,61 +617,31 @@ def create_diff_xl_unet(
)
return model

def get_sparse_controlnet(
self,
controlnet: Literal["sparse-rgb", "sparse-scribble"],
cache_dir: str,
motion_dir: Optional[str]=None,
task_callback: Optional[Callable[[str], None]]=None,
) -> SparseControlNetModel:
@classmethod
def get_sparse_controlnet_config(cls, use_simplified_embedding: bool) -> Dict[str, Any]:
"""
Loads a sparse controlnet from the UNet
Gets configuration for the sparse controlnet.
"""
if controlnet == "sparse-rgb":
controlnet_path = CONTROLNET_SPARSE_RGB
elif controlnet == "sparse-scribble":
controlnet_path = CONTROLNET_SPARSE_SCRIBBLE
else:
raise ValueError(f"Unknown ControlNet {controlnet}")

use_simplified_embedding = controlnet == "sparse-rgb"
sparse_controlnet_config = {
"set_noisy_sample_input_to_zero": True,
"use_simplified_condition_embedding": use_simplified_embedding,
"conditioning_channels": 4 if use_simplified_embedding else 3,
"use_motion_module": True,
"motion_module_resolutions": [1,2,4,8],
"motion_module_mid_block": False,
"motion_module_type": "Vanilla",
"motion_module_kwargs": {
"num_attention_heads": 8,
"num_transformer_block": 1,
"attention_block_types": ["Temporal_Self"],
"temporal_position_encoding": True,
"temporal_position_encoding_max_len": 32,
"temporal_attention_dim_div": 1
config = EnfugueStableDiffusionPipeline.get_sparse_controlnet_config(use_simplified_embedding)

return {
**config,
**{
"use_motion_module": True,
"motion_module_resolutions": [1,2,4,8],
"motion_module_mid_block": False,
"motion_module_type": "Vanilla",
"motion_module_kwargs": {
"num_attention_heads": 8,
"num_transformer_block": 1,
"attention_block_types": ["Temporal_Self"],
"temporal_position_encoding": True,
"temporal_position_encoding_max_len": 32,
"temporal_attention_dim_div": 1
}
}
}

# Prepare UNet
self.unet.config.num_attention_heads = 8
self.unet.config.projection_class_embeddings_input_dim = None
# Create model
controlnet_model = SparseControlNetModel.from_unet(
self.unet,
controlnet_additional_kwargs=sparse_controlnet_config
)

if task_callback is not None and not os.path.exists(os.path.join(cache_dir, os.path.basename(controlnet_path))):
task_callback(f"Downloading {controlnet_path}")

controlnet_module = check_download_to_dir(controlnet_path, cache_dir)
controlnet_state_dict = load_state_dict(controlnet_module)
if "controlnet" in controlnet_state_dict:
controlnet_state_dict = controlnet_state_dict["controlnet"]
controlnet_model.load_state_dict(controlnet_state_dict)
return controlnet_model

def load_motion_module_weights(
self,
cache_dir: str,
Expand Down
4 changes: 4 additions & 0 deletions src/python/enfugue/diffusion/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"MOTION_LORA_LITERAL",
"LCM_LORA_DEFAULT",
"LCM_LORA_XL",
"SPARSE_CONTROLNET_ADAPTER_LORA",
"SCHEDULER_LITERAL",
"DEVICE_LITERAL",
"PIPELINE_SWITCH_MODE_LITERAL",
Expand Down Expand Up @@ -348,6 +349,7 @@
MOTION_LORA_ZOOM_OUT = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_ZoomOut.ckpt"
LCM_LORA_DEFAULT = "https://huggingface.co/latent-consistency/lcm-lora-sdv1-5/resolve/main/pytorch_lora_weights.safetensors?filename=lcm-lora-sdv1-5.safetensors"
LCM_LORA_XL = "https://huggingface.co/latent-consistency/lcm-lora-sdxl/resolve/main/pytorch_lora_weights.safetensors?filename=lcm-lora-sdxl.safetensors"
SPARSE_CONTROLNET_ADAPTER_LORA = "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt"

MultiModelType = Union[str, List[str]]
WeightedMultiModelType = Union[
Expand All @@ -362,6 +364,8 @@ class ImageDict(TypedDict):
image: Union[str, Image, List[Image]]
skip_frames: NotRequired[Optional[int]]
divide_frames: NotRequired[Optional[int]]
start_frame: NotRequired[Optional[int]]
end_frame: NotRequired[Optional[int]]
fit: NotRequired[Optional[IMAGE_FIT_LITERAL]]
anchor: NotRequired[Optional[IMAGE_ANCHOR_LITERAL]]
invert: NotRequired[bool]
Expand Down
Loading

0 comments on commit 8bf352c

Please sign in to comment.