Skip to content

Commit

Permalink
fix for new comfyui
Browse files Browse the repository at this point in the history
  • Loading branch information
gameltb committed Jan 4, 2024
1 parent 5cca381 commit b174afa
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 59 deletions.
7 changes: 5 additions & 2 deletions modules/spade.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
"""

import re

import torch
import torch.nn as nn

from ldm.modules.diffusionmodules.util import normalization, checkpoint
from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel
from comfy.ldm.modules.diffusionmodules.util import checkpoint

from .util import normalization


class SPADE(nn.Module):
Expand Down
27 changes: 9 additions & 18 deletions modules/struct_cond.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
import math
import torch
import torch.nn as nn

from ldm.modules.diffusionmodules.openaimodel import (
TimestepEmbedSequential,
ResBlock,
Downsample,
)

from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
timestep_embedding,
checkpoint,
normalization,
zero_module,
)
from comfy.ldm.modules.diffusionmodules.openaimodel import (
Downsample, ResBlock, TimestepEmbedSequential)
from comfy.ldm.modules.diffusionmodules.util import (checkpoint,
timestep_embedding,
zero_module)

# NOTE only change in file for Comyfui
from .attn import sr_get_attn_func as get_attn_func
from .util import conv_nd, linear, normalization

attn_func = None

Expand Down Expand Up @@ -271,7 +262,7 @@ def forward(self, x, timesteps):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels).to(x.dtype))

result_list = []
results = {}
Expand Down Expand Up @@ -303,7 +294,7 @@ def load_from_dict(self, state_dict):
self.load_state_dict(filtered_dict)


def build_unetwt() -> EncoderUNetModelWT:
def build_unetwt(use_fp16=False) -> EncoderUNetModelWT:
"""
Build a model from a state dict.
:param state_dict: a dict of parameters.
Expand All @@ -323,7 +314,7 @@ def build_unetwt() -> EncoderUNetModelWT:
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
use_fp16=use_fp16,
num_heads=4,
num_head_channels=-1,
num_heads_upsample=-1,
Expand Down
28 changes: 26 additions & 2 deletions modules/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
import torch
import numpy as np
import PIL.Image as Image
import torch


def pil2tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)


def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
return Image.fromarray(
np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
)


def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return torch.nn.Conv1d(*args, **kwargs)
elif dims == 2:
return torch.nn.Conv2d(*args, **kwargs)
elif dims == 3:
return torch.nn.Conv3d(*args, **kwargs)


def linear(*args, **kwargs):
return torch.nn.Linear(*args, **kwargs)


def normalization(channels):
return torch.nn.GroupNorm(32, channels)
105 changes: 68 additions & 37 deletions nodes.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,65 @@
from .modules.struct_cond import EncoderUNetModelWT, build_unetwt
from .modules.spade import SPADELayers
from .modules.util import pil2tensor, tensor2pil
from .modules.colorfix import adain_color_fix, wavelet_color_fix

import os
from torch import Tensor

import torch
import comfy.sample
from torch import Tensor

import comfy.model_management
import comfy.sample
import folder_paths

from .modules.colorfix import adain_color_fix, wavelet_color_fix
from .modules.spade import SPADELayers
from .modules.struct_cond import EncoderUNetModelWT, build_unetwt
from .modules.util import pil2tensor, tensor2pil

model_path = folder_paths.models_dir
folder_name = "stablesr"
folder_path = os.path.join(model_path, "stablesr") # set a default path for the common comfyui model path
folder_path = os.path.join(
model_path, "stablesr"
) # set a default path for the common comfyui model path
if folder_name in folder_paths.folder_names_and_paths:
folder_path = folder_paths.folder_names_and_paths[folder_name][0][0] # if a custom path was set in extra_model_paths.yaml then use it
folder_paths.folder_names_and_paths["stablesr"] = ([folder_path], folder_paths.supported_pt_extensions)
folder_path = folder_paths.folder_names_and_paths[folder_name][0][
0
] # if a custom path was set in extra_model_paths.yaml then use it
folder_paths.folder_names_and_paths["stablesr"] = (
[folder_path],
folder_paths.supported_pt_extensions,
)


class StableSRColorFix:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE", ),
"color_map_image": ("IMAGE", ),
"color_fix": (["Wavelet", "AdaIN",],),
},
return {
"required": {
"image": ("IMAGE",),
"color_map_image": ("IMAGE",),
"color_fix": (
[
"Wavelet",
"AdaIN",
],
),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "fix_color"
CATEGORY = "image"

def fix_color(self, image, color_map_image, color_fix):
print(f'[StableSR] fix_color')
print(f"[StableSR] fix_color")
try:
color_fix_func = wavelet_color_fix if color_fix == 'Wavelet' else adain_color_fix
result_image = color_fix_func(tensor2pil(image), tensor2pil(color_map_image))
color_fix_func = (
wavelet_color_fix if color_fix == "Wavelet" else adain_color_fix
)
result_image = color_fix_func(
tensor2pil(image), tensor2pil(color_map_image)
)
refined_image = pil2tensor(result_image)
return (refined_image, )
return (refined_image,)
except Exception as e:
print(f'[StableSR] Error fix_color: {e}')
print(f"[StableSR] Error fix_color: {e}")


original_sample = comfy.sample.sample
Expand All @@ -59,33 +79,36 @@ def hook_sample(*args, **kwargs):


class StableSR:
'''
"""
Initializes a StableSR model.
Args:
path: The path to the StableSR checkpoint file.
dtype: The data type of the model. If not specified, the default data type will be used.
device: The device to run the model on. If not specified, the default device will be used.
'''
"""

def __init__(self, stable_sr_model_path, dtype, device):
print(f"[StbaleSR] in StableSR init - dtype: {dtype}, device: {device}")
state_dict = comfy.utils.load_torch_file(stable_sr_model_path)

self.struct_cond_model: EncoderUNetModelWT = build_unetwt()
self.struct_cond_model: EncoderUNetModelWT = build_unetwt(
use_fp16=dtype == torch.float16
)
self.spade_layers: SPADELayers = SPADELayers()
self.struct_cond_model.load_from_dict(state_dict)
self.spade_layers.load_from_dict(state_dict)
del state_dict

self.dtype = dtype
self.struct_cond_model.apply(lambda x: x.to(dtype=dtype, device=device))
self.spade_layers.apply(lambda x: x.to(dtype=dtype, device=device))
self.latent_image: Tensor = None
self.set_image_hooks = {}
self.struct_cond: Tensor = None

self.auto_set_latent = False
self.last_t = 0.
self.last_t = 0.0

def set_latent_image(self, latent_image):
self.latent_image = latent_image
Expand All @@ -99,26 +122,32 @@ def __call__(self, model_function, params):
timestep = params.get("timestep")
c = params.get("c")

t = model_function.__self__.model_sampling.timestep(timestep)

if self.auto_set_latent:
tt = float(timestep[0])
tt = float(t[0])
if self.last_t <= tt:
latent_image = model_function.__self__.process_latent_in(SAMPLE_X)
self.set_latent_image(latent_image)
self.last_t = tt

# set latent image to device
device = input_x.device
latent_image = self.latent_image.to(device)
latent_image = self.latent_image.to(dtype=self.dtype, device=device)

# Ensure the device of all modules layers is the same as the unet
# This will fix the issue when user use --medvram or --lowvram
self.spade_layers.to(device)
self.struct_cond_model.to(device)

self.struct_cond = None # mitigate vram peak
self.struct_cond = self.struct_cond_model(latent_image, timestep[:latent_image.shape[0]])
self.struct_cond = self.struct_cond_model(
latent_image, t[: latent_image.shape[0]]
)

self.spade_layers.hook(model_function.__self__.diffusion_model, lambda: self.struct_cond)
self.spade_layers.hook(
model_function.__self__.diffusion_model, lambda: self.struct_cond
)

# Call the model_function with the provided arguments
result = model_function(input_x, timestep, **c)
Expand All @@ -140,11 +169,11 @@ class ApplyStableSRUpscaler:
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"stablesr_model": (folder_paths.get_filename_list("stablesr"), ),
"model": ("MODEL",),
"stablesr_model": (folder_paths.get_filename_list("stablesr"),),
},
"optional": {
"latent_image": ("LATENT", ),
"latent_image": ("LATENT",),
},
}

Expand All @@ -153,12 +182,14 @@ def INPUT_TYPES(s):
FUNCTION = "apply_stable_sr_upscaler"
CATEGORY = "image/upscaling"

def apply_stable_sr_upscaler(self, model, stablesr_model, latent_image=None):
def apply_stable_sr_upscaler(self, model, stablesr_model, latent_image=None):
stablesr_model_path = folder_paths.get_full_path("stablesr", stablesr_model)
if not os.path.isfile(stablesr_model_path):
raise Exception(f'[StableSR] Invalid StableSR model reference')
raise Exception(f"[StableSR] Invalid StableSR model reference")

upscaler = StableSR(stablesr_model_path, dtype=torch.float32, device="cpu")
upscaler = StableSR(
stablesr_model_path, dtype=comfy.model_management.unet_dtype(), device="cpu"
)
if latent_image != None:
latent_image = model.model.process_latent_in(latent_image["samples"])
upscaler.set_latent_image(latent_image)
Expand All @@ -167,15 +198,15 @@ def apply_stable_sr_upscaler(self, model, stablesr_model, latent_image=None):

model_sr = model.clone()
model_sr.set_model_unet_function_wrapper(upscaler)
return (model_sr, )
return (model_sr,)


NODE_CLASS_MAPPINGS = {
"StableSRColorFix": StableSRColorFix,
"ApplyStableSRUpscaler": ApplyStableSRUpscaler
"ApplyStableSRUpscaler": ApplyStableSRUpscaler,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"StableSRColorFix": "StableSRColorFix",
"ApplyStableSRUpscaler": "ApplyStableSRUpscaler"
"ApplyStableSRUpscaler": "ApplyStableSRUpscaler",
}

0 comments on commit b174afa

Please sign in to comment.