Skip to content

Commit

Permalink
enable offload
Browse files Browse the repository at this point in the history
  • Loading branch information
gameltb committed Nov 5, 2023
1 parent 5290d12 commit 5cca381
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,23 @@ def fix_color(self, image, color_map_image, color_fix):
except Exception as e:
print(f'[StableSR] Error fix_color: {e}')


original_sample = comfy.sample.sample
SAMPLE_X = None


def hook_sample(*args, **kwargs):
global SAMPLE_X
if len(args) >=9 :
if len(args) >= 9:
SAMPLE_X = args[8]
elif "latent_image" in kwargs:
SAMPLE_X = kwargs["latent_image"]
return original_sample(*args, **kwargs)


comfy.sample.sample = hook_sample


class StableSR:
'''
Initializes a StableSR model.
Expand Down Expand Up @@ -124,6 +128,12 @@ def __call__(self, model_function, params):
# Return the result
return result

def to(self, device):
if type(device) == torch.device:
self.struct_cond_model.apply(lambda x: x.to(device=device))
self.spade_layers.apply(lambda x: x.to(device=device))
return self


class ApplyStableSRUpscaler:
@classmethod
Expand Down

0 comments on commit 5cca381

Please sign in to comment.