forked from WSJUSA/Comfyui-StableSR
-
Notifications
You must be signed in to change notification settings - Fork 10
/
nodes.py
206 lines (162 loc) · 6.51 KB
/
nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
import torch
from torch import Tensor
import comfy.model_management
import comfy.sample
import folder_paths
from .modules.colorfix import adain_color_fix, wavelet_color_fix
from .modules.spade import SPADELayers
from .modules.struct_cond import EncoderUNetModelWT, build_unetwt
from .modules.util import pil2tensor, tensor2pil
model_path = folder_paths.models_dir
folder_name = "stablesr"
folder_path = os.path.join(
model_path, "stablesr"
) # set a default path for the common comfyui model path
if folder_name in folder_paths.folder_names_and_paths:
folder_path = folder_paths.folder_names_and_paths[folder_name][0][
0
] # if a custom path was set in extra_model_paths.yaml then use it
folder_paths.folder_names_and_paths["stablesr"] = (
[folder_path],
folder_paths.supported_pt_extensions,
)
class StableSRColorFix:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"color_map_image": ("IMAGE",),
"color_fix": (
[
"Wavelet",
"AdaIN",
],
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "fix_color"
CATEGORY = "image"
def fix_color(self, image, color_map_image, color_fix):
color_fix_func = (
wavelet_color_fix if color_fix == "Wavelet" else adain_color_fix
)
result_image = color_fix_func(tensor2pil(image), tensor2pil(color_map_image))
refined_image = pil2tensor(result_image)
return (refined_image,)
original_sample = comfy.sample.sample
SAMPLE_X = None
def hook_sample(*args, **kwargs):
global SAMPLE_X
if len(args) >= 9:
SAMPLE_X = args[8]
elif "latent_image" in kwargs:
SAMPLE_X = kwargs["latent_image"]
return original_sample(*args, **kwargs)
comfy.sample.sample = hook_sample
class StableSR:
"""
Initializes a StableSR model.
Args:
path: The path to the StableSR checkpoint file.
dtype: The data type of the model. If not specified, the default data type will be used.
device: The device to run the model on. If not specified, the default device will be used.
"""
def __init__(self, stable_sr_model_path, dtype, device):
print(f"[StbaleSR] in StableSR init - dtype: {dtype}, device: {device}")
state_dict = comfy.utils.load_torch_file(stable_sr_model_path)
self.struct_cond_model: EncoderUNetModelWT = build_unetwt(
use_fp16=dtype == torch.float16
)
self.spade_layers: SPADELayers = SPADELayers()
self.struct_cond_model.load_from_dict(state_dict)
self.spade_layers.load_from_dict(state_dict)
del state_dict
self.dtype = dtype
self.struct_cond_model.apply(lambda x: x.to(dtype=dtype, device=device))
self.spade_layers.apply(lambda x: x.to(dtype=dtype, device=device))
self.latent_image: Tensor = None
self.set_image_hooks = {}
self.struct_cond: Tensor = None
self.auto_set_latent = False
self.last_t = 0.0
def set_latent_image(self, latent_image):
self.latent_image = latent_image
def set_auto_set_latent(self, auto_set_latent):
self.auto_set_latent = auto_set_latent
def __call__(self, model_function, params):
# explode packed args
input_x = params.get("input")
timestep = params.get("timestep")
c = params.get("c")
t = model_function.__self__.model_sampling.timestep(timestep)
if self.auto_set_latent:
tt = float(t[0])
if self.last_t <= tt:
latent_image = model_function.__self__.process_latent_in(SAMPLE_X)
self.set_latent_image(latent_image)
self.last_t = tt
# set latent image to device
device = input_x.device
latent_image = self.latent_image.to(dtype=self.dtype, device=device)
# Ensure the device of all modules layers is the same as the unet
# This will fix the issue when user use --medvram or --lowvram
self.spade_layers.to(device)
self.struct_cond_model.to(device)
self.struct_cond = None # mitigate vram peak
self.struct_cond = self.struct_cond_model(
latent_image, t[: latent_image.shape[0]]
)
self.spade_layers.hook(
model_function.__self__.diffusion_model, lambda: self.struct_cond
)
# Call the model_function with the provided arguments
result = model_function(input_x, timestep, **c)
self.spade_layers.unhook()
# Return the result
return result
def to(self, device):
if type(device) == torch.device:
self.struct_cond_model.apply(lambda x: x.to(device=device))
self.spade_layers.apply(lambda x: x.to(device=device))
return self
class ApplyStableSRUpscaler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"stablesr_model": (folder_paths.get_filename_list("stablesr"),),
},
"optional": {
"latent_image": ("LATENT",),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_stable_sr_upscaler"
CATEGORY = "image/upscaling"
def apply_stable_sr_upscaler(self, model, stablesr_model, latent_image=None):
stablesr_model_path = folder_paths.get_full_path("stablesr", stablesr_model)
if not os.path.isfile(stablesr_model_path):
raise Exception(f"[StableSR] Invalid StableSR model reference")
upscaler = StableSR(
stablesr_model_path, dtype=comfy.model_management.unet_dtype(), device="cpu"
)
if latent_image != None:
latent_image = model.model.process_latent_in(latent_image["samples"])
upscaler.set_latent_image(latent_image)
else:
upscaler.set_auto_set_latent(True)
model_sr = model.clone()
model_sr.set_model_unet_function_wrapper(upscaler)
return (model_sr,)
NODE_CLASS_MAPPINGS = {
"StableSRColorFix": StableSRColorFix,
"ApplyStableSRUpscaler": ApplyStableSRUpscaler,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"StableSRColorFix": "StableSRColorFix",
"ApplyStableSRUpscaler": "ApplyStableSRUpscaler",
}