Style transfer using Style Aligned #370
-
Hi! Loving this library! I'm reading through the blog post on how do use style aligned and it seems pretty straight forward. Unfortunately it doesn't show how to do style transfer. I guess that part was skipped on purpose, as it's not that straight forward to implement, but would it be possible to share some tips on how this could be implemented in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi @holwech , thanks for the kind words! The Style Transfer / DDIM Inversion was indeed skipped on purpose, since we only made a quick draft of it and didn't have time to polish it. Essentially the DDIM inversion computes the trajectory on an encoded image in the latent space back to a noisy latent (using the UNet denoiser, it's litterally the reverse trajectory of DDIM). We use it tandem with StyleAligned by overwriting the first element of the batch with the latent on the trajectory, right before giving the batch to the UNet model for denoising. Here are some links I had lying around that might help you understand the process:
I also remember @chloedia implementing the DDIM inversion: I also found this old snippet I wrote (I can't guarantee it still works, it's pretty outdated)from pathlib import Path
import torch
from PIL import Image
from rich.progress import track
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.ddim_inversion import ddim_invert
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.style_align import StyleAlignAdapter
HUB_PATH = Path("/mnt/ssd2/hub/finegrain/tests-weights")
DEVICE = torch.device("cuda")
DTYPE = torch.float16
SEED = 2
CONDITION_SCALE = 10.0
INVERSION_CONDITION_SCALE = 2.0
NUM_INFERENCE_STEPS = 30
OFFSET = 0
sd = StableDiffusion_XL(device=DEVICE, dtype=DTYPE)
sd.clip_text_encoder.load_from_safetensors(HUB_PATH / "DoubleCLIPTextEncoder.safetensors")
sd.lda.load_from_safetensors(HUB_PATH / "sdxl-lda.safetensors")
sd.unet.load_from_safetensors(HUB_PATH / "sdxl-unet.safetensors")
SHARED_ATTENTION = True
SHARED_NORM = False
style_align_adapter = StyleAlignAdapter(
target=sd.unet,
shared_attention=SHARED_ATTENTION,
shared_norm=SHARED_NORM,
)
style_align_adapter.inject()
PROMPTS = [
"Man laying in a bed, medieval painting.",
"A man working on a laptop, medieval painting.",
"A man eats pizza, medieval painting.",
"A woman playig on saxophone, medieval painting.",
]
NEGATIVE_PROMPTS = [""] * len(PROMPTS)
with no_grad():
# compute clip text embeddings
unconds: list[torch.Tensor] = []
conds: list[torch.Tensor] = []
pooled_unconds: list[torch.Tensor] = []
pooled_conds: list[torch.Tensor] = []
for prompt, negative_prompt in zip(PROMPTS, NEGATIVE_PROMPTS):
clip_text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding(
text=prompt,
negative_text=negative_prompt,
)
uncond, cond = clip_text_embedding.chunk(2)
pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2)
unconds.append(uncond)
conds.append(cond)
pooled_unconds.append(pooled_uncond)
pooled_conds.append(pooled_cond)
uncond = torch.cat(unconds, dim=0)
cond = torch.cat(conds, dim=0)
pooled_uncond = torch.cat(pooled_unconds, dim=0)
pooled_cond = torch.cat(pooled_conds, dim=0)
clip_text_embedding = torch.cat((uncond, cond), dim=0)
pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0)
time_ids = sd.default_time_ids.repeat(len(PROMPTS), 1)
# compute reverse ddim latents
image_path = "/home/laurent/github.com/google/style-aligned/example_image/medieval-bed.jpeg"
image = Image.open(image_path).resize((1024, 1024))
prompt = PROMPTS[0]
ddim_inverse_latents = ddim_invert(
sd=sd,
image=image,
prompt=prompt,
condition_scale=INVERSION_CONDITION_SCALE,
)
# generate random latents
manual_seed(seed=SEED)
x = torch.randn(len(PROMPTS), 4, 128, 128, device=sd.device, dtype=sd.dtype)
# replace the first latent with the inverse ddim latent
x[0] = ddim_inverse_latents[0]
for step in track(sd.steps, description="Stepping"):
x = sd(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=CONDITION_SCALE,
)
# replace the first latent with the inverse ddim latent
x[0] = ddim_inverse_latents[step + OFFSET + 1]
# decode latents to images
sd.lda.to(dtype=torch.float32)
predicted_images = [sd.lda.decode_latents(x[i : i + 1, ...].to(dtype=torch.float32)) for i in range(x.shape[0])]
sd.lda.to(dtype=torch.float16)
# save images to disk
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
for i in range(len(predicted_images)):
merged_image.paste(predicted_images[i], (i * 1024, 0))
merged_image.save(f"inverse_ddim.png") from itertools import pairwise
import torch
from PIL import Image
from torch import Tensor
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
@no_grad()
def ddim_invert(
sd: StableDiffusion_XL,
image: Image.Image,
prompt: str,
time_ids: Tensor | None = None,
condition_scale: float = 2.0,
) -> list[Tensor]:
"""Invert the latent of an image using the DDIM algorithm."""
# encode the image into a latent
sd.lda.to(dtype=torch.float32)
latent = sd.lda.encode_image(image).to(dtype=sd.dtype)
sd.lda.to(dtype=torch.float16)
# encode the prompt into a text embeddings
clip_text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding(text=prompt)
# set default values for time_ids
time_ids = time_ids or sd.default_time_ids
# create list containing all latents
ddim_trajectory: list[Tensor] = [latent]
# build timestep pairs
timestep_pairs = list(pairwise(sd.scheduler.timesteps.unsqueeze(1).flip(0)))
timestep_pairs.insert(0, (torch.tensor([1], device=sd.device), torch.tensor([1], device=sd.device)))
# inversion loop
for current_timestep, next_timestep in timestep_pairs:
# set unet contexts
sd.set_unet_context(
time_ids=time_ids,
timestep=next_timestep,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
)
# predict noise (inspired from sd.forward)
latents = torch.cat(tensors=(latent, latent)) # cfg
unconditional_prediction, conditional_prediction = sd.unet(latents).chunk(2)
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
# (inspired from sd.scheduler.__call__)
current_alpha_cumprod = sd.scheduler.cumulative_scale_factors[current_timestep - 1] ** 2
next_alpha_cumprod = sd.scheduler.cumulative_scale_factors[next_timestep - 1] ** 2
predicted_x = (latent - (1 - current_alpha_cumprod).sqrt() * noise) / current_alpha_cumprod.sqrt()
direction = (1 - next_alpha_cumprod).sqrt() * noise
latent = next_alpha_cumprod.sqrt() * predicted_x + direction
# add latent to list
ddim_trajectory.append(latent)
ddim_trajectory.reverse()
return ddim_trajectory |
Beta Was this translation helpful? Give feedback.
Hi @holwech , thanks for the kind words!
The Style Transfer / DDIM Inversion was indeed skipped on purpose, since we only made a quick draft of it and didn't have time to polish it. Essentially the DDIM inversion computes the trajectory on an encoded image in the latent space back to a noisy latent (using the UNet denoiser, it's litterally the reverse trajectory of DDIM). We use it tandem with StyleAligned by overwriting the first element of the batch with the latent on the trajectory, right before giving the batch to the UNet model for denoising.
Here are some links I had lying around that might help you understand the process: