Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapt existing tests for flux models #84

Merged
merged 7 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions multigen/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def weightshare_copy(pipe):
obj = getattr(copy, key)
if hasattr(obj, 'load_state_dict'):
obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True)
# some buffers might not be transfered from pipe to copy
copy.to(pipe.device)
return copy


Expand Down
42 changes: 33 additions & 9 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, model_id: str,
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe_passed = pipe is not None
self.pipe = pipe
self._scheduler = None
self._hypernets = []
Expand All @@ -125,7 +126,8 @@ def __init__(self, model_id: str,
if mt != model_type:
raise RuntimeError(f"passed model type {self.model_type} doesn't match actual type {mt}")

self._initialize_pipe(device, offload_device)
if not pipe_passed:
self._initialize_pipe(device, offload_device)
self.lpw = lpw
self._loras = []

Expand Down Expand Up @@ -155,6 +157,7 @@ def _get_model_type(self):
def _initialize_pipe(self, device, offload_device):
# sometimes text encoder is on a different device
# if self.pipe.device != device:
logging.debug(f"initialising pipe to device {device}: offload_device {offload_device}")
self.pipe.to(device)
# self.pipe.enable_attention_slicing()
# self.pipe.enable_vae_slicing()
Expand All @@ -164,6 +167,7 @@ def _initialize_pipe(self, device, offload_device):
if self.model_type == ModelType.FLUX:
if offload_device is not None:
self.pipe.enable_sequential_cpu_offload(offload_device)
logging.debug(f'enable_sequential_cpu_offload for pipe dtype {self.pipe.dtype}')
else:
try:
import xformers
Expand All @@ -172,19 +176,21 @@ def _initialize_pipe(self, device, offload_device):
logging.warning("xformers not found, can't use efficient attention")

def _load_pipeline(self, sd_pipe_class, model_type, args):
logging.debug(f"loading pipeline from {self._model_id} with {args}")
if sd_pipe_class is None:
if self._model_id.endswith('.safetensors'):
if model_type is None:
raise RuntimeError(f"model_type is not specified for safetensors file {self._model_id}")
pipe_class = self._class if model_type == ModelType.SD else self._classxl
return pipe_class.from_single_file(self._model_id, **args)
result = pipe_class.from_single_file(self._model_id, **args)
else:
return self._autopipeline.from_pretrained(self._model_id, **args)
result = self._autopipeline.from_pretrained(self._model_id, **args)
else:
if self._model_id.endswith('.safetensors'):
return sd_pipe_class.from_single_file(self._model_id, **args)
result = sd_pipe_class.from_single_file(self._model_id, **args)
else:
return sd_pipe_class.from_pretrained(self._model_id, **args)
result = sd_pipe_class.from_pretrained(self._model_id, **args)
return result

@property
def scheduler(self):
Expand Down Expand Up @@ -724,7 +730,7 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
if model_id.endswith('.safetensors'):
if self.model_type is None:
raise RuntimeError(f"model type is not specified for safetensors file {model_id}")
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None))
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None), args.get('torch_dtype', None))
super().__init__(model_id=model_id, pipe=pipe, controlnet=cnets, model_type=model_type, **args)
else:
super().__init__(model_id=model_id, pipe=pipe, controlnet=cnets, model_type=model_type, **args)
Expand All @@ -738,22 +744,26 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
else:
raise RuntimeError(f"Unexpected model type {type(self.pipe)}")
self.model_type = t_model_type
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None))
logging.debug(f"from_pipe source dtype {self.pipe.dtype}")
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None), self.pipe.dtype)
prev_dtype = self.pipe.dtype
if self.model_type == ModelType.SDXL:
self.pipe = self._classxl.from_pipe(self.pipe, controlnet=cnets)
elif self.model_type == ModelType.FLUX:
self.pipe = self._classflux.from_pipe(self.pipe, controlnet=cnets[0])
else:
self.pipe = self._class.from_pipe(self.pipe, controlnet=cnets)
logging.debug(f"after from_pipe result dtype {self.pipe.dtype}")
for cnet in cnets:
cnet.to(self.pipe.dtype)
cnet.to(prev_dtype)
logging.debug(f'moving cnet {id(cnet)} to self.pipe.dtype {prev_dtype}')
if 'offload_device' not in args:
cnet.to(self.pipe.device)
else:
# don't load anything, just reuse pipe
super().__init__(model_id=model_id, pipe=pipe, **args)

def _load_cnets(self, cnets, cnet_ids, offload_device=None):
def _load_cnets(self, cnets, cnet_ids, offload_device=None, dtype=None):
if self.model_type == ModelType.FLUX:
ControlNet = FluxControlNetModel
else:
Expand All @@ -773,9 +783,18 @@ def _load_cnets(self, cnets, cnet_ids, offload_device=None):
else:
cnets.append(ControlNet.from_pretrained(c, torch_dtype=torch_dtype))
if offload_device is not None:
# controlnet should be on the same device where main model is working
dev = torch.device('cuda', offload_device)
logging.debug(f'moving cnets to offload device {dev}')
for cnet in cnets:
cnet.to(dev)
else:
logging.debug('offload device is None')
for cnet in cnets:
logging.debug(f"cnet dtype {cnet.dtype}")
if dtype is not None:
logging.debug(f"changing to {dtype}")
cnet.to(dtype)
return cnets

def get_cmodels(self):
Expand Down Expand Up @@ -832,6 +851,8 @@ def setup(self, fimage, width=None, height=None,
self._input_image = [image]
if cscales is None:
cscales = [self.get_default_cond_scales()[c] for c in self.ctypes]
if self.model_type == ModelType.FLUX and hasattr(cscales, '__len__'):
cscales = cscales[0] # multiple controlnets are not yet supported
self.pipe_params.update({
"width": image.size[0] if width is None else width,
"height": image.size[1] if height is None else height,
Expand Down Expand Up @@ -905,6 +926,9 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
Additional arguments passed to the Cond2ImPipe constructor.
"""
super().__init__(model_id=model_id, pipe=pipe, ctypes=ctypes, model_type=model_type, **args)
logging.debug("CIm2Im backend pipe was constructed")
logging.debug(f"self.pipe.dtype = {self.pipe.dtype}")
logging.debug(f"self.pipe.controlnet.dtype = {self.pipe.controlnet.dtype}")
self.processor = None
self.body_estimation = None
self.draw_bodypose = None
Expand Down
8 changes: 7 additions & 1 deletion multigen/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from . import util
from .prompting import Cfgen
import logging


class GenSession:
Expand Down Expand Up @@ -84,12 +85,14 @@ def gen_sess(self, add_count = 0, save_img=True,
# collecting images to return if requested or images are not saved
if not save_img or force_collect:
images = []
logging.info(f"add count = {add_count}")
jk = 0
for inputs in self.confg:
self.last_index = self.confg.count - 1
self.last_conf = {**inputs}
# TODO: multiple inputs?
inputs['generator'] = torch.Generator().manual_seed(inputs['generator'])

logging.debug("start generation")
image = self.pipe.gen(inputs)
if save_img:
self.last_img_name = self.get_last_file_prefix() + ".png"
Expand All @@ -103,5 +106,8 @@ def gen_sess(self, add_count = 0, save_img=True,
if save_img and not drop_cfg:
self.save_last_conf()
if callback is not None:
logging.debug("call callback after generation")
callback()
jk += 1
logging.debug(f"done iteration {jk}")
return images
9 changes: 5 additions & 4 deletions multigen/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None):
if model_type == ModelType.SDXL:
cls = pipe_class._classxl
elif model_type == ModelType.FLUX:
cls = pipe_class._flux
# use offload by default for now
cls = pipe_class._classflux
if device.type == 'cuda':
offload_device = device.index
device = torch.device('cpu')
device = torch.device('cpu', 0)
else:
cls = pipe_class._class
pipeline = self._loader.load_pipeline(cls, model_id, torch_dtype=torch.bfloat16,
device=device)
self.logger.debug(f'requested {cls} {model_id} on device {device}, got {pipeline.device}')
assert pipeline.device == device
pipe = pipe_class(model_id, pipe=pipeline, device=device, offload_device=offload_device)
if offload_device is None:
assert pipeline.device == device
Expand Down Expand Up @@ -164,7 +164,8 @@ def _update(sess, job, gs):
data['finish_callback']()
except (RuntimeError, TypeError, NotImplementedError) as e:
self.logger.error("error in generation", exc_info=e)
self.logger.error(f"offload_device {pipe.pipe._offload_gpu_id}")
if hasattr(pipe.pipe, '_offload_gpu_id'):
self.logger.error(f"offload_device {pipe.pipe._offload_gpu_id}")
if 'finish_callback' in data:
data['finish_callback']("Can't generate image due to error")
except Exception as e:
Expand Down
Loading
Loading