From 294161afe678ab5565c561f863a689706a85c1d8 Mon Sep 17 00:00:00 2001 From: gameltb Date: Wed, 25 Oct 2023 20:40:34 +0800 Subject: [PATCH] update node --- README.md | 10 +- __init__.py | 7 - examples/stablesr_w_color_fix.json | 729 ++++++++++++++--------------- modules/stablesr.py | 301 ------------ modules/struct_cond.py | 6 +- nodes.py | 151 ++++-- 6 files changed, 460 insertions(+), 744 deletions(-) delete mode 100644 modules/stablesr.py diff --git a/README.md b/README.md index 20db030..19bb7ea 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,8 @@ -# pre-comfyui-stablsr -This is a development respository for debugging migration of StableSR to Comfyui +# comfyui-stablsr -There is a key bug the unet hook into Comfyui. It manifests itself as a error: mat1 and mat2 must have the same dtype. Currently I do not know how to solve configuring the diffusion model to resolve this issue. I have posted this code in hopes of finding some help from a diffusion expert to resolve it. +Put the StableSR webui_786v_139.ckpt model into Comyfui/models/stablesr/ +Put the StableSR stablesr_768v_000139.ckpt model into Comyfui/models/checkpoints/ -Put the StableSR webui_786v_139.ckpt model into Comyfui/models/stablesr/ - -Download the ckpt from HuggingFace https://huggingface.co/Iceclear/StableSR/blob/main/webui_768v_139.ckpt +Download the ckpt from HuggingFace https://huggingface.co/Iceclear/StableSR/ There is a setup json in /examples/ to load the workflow into Comfyui diff --git a/__init__.py b/__init__.py index 10701d5..3a1591d 100644 --- a/__init__.py +++ b/__init__.py @@ -4,13 +4,6 @@ @nickname: StableSR @description: This module enables StableSR in Comgfyui. Ported work of sd-webui-stablesr. Original work for Auotmaatic1111 version of this module and StableSR credit to LIightChaser and Jianyi Wang. """ -import folder_paths -import os -import sys - -modules_path = os.path.join(os.path.dirname(__file__), "modules") - -sys.path.append(modules_path) from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS diff --git a/examples/stablesr_w_color_fix.json b/examples/stablesr_w_color_fix.json index b67516b..bc06ccf 100644 --- a/examples/stablesr_w_color_fix.json +++ b/examples/stablesr_w_color_fix.json @@ -1,28 +1,26 @@ { - "last_node_id": 28, - "last_link_id": 54, + "last_node_id": 35, + "last_link_id": 74, "nodes": [ { - "id": 7, + "id": 14, "type": "CLIPTextEncode", "pos": [ - 454, - 470 + 57, + 505 ], "size": { - "0": 425.27801513671875, - "1": 180.6060791015625 + "0": 400, + "1": 200 }, - "flags": { - "collapsed": false - }, - "order": 4, + "flags": {}, + "order": 3, "mode": 0, "inputs": [ { "name": "clip", "type": "CLIP", - "link": 5 + "link": 19 } ], "outputs": [ @@ -30,79 +28,38 @@ "name": "CONDITIONING", "type": "CONDITIONING", "links": [ - 25 - ], - "slot_index": 0 - } - ], - "title": "CLIP Text Encode (Negative)", - "properties": { - "Node name for S&R": "CLIPTextEncode" - }, - "widgets_values": [ - "text, watermark" - ], - "color": "#222", - "bgcolor": "#000" - }, - { - "id": 18, - "type": "LoadImage", - "pos": [ - 989, - 183 - ], - "size": { - "0": 321, - "1": 337 - }, - "flags": {}, - "order": 0, - "mode": 0, - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 50 + 65 ], "shape": 3, "slot_index": 0 - }, - { - "name": "MASK", - "type": "MASK", - "links": null, - "shape": 3 } ], "properties": { - "Node name for S&R": "LoadImage" + "Node name for S&R": "CLIPTextEncode" }, "widgets_values": [ - "P1050788_Color_1_512.png", - "image" + "(masterpiece), (best quality), (realistic),(very clear)" ] }, { - "id": 6, + "id": 15, "type": "CLIPTextEncode", "pos": [ - 453, - 256 + 56, + 777 ], "size": { - "0": 422.84503173828125, - "1": 164.31304931640625 + "0": 400, + "1": 200 }, "flags": {}, - "order": 3, + "order": 4, "mode": 0, "inputs": [ { "name": "clip", "type": "CLIP", - "link": 3 + "link": 22 } ], "outputs": [ @@ -110,88 +67,82 @@ "name": "CONDITIONING", "type": "CONDITIONING", "links": [ - 24 + 66 ], + "shape": 3, "slot_index": 0 } ], - "title": "CLIP Text Encode (Positive)", "properties": { "Node name for S&R": "CLIPTextEncode" }, "widgets_values": [ - "best quality" + "3d, cartoon, anime, sketches, (worst quality), (low quality)" ] }, { - "id": 4, - "type": "CheckpointLoaderSimple", + "id": 31, + "type": "ApplyStableSRUpscaler", "pos": [ - 32, - 284 + 550, + 236 ], "size": { "0": 315, - "1": 98 + "1": 78 }, "flags": {}, - "order": 1, + "order": 6, "mode": 0, - "outputs": [ + "inputs": [ { - "name": "MODEL", + "name": "model", "type": "MODEL", - "links": [ - 23 - ], - "slot_index": 0 + "link": 63 }, { - "name": "CLIP", - "type": "CLIP", - "links": [ - 3, - 5, - 26 - ], - "slot_index": 1 - }, + "name": "latent_image", + "type": "LATENT", + "link": 67 + } + ], + "outputs": [ { - "name": "VAE", - "type": "VAE", + "name": "MODEL", + "type": "MODEL", "links": [ - 27 + 64 ], - "slot_index": 2 + "shape": 3, + "slot_index": 0 } ], "properties": { - "Node name for S&R": "CheckpointLoaderSimple" + "Node name for S&R": "ApplyStableSRUpscaler" }, "widgets_values": [ - "v2-1_768-ema-pruned.ckpt" + "webui_768v_139.ckpt" ] }, { - "id": 24, + "id": 26, "type": "PreviewImage", "pos": [ - 2067, - 260 + 1247, + 498 ], "size": { - "0": 326, - "1": 335 + "0": 426.760009765625, + "1": 541.3356323242188 }, "flags": {}, - "order": 12, + "order": 10, "mode": 0, "inputs": [ { "name": "images", "type": "IMAGE", - "link": 38, - "slot_index": 0 + "link": 41 } ], "properties": { @@ -200,180 +151,203 @@ }, { "id": 13, - "type": "PreviewImage", + "type": "VAEEncode", "pos": [ - 1703, - 885 + 593, + 110 ], "size": { - "0": 332, - "1": 337 + "0": 210, + "1": 46 }, "flags": {}, - "order": 11, + "order": 5, "mode": 0, "inputs": [ { - "name": "images", + "name": "pixels", "type": "IMAGE", - "link": 16 + "link": 56 + }, + { + "name": "vae", + "type": "VAE", + "link": 15 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 67, + 68, + 70 + ], + "shape": 3, + "slot_index": 0 } ], "properties": { - "Node name for S&R": "PreviewImage" + "Node name for S&R": "VAEEncode" } }, { - "id": 8, - "type": "VAEDecode", + "id": 4, + "type": "CheckpointLoaderSimple", "pos": [ - 1703, - 792 + -406, + 490 ], "size": { - "0": 210, - "1": 46 - }, - "flags": { - "collapsed": false + "0": 315, + "1": 98 }, - "order": 8, + "flags": {}, + "order": 1, "mode": 0, - "inputs": [ + "outputs": [ { - "name": "samples", - "type": "LATENT", - "link": 30, + "name": "MODEL", + "type": "MODEL", + "links": [ + 63 + ], "slot_index": 0 }, { - "name": "vae", - "type": "VAE", - "link": 29, + "name": "CLIP", + "type": "CLIP", + "links": [ + 19, + 22 + ], "slot_index": 1 - } - ], - "outputs": [ + }, { - "name": "IMAGE", - "type": "IMAGE", + "name": "VAE", + "type": "VAE", "links": [ - 16 + 8, + 15, + 71 ], - "slot_index": 0 + "slot_index": 2 } ], "properties": { - "Node name for S&R": "VAEDecode" - } + "Node name for S&R": "CheckpointLoaderSimple" + }, + "widgets_values": [ + "stablesr_768v_000139.ckpt" + ] }, { - "id": 20, - "type": "PreviewImage", + "id": 33, + "type": "VAEDecode", "pos": [ - 1703, - 259 + 1467, + 79 ], "size": { - "0": 326, - "1": 335 + "0": 210, + "1": 46 }, "flags": {}, - "order": 9, + "order": 7, "mode": 0, "inputs": [ { - "name": "images", + "name": "samples", + "type": "LATENT", + "link": 70 + }, + { + "name": "vae", + "type": "VAE", + "link": 71 + } + ], + "outputs": [ + { + "name": "IMAGE", "type": "IMAGE", - "link": 53 + "links": [ + 72 + ], + "shape": 3, + "slot_index": 0 } ], "properties": { - "Node name for S&R": "PreviewImage" + "Node name for S&R": "VAEDecode" } }, { - "id": 17, - "type": "ToBasicPipe", + "id": 8, + "type": "VAEDecode", "pos": [ - 541, - 95 + 1468, + 181 ], "size": { - "0": 241.79998779296875, - "1": 106 + "0": 210, + "1": 46 }, "flags": {}, - "order": 5, + "order": 9, "mode": 0, "inputs": [ { - "name": "model", - "type": "MODEL", - "link": 23 - }, - { - "name": "clip", - "type": "CLIP", - "link": 26 + "name": "samples", + "type": "LATENT", + "link": 60 }, { "name": "vae", "type": "VAE", - "link": 27 - }, - { - "name": "positive", - "type": "CONDITIONING", - "link": 24 - }, - { - "name": "negative", - "type": "CONDITIONING", - "link": 25 + "link": 8 } ], "outputs": [ { - "name": "basic_pipe", - "type": "BASIC_PIPE", + "name": "IMAGE", + "type": "IMAGE", "links": [ - 31, - 54 + 41, + 73 ], - "shape": 3, "slot_index": 0 } ], "properties": { - "Node name for S&R": "ToBasicPipe" + "Node name for S&R": "VAEDecode" } }, { - "id": 25, - "type": "ColorFix", + "id": 34, + "type": "StableSRColorFix", "pos": [ - 2072, - 133 + 1732, + 118 ], "size": { "0": 315, "1": 78 }, "flags": {}, - "order": 10, + "order": 11, "mode": 0, "inputs": [ { "name": "image", "type": "IMAGE", - "link": 51 + "link": 73 }, { "name": "color_map_image", "type": "IMAGE", - "link": 52, - "slot_index": 1 + "link": 72 } ], "outputs": [ @@ -381,347 +355,340 @@ "name": "IMAGE", "type": "IMAGE", "links": [ - 38 + 74 ], "shape": 3, "slot_index": 0 } ], "properties": { - "Node name for S&R": "ColorFix" + "Node name for S&R": "StableSRColorFix" }, "widgets_values": [ "Wavelet" ] }, { - "id": 5, - "type": "EmptyLatentImage", + "id": 30, + "type": "KSampler", "pos": [ - 904, - 860 + 1012, + 124 ], "size": { "0": 315, - "1": 106 + "1": 262 }, "flags": {}, - "order": 2, + "order": 8, "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 64 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 65 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 66 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 68, + "slot_index": 3 + } + ], "outputs": [ { "name": "LATENT", "type": "LATENT", - "links": [], + "links": [ + 60 + ], + "shape": 3, "slot_index": 0 } ], "properties": { - "Node name for S&R": "EmptyLatentImage" + "Node name for S&R": "KSampler" }, "widgets_values": [ - 768, - 768, + 175840193994180, + "randomize", + 20, + 8, + "euler", + "normal", 1 ] }, { - "id": 22, - "type": "ImpactKSamplerBasicPipe", + "id": 29, + "type": "ImageScaleBy", "pos": [ - 1348, - 795 + 197, + 47 ], "size": { "0": 315, - "1": 242 + "1": 82 }, "flags": {}, - "order": 6, + "order": 2, "mode": 0, "inputs": [ { - "name": "basic_pipe", - "type": "BASIC_PIPE", - "link": 31 - }, - { - "name": "latent_image", - "type": "LATENT", - "link": null + "name": "image", + "type": "IMAGE", + "link": 54 } ], "outputs": [ { - "name": "BASIC_PIPE", - "type": "BASIC_PIPE", - "links": null, - "shape": 3 - }, - { - "name": "LATENT", - "type": "LATENT", - "links": [ - 30 - ], - "shape": 3 - }, - { - "name": "VAE", - "type": "VAE", + "name": "IMAGE", + "type": "IMAGE", "links": [ - 29 + 56 ], - "shape": 3 + "shape": 3, + "slot_index": 0 } ], "properties": { - "Node name for S&R": "ImpactKSamplerBasicPipe" + "Node name for S&R": "ImageScaleBy" }, "widgets_values": [ - 4, - "fixed", - 20, - 8, - "euler", - "normal", - 1 + "lanczos", + 1.9999993896484363 ] }, { - "id": 28, - "type": "StableSRUpscalerPipe", + "id": 35, + "type": "PreviewImage", "pos": [ - 1333, - 184 + 1740, + 502 + ], + "size": [ + 423.54519164947305, + 530.33793439514 ], - "size": { - "0": 342.5999755859375, - "1": 338 - }, "flags": {}, - "order": 7, + "order": 12, "mode": 0, "inputs": [ { - "name": "image", + "name": "images", "type": "IMAGE", - "link": 50 - }, - { - "name": "basic_pipe", - "type": "BASIC_PIPE", - "link": 54, - "slot_index": 1 - }, - { - "name": "pk_hook_opt", - "type": "PK_HOOK", - "link": null + "link": 74 } ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 12, + "type": "LoadImage", + "pos": [ + -382, + -80 + ], + "size": { + "0": 453.4217529296875, + "1": 469.52587890625 + }, + "flags": {}, + "order": 0, + "mode": 0, "outputs": [ { - "name": "stablesr_image", + "name": "IMAGE", "type": "IMAGE", "links": [ - 51, - 53 + 54 ], "shape": 3, "slot_index": 0 }, { - "name": "color_map_image", - "type": "IMAGE", - "links": [ - 52 - ], - "shape": 3, - "slot_index": 1 + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 } ], "properties": { - "Node name for S&R": "StableSRUpscalerPipe" + "Node name for S&R": "LoadImage" }, "widgets_values": [ - 1.5, - 4, - "fixed", - 20, - 8, - "euler", - "normal", - 1, - true, - "webui_768v_139.ckpt" - ], - "color": "#323", - "bgcolor": "#535" + "1111.jpg", + "image" + ] } ], "links": [ [ - 3, + 8, 4, + 2, + 8, 1, - 6, + "VAE" + ], + [ + 15, + 4, + 2, + 13, + 1, + "VAE" + ], + [ + 19, + 4, + 1, + 14, 0, "CLIP" ], [ - 5, + 22, 4, 1, - 7, + 15, 0, "CLIP" ], [ - 16, + 41, 8, 0, - 13, + 26, 0, "IMAGE" ], [ - 23, - 4, + 54, + 12, 0, - 17, + 29, 0, - "MODEL" + "IMAGE" ], [ - 24, - 6, + 56, + 29, 0, - 17, - 3, - "CONDITIONING" - ], - [ - 25, - 7, + 13, 0, - 17, - 4, - "CONDITIONING" + "IMAGE" ], [ - 26, - 4, - 1, - 17, - 1, - "CLIP" + 60, + 30, + 0, + 8, + 0, + "LATENT" ], [ - 27, + 63, 4, - 2, - 17, - 2, - "VAE" + 0, + 31, + 0, + "MODEL" ], [ - 29, - 22, - 2, - 8, - 1, - "VAE" + 64, + 31, + 0, + 30, + 0, + "MODEL" ], [ + 65, + 14, + 0, 30, - 22, 1, - 8, - 0, - "LATENT" + "CONDITIONING" ], [ - 31, - 17, - 0, - 22, + 66, + 15, 0, - "BASIC_PIPE" + 30, + 2, + "CONDITIONING" ], [ - 38, - 25, - 0, - 24, + 67, + 13, 0, - "IMAGE" + 31, + 1, + "LATENT" ], [ - 50, - 18, - 0, - 28, + 68, + 13, 0, - "IMAGE" + 30, + 3, + "LATENT" ], [ - 51, - 28, + 70, + 13, 0, - 25, + 33, 0, - "IMAGE" + "LATENT" ], [ - 52, - 28, + 71, + 4, + 2, + 33, 1, - 25, + "VAE" + ], + [ + 72, + 33, + 0, + 34, 1, "IMAGE" ], [ - 53, - 28, + 73, + 8, 0, - 20, + 34, 0, "IMAGE" ], [ - 54, - 17, + 74, + 34, 0, - 28, - 1, - "BASIC_PIPE" + 35, + 0, + "IMAGE" ] ], - "groups": [ - { - "title": "StableSR Proof", - "bounding": [ - 1, - 1, - 2434, - 679 - ], - "color": "#693d6c", - "locked": false - }, - { - "title": "Standard Comfyui Txt2Img Proof", - "bounding": [ - 858, - 704, - 1216, - 558 - ], - "color": "#3f789e", - "locked": false - } - ], + "groups": [], "config": {}, "extra": {}, "version": 0.4 diff --git a/modules/stablesr.py b/modules/stablesr.py deleted file mode 100644 index c894055..0000000 --- a/modules/stablesr.py +++ /dev/null @@ -1,301 +0,0 @@ -''' -# -------------------------------------------------------------------------------- -# -# StableSR for Comfyui -# Migrationed from sd-webui-stablesr for Automatic1111 WebUI -# -# Introducing state-of-the super-resolution method: StableSR! -# Techniques is originally proposed by Jianyi Wang et, al. -# -# Project Page: https://iceclear.github.io/projects/stablesr/ -# Official Repo: https://github.com/IceClear/StableSR -# Paper: https://arxiv.org/abs/2305.07015 -# -# @original author: Jianyi Wang et, al. -# @migration: LI YI, Will James -# @organization: Nanyang Technological University - Singapore -# @date: 2023-09-20 -# @license: -# S-Lab License 1.0 (see LICENSE file) -# CC BY-NC-SA 4.0 (required by NVIDIA SPADE module) -# -# @disclaimer: -# All code in this extension is for research purpose only. -# The commercial use of the code & checkpoint is strictly prohibited. -# -# -------------------------------------------------------------------------------- -# -# IMPORTANT NOTICE FOR OUTCOME IMAGES: -# - Please be aware that the CC BY-NC-SA 4.0 license in SPADE module -# also prohibits the commercial use of outcome images. -# - Jianyi Wang may change the SPADE module to a commercial-friendly one. -# If you want to use the outcome images for commercial purposes, please -# contact Jianyi Wang for more information. -# -# Please give LI YI's repo and also Jianyi's repo a star if you like this project! -# -# -------------------------------------------------------------------------------- -''' - -import os -import torch -import numpy as np -import PIL.Image as Image - -import folder_paths -import nodes -import comfy.utils -import comfy.model_management - -# TODO might delete this in clean up -from comfy.model_patcher import ModelPatcher - -from torch import Tensor -from ldm.modules.diffusionmodules.openaimodel import UNetModel - -from spade import SPADELayers -from struct_cond import EncoderUNetModelWT, build_unetwt -from util import pil2tensor, tensor2pil - -FORWARD_CACHE_NAME = 'org_forward_stablesr' - -class StableSR: - ''' - Initializes a StableSR model. - - Args: - path: The path to the StableSR checkpoint file. - dtype: The data type of the model. If not specified, the default data type will be used. - device: The device to run the model on. If not specified, the default device will be used. - ''' - - def __init__(self, path, dtype, device): - print(f"[StbaleSR] in StableSR init - dtype: {dtype}, device: {device}") - - state_dict = comfy.utils.load_torch_file(path) - - self.struct_cond_model: EncoderUNetModelWT = build_unetwt() - self.spade_layers: SPADELayers = SPADELayers() - self.struct_cond_model.load_from_dict(state_dict) - self.spade_layers.load_from_dict(state_dict) - del state_dict - - self.struct_cond_model.apply(lambda x: x.to(dtype=dtype, device=device)) - self.spade_layers.apply(lambda x: x.to(dtype=dtype, device=device)) - self.latent_image: Tensor = None - self.set_image_hooks = {} - self.struct_cond: Tensor = None - - def set_latent_image(self, latent_image): - self.latent_image = latent_image["samples"] - for hook in self.set_image_hooks.values(): - hook(latent_image) - - ''' - # attempt to use Comfyui ModelPatcher.set_model_unet_function_wrapper() - # hasn't been successful due to timestep complexity - def sr_unet_forward(self, model_function, args_dict): - try: - # explode packed args - input_x = args_dict.get("input") - timestep_ = args_dict.get("timestep") - c = args_dict.get("c") - cond_or_uncond = args_dict.get("cond_or_uncond") - - # set latent image to device - device = comfy.model_management.get_torch_device() - latent_image = self.latent_image["samples"] - latent_image = latent_image.to(device) - - timestep_ = timestep_.to(torch.float32) - - # Ensure the device of all modules layers is the same as the unet - # This will fix the issue when user use --medvram or --lowvram - self.spade_layers.to(device) - self.struct_cond_model.to(device) - - #timestep_ = timestep_.to(device) - self.struct_cond = None # mitigate vram peak - self.struct_cond = self.struct_cond_model(latent_image, timestep_[:latent_image.shape[0]]) - - # Call the model_function with the provided arguments - result = model_function(input_x, timestep_, **c) - - # Return the result - return result - except Exception as e: - print(f"[StbaleSR] Error in sr_unet_forward: {str(e)}") - raise e - - def sr_hook(self, sd_model) - # try set forward handler using ModelPatcher.set_model_unet_function_wrapper() - #sd_model.set_model_unet_function_wrapper(self.sr_unet_forward) - - ''' - - def hook(self, unet: UNetModel): - # hook unet to set the struct_cond - if not hasattr(unet, FORWARD_CACHE_NAME): - setattr(unet, FORWARD_CACHE_NAME, unet.forward) - - print(f"[StbaleSR] in StableSR hook - unet dtype: {unet.dtype}") - - def unet_forward(x, timesteps=None, context=None, y=None,**kwargs): - # debug print the dtypes going in - print(f'[StableSR] in unet_forward()') - print(f"[StbaleSR] in StableSR hook unet_forward - dtype timesteps: {timesteps.dtype}") - print(f"[StbaleSR] in StableSR hook unet_forward - dtype latent_image: {self.latent_image.dtype}") - - self.latent_image = self.latent_image.to(x.device) - - # Ensure the device of all modules layers is the same as the unet - # This will fix the issue when user use --medvram or --lowvram - self.spade_layers.to(x.device) - self.struct_cond_model.to(x.device) - timesteps = timesteps.to(x.device) - self.struct_cond = None # mitigate vram peak - self.struct_cond = self.struct_cond_model(self.latent_image, timesteps[:self.latent_image.shape[0]]) - return getattr(unet, FORWARD_CACHE_NAME)(x, timesteps, context, y, **kwargs) - - unet.forward = unet_forward - - # set the spade_layers on unet - self.spade_layers.hook(unet, lambda: self.struct_cond) - - ''' - # TODO migrate unhook - def unhook(self, unet: UNetModel): - # clean up cache - self.latent_image = None - self.struct_cond = None - self.set_image_hooks = {} - # unhook unet forward - if hasattr(unet, FORWARD_CACHE_NAME): - unet.forward = getattr(unet, FORWARD_CACHE_NAME) - delattr(unet, FORWARD_CACHE_NAME) - - # unhook spade layers - self.spade_layers.unhook() - ''' - -class StableSRScript(): - params = None - - def __init__(self, upscale_factor, seed, steps, cfg, sampler_name, scheduler, denoise, pure_noise, basic_pipe, model, - hook_opt=None) -> None: - self.params = upscale_factor, seed, steps, cfg, sampler_name, scheduler, denoise, pure_noise, basic_pipe, model - self.hook = hook_opt - self.stablesr_model_path = None - self.get_stablesr_model_path(model) - self.stablesr_module: StableSR = None - self.init_latent = None - - def get_stablesr_model_path(self, model): - if self.stablesr_model_path is None: - file_path = folder_paths.get_full_path("stablesr", model) - if os.path.isfile(file_path): - # save tha absolute path - self.stablesr_model_path = file_path - else: - print(f'[StableSR] Invalid StableSR model reference') - return self.stablesr_model_path - - def upscale_tensor_as_pil(self, image, scale_factor, save_temp_prefix=None): - # Convert the PyTorch tensor to a PIL Image object. - pil_image = tensor2pil(image) - - w = int(pil_image.width * scale_factor) - h = int(pil_image.height * scale_factor) - - # if the target width is not dividable by 8, then round it up - if w % 8 != 0: - w = w + 8 - w % 8 - # if the target height is not dividable by 8, then round it up - if h % 8 != 0: - h = h + 8 - h % 8 - - # Resize the PIL Image object using Lanczos interpolation. - resized_image = pil_image.resize((w, h), Image.LANCZOS) - - resized_tensor = pil2tensor(resized_image) - - return resized_tensor - - def to_latent_image_with_vae(self, pixels, vae): - x = pixels.shape[1] - y = pixels.shape[2] - if pixels.shape[1] != x or pixels.shape[2] != y: - pixels = pixels[:, :x, :y, :] - t = vae.encode(pixels[:, :, :, :3]) - return {"samples": t} - - # sampler wrapper - def sample(self, image) -> Image: - upscale_factor, seed, steps, cfg, sampler_name, scheduler, denoise, pure_noise, basic_pipe, model = self.params - sd_model, clip, vae, positive, negative = basic_pipe - - # initial upscale on pixels to get target size and color map - upscaled_image = self.upscale_tensor_as_pil(image, upscale_factor) - - # get the initial upscaled latent image - self.init_latent = self.to_latent_image_with_vae(upscaled_image, vae) - - # get the device - device = comfy.model_management.get_torch_device() - - # get dtype from sd model - dtype = sd_model.model_dtype() - - # load StableSR - if self.stablesr_module is None: - self.stablesr_module = StableSR(self.stablesr_model_path, dtype, device) - - # set latent image on stablesr - self.stablesr_module.set_latent_image(self.init_latent) - - # get the stablediffusion unet referrence from the nested BaseModel instance - unet: UNetModel = sd_model.model.diffusion_model - - # hook unet forwards - self.stablesr_module.hook(unet) - - # get an empty latent for ksampler, it will generate a random tensor from the seed - empty_latent = {"samples": torch.zeros(self.init_latent["samples"].shape)} - - # run ksampler - print('[StableSR] Target image size: {}x{}'.format(upscaled_image.shape[2], upscaled_image.shape[1])) - refined_latent = \ - nodes.common_ksampler(sd_model, seed, steps, cfg, sampler_name, scheduler, positive, negative, empty_latent, denoise)[0] - - ''' - # TODO migrate variable noise - if pure_noise: - # NOTE: use txt2img instead of img2img sampling - samples = sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) - else: - if p.initial_noise_multiplier != 1.0: - p.extra_generation_params["Noise multiplier"] =p.initial_noise_multiplier - x *= p.initial_noise_multiplier - samples = sampler.sample_img2img(p, p.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) - - # TODO migrate mask - if p.mask is not None: - print("[StableSR] trace - in sample_custom() - p.mask is applied") - - samples = samples * p.nmask + p.init_latent * p.mask - del x - devices.torch_gc() - ''' - - # decode latent - refined_image = vae.decode(refined_latent['samples']) # final sr image - no color correction - color_map_image = upscaled_image # pretty name - return refined_image, color_map_image - - ''' - # TODO migrate unhook - self.stablesr_model.unhook(unet) - # in --medvram and --lowvram mode, we send the model back to the initial device - self.stablesr_model.struct_cond_model.to(device=first_param.device) - self.stablesr_model.spade_layers.to(device=first_param.device) - ''' \ No newline at end of file diff --git a/modules/struct_cond.py b/modules/struct_cond.py index e1c9c1f..0237279 100644 --- a/modules/struct_cond.py +++ b/modules/struct_cond.py @@ -18,7 +18,7 @@ ) # NOTE only change in file for Comyfui -from attn import sr_get_attn_func as get_attn_func +from .attn import sr_get_attn_func as get_attn_func attn_func = None @@ -347,8 +347,8 @@ def build_unetwt() -> EncoderUNetModelWT: model = build_unetwt() model.load_from_dict(state_dict) model = model.cuda() - test_latent = torch.randn(1, 4, 64, 64).half().cuda() + test_latent = torch.zeros(1, 4, 64, 64).half().cuda() test_timesteps = torch.tensor([0]).half().cuda() with torch.no_grad(): test_result = model(test_latent, test_timesteps) - print(test_result.keys()) \ No newline at end of file + print(test_result) \ No newline at end of file diff --git a/nodes.py b/nodes.py index 3b1d445..ba26d5e 100644 --- a/nodes.py +++ b/nodes.py @@ -1,30 +1,31 @@ +from .modules.struct_cond import EncoderUNetModelWT, build_unetwt +from .modules.spade import SPADELayers +from .modules.util import pil2tensor, tensor2pil +from .modules.colorfix import adain_color_fix, wavelet_color_fix + import os -import comfy -import numpy as np -import PIL.Image as Image +import comfy.samplers +from torch import Tensor import torch import folder_paths model_path = folder_paths.models_dir folder_name = "stablesr" -folder_path = os.path.join(model_path, "stablesr") #set a default path for the common comfyui model path +folder_path = os.path.join(model_path, "stablesr") # set a default path for the common comfyui model path if folder_name in folder_paths.folder_names_and_paths: - folder_path = folder_paths.folder_names_and_paths[folder_name][0][0] #if a custom path was set in extra_model_paths.yaml then use it + folder_path = folder_paths.folder_names_and_paths[folder_name][0][0] # if a custom path was set in extra_model_paths.yaml then use it folder_paths.folder_names_and_paths["stablesr"] = ([folder_path], folder_paths.supported_pt_extensions) -import stablesr -from colorfix import adain_color_fix, wavelet_color_fix -from util import pil2tensor, tensor2pil -class ColorFix: +class StableSRColorFix: @classmethod def INPUT_TYPES(s): return {"required": { - "image": ("IMAGE", ), - "color_map_image": ("IMAGE", ), - "color_fix": (["Wavelet", "AdaIN",],), - }, - } + "image": ("IMAGE", ), + "color_map_image": ("IMAGE", ), + "color_fix": (["Wavelet", "AdaIN",],), + }, + } RETURN_TYPES = ("IMAGE",) FUNCTION = "fix_color" @@ -40,44 +41,102 @@ def fix_color(self, image, color_map_image, color_fix): except Exception as e: print(f'[StableSR] Error fix_color: {e}') - -class StableSRUpscalerPipe: + +class StableSR: + ''' + Initializes a StableSR model. + + Args: + path: The path to the StableSR checkpoint file. + dtype: The data type of the model. If not specified, the default data type will be used. + device: The device to run the model on. If not specified, the default device will be used. + ''' + + def __init__(self, stable_sr_model_path, dtype, device): + print(f"[StbaleSR] in StableSR init - dtype: {dtype}, device: {device}") + state_dict = comfy.utils.load_torch_file(stable_sr_model_path) + + self.struct_cond_model: EncoderUNetModelWT = build_unetwt() + self.spade_layers: SPADELayers = SPADELayers() + self.struct_cond_model.load_from_dict(state_dict) + self.spade_layers.load_from_dict(state_dict) + del state_dict + + self.struct_cond_model.apply(lambda x: x.to(dtype=dtype, device=device)) + self.spade_layers.apply(lambda x: x.to(dtype=dtype, device=device)) + self.latent_image: Tensor = None + self.set_image_hooks = {} + self.struct_cond: Tensor = None + + def set_latent_image(self, latent_image): + self.latent_image = latent_image["samples"] + for hook in self.set_image_hooks.values(): + hook(latent_image) + + def __call__(self, model_function, params): + # explode packed args + input_x = params.get("input") + timestep = params.get("timestep") + c = params.get("c") + + # set latent image to device + device = input_x.device + latent_image = self.latent_image.to(device) + + # Ensure the device of all modules layers is the same as the unet + # This will fix the issue when user use --medvram or --lowvram + self.spade_layers.to(device) + self.struct_cond_model.to(device) + + self.struct_cond = None # mitigate vram peak + self.struct_cond = self.struct_cond_model(latent_image, timestep[:latent_image.shape[0]]) + + self.spade_layers.hook(model_function.__self__.diffusion_model, lambda: self.struct_cond) + + # Call the model_function with the provided arguments + result = model_function(input_x, timestep, **c) + + self.spade_layers.unhook() + + # Return the result + return result + +class ApplyStableSRUpscaler: @classmethod def INPUT_TYPES(s): return {"required": { - "image": ("IMAGE", ), - "upscale_factor": ("FLOAT", {"default": 1.5, "min": 1, "max": 10000, "step": 0.1}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), - "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), - "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "pure_noise": ("BOOLEAN", {"label_on": "enabled", "label_off": "disabled"}), - "basic_pipe": ("BASIC_PIPE",), - "stablesr_model": (folder_paths.get_filename_list("stablesr"), ), - }, - "optional": { - "pk_hook_opt": ("PK_HOOK", ), - } - } - - RETURN_TYPES = ("IMAGE","IMAGE", ) - RETURN_NAMES = ("stablesr_image","color_map_image", ) - FUNCTION = "doit_pipe" + "model": ("MODEL", ), + "latent_image": ("LATENT", ), + "stablesr_model": (folder_paths.get_filename_list("stablesr"), ), + } + } + + RETURN_TYPES = ("MODEL",) + + FUNCTION = "apply_stable_sr_upscaler" CATEGORY = "image/upscaling" - def doit_pipe(self, image, upscale_factor, seed, steps, cfg, sampler_name, scheduler, denoise, pure_noise, basic_pipe, stablesr_model, pk_hook_opt=None): - upscaler = stablesr.StableSRScript(upscale_factor, seed, steps, cfg, sampler_name, scheduler, denoise, pure_noise, basic_pipe, stablesr_model, pk_hook_opt) - upscale_image, color_map_image = upscaler.sample(image) - return (upscale_image, color_map_image, ) - + def apply_stable_sr_upscaler(self, model, latent_image, stablesr_model): + latent_image = {"samples": latent_image["samples"] * 0.18215} + + stablesr_model_path = folder_paths.get_full_path("stablesr", stablesr_model) + if not os.path.isfile(stablesr_model_path): + raise Exception(f'[StableSR] Invalid StableSR model reference') + + upscaler = StableSR(stablesr_model_path,dtype=torch.float32,device="cpu") + upscaler.set_latent_image(latent_image) + + model_sr = model.clone() + model_sr.set_model_unet_function_wrapper(upscaler) + return (model_sr, ) + + NODE_CLASS_MAPPINGS = { - "ColorFix": ColorFix, - "StableSRUpscalerPipe": StableSRUpscalerPipe, + "StableSRColorFix": StableSRColorFix, + "ApplyStableSRUpscaler": ApplyStableSRUpscaler } NODE_DISPLAY_NAME_MAPPINGS = { - "ColorFix": "ColorFix", - "StableSRUpscalerPipe": "StableSRUpscaler (pipe)", -} \ No newline at end of file + "StableSRColorFix": "StableSRColorFix", + "ApplyStableSRUpscaler": "ApplyStableSRUpscaler" +}