Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle clip_skip and lora scale in lpw pipelines #72

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions multigen/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,19 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
deprecate,
logging,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -227,8 +223,7 @@ def get_unweighted_text_embeddings(
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
text_embedding = prompt_embeds[0]
else:
prompt_embeds = pipe.text_encoder(
text_input_chunk.to(pipe.device), output_hidden_states=True)
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
Expand Down Expand Up @@ -372,11 +367,7 @@ def get_weighted_text_embeddings(

# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
if uncond_prompt is not None:
Expand All @@ -385,7 +376,7 @@ def get_weighted_text_embeddings(
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip
clip_skip=clip_skip,
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)

Expand Down Expand Up @@ -454,7 +445,11 @@ def preprocess_mask(mask, batch_size, scale_factor=8):


class StableDiffusionLongPromptWeightingPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
LoraLoaderMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
Expand Down Expand Up @@ -590,6 +585,8 @@ def _encode_prompt(
max_embeddings_multiples=3,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand Down Expand Up @@ -639,6 +636,7 @@ def _encode_prompt(
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
if prompt_embeds is None:
prompt_embeds = prompt_embeds1
Expand Down Expand Up @@ -832,6 +830,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip: Optional[int] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -907,6 +906,9 @@ def __call__(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Expand Down Expand Up @@ -945,6 +947,7 @@ def __call__(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None

# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
Expand All @@ -956,6 +959,8 @@ def __call__(
max_embeddings_multiples,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
dtype = prompt_embeds.dtype

Expand Down Expand Up @@ -1086,6 +1091,7 @@ def text2img(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip=None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -1143,6 +1149,9 @@ def text2img(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Expand Down Expand Up @@ -1177,6 +1186,7 @@ def text2img(
return_dict=return_dict,
callback=callback,
is_cancelled_callback=is_cancelled_callback,
clip_skip=clip_skip,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
)
Expand Down
33 changes: 19 additions & 14 deletions multigen/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,31 @@

from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor


if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
Expand Down Expand Up @@ -263,7 +265,7 @@ def get_weighted_text_embeddings_sdxl(
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None
lora_scale: Optional[int] = None,
):
"""
This function can process long prompt with weights, no length limitation
Expand Down Expand Up @@ -580,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
StableDiffusionMixin,
FromSingleFileMixin,
IPAdapterMixin,
LoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
):
r"""
Expand All @@ -592,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings

Args:
Expand Down Expand Up @@ -774,7 +776,7 @@ def encode_prompt(

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale

if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1643,7 +1645,9 @@ def __call__(
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 3. Encode input prompt
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
lora_scale = (
self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
)

negative_prompt = negative_prompt if negative_prompt is not None else ""

Expand All @@ -1658,6 +1662,7 @@ def __call__(
neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
dtype = prompt_embeds.dtype

Expand Down
Loading