From 799ab2370fcac5e0d0ff518f45d79e6c06b097fa Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 15 May 2023 08:04:14 -0700 Subject: [PATCH 01/72] initial commit --- .../pipelines/consistency_models/__init__.py | 0 .../pipeline_consistency_models.py | 131 ++++++++++++++++++ .../pipelines/consistency_models/__init__.py | 0 .../test_consistency_models.py | 26 ++++ 4 files changed, 157 insertions(+) create mode 100644 src/diffusers/pipelines/consistency_models/__init__.py create mode 100644 src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py create mode 100644 tests/pipelines/consistency_models/__init__.py create mode 100644 tests/pipelines/consistency_models/test_consistency_models.py diff --git a/src/diffusers/pipelines/consistency_models/__init__.py b/src/diffusers/pipelines/consistency_models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py new file mode 100644 index 000000000000..a244203db131 --- /dev/null +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -0,0 +1,131 @@ +import inspect +from typing import List, Optional, Tuple, Union + +import torch + +from ...models import UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + +class ConsistencyModelPipeline(DiffusionPipeline): + r""" + TODO + """ + def __init__(self, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers) -> None: + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + ) + + # Need to handle boundary conditions (e.g. c_skip, c_out, etc.) somewhere. + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def add_noise_to_input( + self, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + step: int = 0 + ): + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + TODO Args: + """ + pass + + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + img_size = img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + device = self.device + + # 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I) + sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma + + # 2. Set timesteps + self.scheduler.set_timesteps(num_inference_steps) + # TODO: should schedulers always have sigmas? I think the original code always uses sigmas + # self.scheduler.set_sigmas(num_inference_steps) + + # 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 4. Denoising loop + # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(self.scheduler.timesteps): + # TODO: handle class labels? + model_output = self.unet(sample, t) + + sample = self.scheduler.step(model_output, t, sample, **extra_step_kwargs).prev_sample + + # TODO: need to handle karras sigma stuff here? + + # TODO: need to support callbacks? + + # 5. Post-process image sample + sample = sample.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + if not return_dict: + return (sample,) + + # TODO: Offload to cpu? + + return ImagePipelineOutput(images=sample) + + + + diff --git a/tests/pipelines/consistency_models/__init__.py b/tests/pipelines/consistency_models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py new file mode 100644 index 000000000000..ee829c24f846 --- /dev/null +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -0,0 +1,26 @@ +import gc +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers.utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu + +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + +class ConsistencyModelPipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): + pass + +@slow +@require_torch_gpu +class ConsistencyModelPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() \ No newline at end of file From 63b7f01be914d80d89cf5b1c4ebe7f1b53d244e5 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 May 2023 11:31:55 -0700 Subject: [PATCH 02/72] Improve consistency models sampling implementation. --- .../pipelines/consistency_models/__init__.py | 1 + .../pipeline_consistency_models.py | 158 ++++++++++++++---- 2 files changed, 126 insertions(+), 33 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/__init__.py b/src/diffusers/pipelines/consistency_models/__init__.py index e69de29bb2d1..52cfd0ba939b 100644 --- a/src/diffusers/pipelines/consistency_models/__init__.py +++ b/src/diffusers/pipelines/consistency_models/__init__.py @@ -0,0 +1 @@ +from .pipeline_consistency_models import ConsistencyModelPipeline \ No newline at end of file diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index a244203db131..faf7e20880fc 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable import torch @@ -8,6 +8,17 @@ from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + class ConsistencyModelPipeline(DiffusionPipeline): r""" TODO @@ -40,30 +51,76 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + def get_scalings(self, sigma, sigma_data: float = 0.5): + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out, c_in + + def get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data: float = 0.5): + c_skip = sigma_data**2 / ( + (sigma - sigma_min) ** 2 + sigma_data**2 + ) + c_out = ( + (sigma - sigma_min) + * sigma_data + / (sigma**2 + sigma_data**2) ** 0.5 + ) + c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out, c_in + + def denoise(self, x_t, sigma, sigma_min, sigma_data: float = 0.5, clip_denoised=True): + """ + Run the consistency model forward...? + """ + c_skip, c_out, c_in = [ + append_dims(x, x_t.ndim) + for x in self.get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data=sigma_data) + ] + rescaled_t = 1000 * 0.25 * torch.log(sigma + 1e-44) + model_output = self.unet(c_in * x_t, rescaled_t).sample + denoised = c_out * model_output + c_skip * x_t + if clip_denoised: + denoised = denoised.clamp(-1, 1) + return model_output, denoised + + def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + def add_noise_to_input( - self, - sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - step: int = 0 - ): - """ - Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a - higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. - TODO Args: - """ - pass + self, + sample: torch.FloatTensor, + sigma_hat: float, + sigma_min: float, + sigma_max: float, + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + ): + # Clamp sigma_hat + sigma_hat = sigma_hat.clamp(min=sigma_min, max=sigma_max) + # sample z ~ N(0, s_noise^2 * I) + z = s_noise * randn_tensor(sample.shape, generator=generator, device=sample.device) + + # tau = sigma_hat; eps = sigma_min + sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z) + + return sample_hat @torch.no_grad() def __call__( self, batch_size: int = 1, - num_inference_steps: int = 2000, + num_inference_steps: int = 40, + clip_denoised: bool = True, + sigma_data: float = 0.5, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - **kwargs, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, ): r""" Args: @@ -87,33 +144,72 @@ def __call__( img_size = img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) device = self.device + scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + scheduler_has_sigma_min = hasattr(self.scheduler, "sigma_min") + assert scheduler_has_sigma_min or scheduler_is_in_sigma_space, "Scheduler needs to have sigmas" # 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma - # 2. Set timesteps + # 2. Set timesteps and get sigmas self.scheduler.set_timesteps(num_inference_steps) - # TODO: should schedulers always have sigmas? I think the original code always uses sigmas - # self.scheduler.set_sigmas(num_inference_steps) + timesteps = self.scheduler.timesteps # 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 4. Denoising loop - # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(self.scheduler.timesteps): - # TODO: handle class labels? - model_output = self.unet(sample, t) - - sample = self.scheduler.step(model_output, t, sample, **extra_step_kwargs).prev_sample - - # TODO: need to handle karras sigma stuff here? - - # TODO: need to support callbacks? + if scheduler_has_sigma_min: + # 4.1 Scheduler which can add noise to input (e.g. KarrasVeScheduler) + sigma_min = self.scheduler.sigma_min + sigma_max = self.scheduler.sigma_max + s_noise = self.scheduler.s_noise + sigmas = self.scheduler.schedule + + # First evaluate the consistency model. This will be the output sample if num_inference_steps == 1 + sigma = sigmas[timesteps[0]] + _, sample = self.denoise(sample, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised) + + # If num_inference_steps > 1, perform multi-step sampling (stochastic_iterative_sampler) + # Alternate adding noise and evaluating the consistency model + for i, t in self.progress_bar(enumerate(self.scheduler.timesteps[1:])): + sigma = sigmas[t] + sigma_prev = sigmas[t - 1] + if hasattr(self.scheduler, "add_noise_to_input"): + sample_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)[0] + else: + sample_hat = self.add_noise_to_input(sample, sigma, sigma_prev, sigma_min, sigma_max, s_noise=s_noise, generator=generator) + + _, sample = self.denoise(sample_hat, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised) + else: + # 4.2 Karras-style scheduler in sigma space (e.g. HeunDiscreteScheduler) + sigma_min = self.scheduler.sigmas[-1] + # TODO: warmup steps logic correct? + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + step_idx = self.scheduler.index_for_timestep(t) + sigma = self.scheduler.sigmas[step_idx] + # TODO: handle class labels? + model_output, denoised = self.denoise( + sample, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised + ) + + # Karras-style schedulers already convert to a ODE derivative inside step() + sample = self.scheduler.step(denoised, t, sample, **extra_step_kwargs).prev_sample + + # TODO: need to handle karras sigma stuff here? + + # TODO: differs from callback support in original code + # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459 + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) # 5. Post-process image sample - sample = sample.clamp(0, 1) + sample = (sample / 2 + 0.5).clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": @@ -125,7 +221,3 @@ def __call__( # TODO: Offload to cpu? return ImagePipelineOutput(images=sample) - - - - From f2e53da3d5557e0e4fed3e0e642e5099b3ee8981 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 May 2023 18:19:57 -0700 Subject: [PATCH 03/72] Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling. --- .../pipeline_consistency_models.py | 118 ++++++---- .../scheduling_consistency_models.py | 206 ++++++++++++++++++ 2 files changed, 279 insertions(+), 45 deletions(-) create mode 100644 src/diffusers/schedulers/scheduling_consistency_models.py diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index faf7e20880fc..5f608626a8bc 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -51,6 +51,43 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + def get_sigma_min_max_from_scheduler(self): + # Get sigma_min, sigma_max in original sigma space, not Karras sigma space + # (e.g. not exponentiated by 1 / rho) + if hasattr(self.scheduler, "sigma_min"): + sigma_min = self.scheduler.sigma_min + sigma_max = self.scheduler.sigma_max + elif hasattr(self.scheduler, "sigmas"): + # Karras-style scheduler e.g. (EulerDiscreteScheduler, HeunDiscreteScheduler) + # Get sigma_min, sigma_max before they're converted into Karras sigma space by set_timesteps + # TODO: Karras schedulers are inconsistent about how they initialize sigmas in __init__ + # For example, EulerDiscreteScheduler gets sigmas in original sigma space, but HeunDiscreteScheduler + # initializes it through set_timesteps, which potentially leaves the sigmas in Karras sigma space. + # TODO: For example, in EulerDiscreteScheduler, a value of 0.0 is appended to the sigmas whern initialized + # in __init__. But wouldn't we usually want sigma_min to be a small positive number, following the + # consistency models paper? + # See e.g. https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L13 + sigma_min = self.scheduler.sigmas[-1].item() + sigma_max = self.scheduler.sigmas[0].item() + else: + raise ValueError( + f"Scheduler {self.scheduler.__class__} does not have sigma_min or sigma_max." + ) + return sigma_min, sigma_max + + def get_sigmas_from_scheduler(self): + if hasattr(self.scheduler, "sigmas"): + # e.g. HeunDiscreteScheduler + sigmas = self.scheduler.sigmas + elif hasattr(self.scheduler, "schedule"): + # e.g. KarrasVeScheduler + sigmas = self.scheduler.schedule + else: + raise ValueError( + f"Scheduler {self.scheduler.__class__} does not have sigmas." + ) + return sigmas + def get_scalings(self, sigma, sigma_data: float = 0.5): c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 @@ -58,6 +95,8 @@ def get_scalings(self, sigma, sigma_data: float = 0.5): return c_skip, c_out, c_in def get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data: float = 0.5): + # sigma_min should be in original sigma space, not in karras sigma space + # (e.g. not exponentiated by 1 / rho) c_skip = sigma_data**2 / ( (sigma - sigma_min) ** 2 + sigma_data**2 ) @@ -73,6 +112,8 @@ def denoise(self, x_t, sigma, sigma_min, sigma_data: float = 0.5, clip_denoised= """ Run the consistency model forward...? """ + # sigma_min should be in original sigma space, not in karras sigma space + # (e.g. not exponentiated by 1 / rho) c_skip, c_out, c_in = [ append_dims(x, x_t.ndim) for x in self.get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data=sigma_data) @@ -88,26 +129,6 @@ def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / append_dims(sigma, x.ndim) - def add_noise_to_input( - self, - sample: torch.FloatTensor, - sigma_hat: float, - sigma_min: float, - sigma_max: float, - s_noise: float = 1.0, - generator: Optional[torch.Generator] = None, - ): - # Clamp sigma_hat - sigma_hat = sigma_hat.clamp(min=sigma_min, max=sigma_max) - - # sample z ~ N(0, s_noise^2 * I) - z = s_noise * randn_tensor(sample.shape, generator=generator, device=sample.device) - - # tau = sigma_hat; eps = sigma_min - sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z) - - return sample_hat - @torch.no_grad() def __call__( self, @@ -144,46 +165,49 @@ def __call__( img_size = img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) device = self.device - scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") - scheduler_has_sigma_min = hasattr(self.scheduler, "sigma_min") - assert scheduler_has_sigma_min or scheduler_is_in_sigma_space, "Scheduler needs to have sigmas" # 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma # 2. Set timesteps and get sigmas + # Get sigma_min, sigma_max in original sigma space (not Karras sigma space) + sigma_min, sigma_max = self.get_sigma_min_max_from_scheduler() self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps + + # Now get Karras sigma schedule (which I think the original implementation always uses) + # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L376 + sigmas = self.get_sigmas_from_scheduler() # 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 4. Denoising loop - if scheduler_has_sigma_min: - # 4.1 Scheduler which can add noise to input (e.g. KarrasVeScheduler) - sigma_min = self.scheduler.sigma_min - sigma_max = self.scheduler.sigma_max - s_noise = self.scheduler.s_noise - sigmas = self.scheduler.schedule - + # TODO: hack, is there a better way to identify schedulers that implement the stochastic iterative sampling + # similar to stochastic_iterative_sampler in the original code? + if hasattr(self.scheduler, "add_noise_to_input"): + # 4.1 Consistency Model Stochastic Iterative Scheduler (multi-step sampling) # First evaluate the consistency model. This will be the output sample if num_inference_steps == 1 - sigma = sigmas[timesteps[0]] + # TODO: not all schedulers have an index_for_timestep method (e.g. KarrasVeScheduler) + step_idx = self.scheduler.index_for_timestep(timesteps[0]) + sigma = sigmas[step_idx] _, sample = self.denoise(sample, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised) # If num_inference_steps > 1, perform multi-step sampling (stochastic_iterative_sampler) - # Alternate adding noise and evaluating the consistency model + # Alternate adding noise and evaluating the consistency model on the noised input for i, t in self.progress_bar(enumerate(self.scheduler.timesteps[1:])): - sigma = sigmas[t] - sigma_prev = sigmas[t - 1] - if hasattr(self.scheduler, "add_noise_to_input"): - sample_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)[0] - else: - sample_hat = self.add_noise_to_input(sample, sigma, sigma_prev, sigma_min, sigma_max, s_noise=s_noise, generator=generator) - - _, sample = self.denoise(sample_hat, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised) - else: + step_idx = self.scheduler.index_for_timestep(t) + sigma = sigmas[step_idx] + sigma_prev = sigmas[step_idx - 1] + sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)[0] + + model_output, denoised = self.denoise( + sample_hat, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised + ) + + sample = self.scheduler.step(denoised, sigma_hat, sigma_prev, sample_hat).prev_sample + elif hasattr(self.scheduler, "sigmas"): # 4.2 Karras-style scheduler in sigma space (e.g. HeunDiscreteScheduler) - sigma_min = self.scheduler.sigmas[-1] # TODO: warmup steps logic correct? num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -191,6 +215,8 @@ def __call__( step_idx = self.scheduler.index_for_timestep(t) sigma = self.scheduler.sigmas[step_idx] # TODO: handle class labels? + # TODO: check shapes, might need equivalent of s_in in original code + # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L510 model_output, denoised = self.denoise( sample, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised ) @@ -198,15 +224,17 @@ def __call__( # Karras-style schedulers already convert to a ODE derivative inside step() sample = self.scheduler.step(denoised, t, sample, **extra_step_kwargs).prev_sample - # TODO: need to handle karras sigma stuff here? - - # TODO: differs from callback support in original code + # Note: differs from callback support in original code # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459 # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, sample) + else: + raise ValueError( + f"Scheduler {self.scheduler.__class__} is not compatible with consistency models." + ) # 5. Post-process image sample sample = (sample / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py new file mode 100644 index 000000000000..8aa5ecd0d7ae --- /dev/null +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -0,0 +1,206 @@ +# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import SchedulerMixin + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +@dataclass +class CMStochasticIterativeSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Derivative of predicted original image sample (x_0). + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + # derivative: torch.FloatTensor + # pred_original_sample: Optional[torch.FloatTensor] = None + + +class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + """ + + @register_to_config + def __init__( + self, + sigma_data: float = 0.5, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + rho: float = 7.0, + s_noise: float = 1.0, + s_churn: float = 0.0, + s_min: float = 0.0, + s_max: float = float('inf'), + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + # setable values + self.num_inference_steps: int = None + self.timesteps: np.IntTensor = None + self.schedule: torch.FloatTensor = None # sigma(t_i) + + self.sigma_data = sigma_data + self.rho = rho + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + return indices.item() + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def get_sigmas_karras(self): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = self.sigma_min ** (1 / self.rho) + max_inv_rho = self.sigma_max ** (1 / self.rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho + return append_zero(sigmas) + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + # TODO: how should timesteps be set? the original code seems to either solely work in sigma space or have + # hardcoded timesteps (see e.g. https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L74) + # TODO: should add num_train_timesteps here??? + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + sigmas = self.get_sigmas_karras() + + self.timesteps = torch.from_numpy(timesteps).to(device) + self.sigmas = torch.tensor(sigmas, dtype=torch.float32, device=device) + + def add_noise(self, original_samples, noise, timesteps): + """Add noise for training.""" + raise NotImplementedError() + + def add_noise_to_input( + self, + sample: torch.FloatTensor, + sigma: float, + generator: Optional[torch.Generator] = None + ) -> Tuple[torch.FloatTensor, float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + TODO Args: + """ + sigma_min = self.config.sigma_min + sigma_max = self.config.sigma_max + + step_idx = (self.sigmas == sigma).nonzero().item() + sigma_hat = self.sigmas[step_idx + 1].clamp(min=sigma_min, max=sigma_max) + + # sample z ~ N(0, s_noise^2 * I) + z = self.config.s_noise * randn_tensor(sample.shape, generator=generator, device=sample.device) + + # tau = sigma_hat, eps = sigma_min + sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z) + + return sample_hat, sigma_hat + + def step( + self, + model_output: torch.FloatTensor, + sigma_hat: float, + sigma_prev: float, + sample_hat: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor`): TODO + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class + KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). + Returns: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # Assume model output is the consistency model evaluated at sample_hat. + sample_prev = model_output + + if not return_dict: + return (sample_prev,) + + return CMStochasticIterativeSchedulerOutput( + prev_sample=sample_prev, + ) + From 7c1e81f83deac2746454920f469a9fc2250bbe2f Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Mon, 22 May 2023 12:10:30 +0530 Subject: [PATCH 04/72] Add Unet blocks for consistency models --- src/diffusers/models/unet_2d_blocks.py | 220 +++++++++++++++++++++++++ 1 file changed, 220 insertions(+) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 75d9eb3e03df..cad0676f4eab 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -90,6 +90,20 @@ def get_down_block( attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif down_block_type == "AttnDownsampleBlock2D": + return AttnDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") @@ -314,6 +328,20 @@ def get_up_block( attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif up_block_type == "AttnUpsampleBlock2D": + return AttnUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( num_layers=num_layers, @@ -762,6 +790,101 @@ def forward(self, hidden_states, temb=None, upsample_size=None): return hidden_states, output_states +class AttnDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, upsample_size=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + + class CrossAttnDownBlock2D(nn.Module): def __init__( self, @@ -1831,6 +1954,103 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si return hidden_states +class AttnUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + class CrossAttnUpBlock2D(nn.Module): def __init__( self, From 3a151bd1d96736ce5f57fc288075b4df7a08dbf9 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Mon, 22 May 2023 17:16:36 +0530 Subject: [PATCH 05/72] Add conversion script for Unet --- scripts/convert_consistency_to_diffusers.py | 163 ++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 scripts/convert_consistency_to_diffusers.py diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py new file mode 100644 index 000000000000..bd269c735dde --- /dev/null +++ b/scripts/convert_consistency_to_diffusers.py @@ -0,0 +1,163 @@ +import argparse +import io +from diffusers.models.unet_2d import UNet2DModel +import requests +import torch + +UNET_CONFIG = { + "sample_size": 64, + "in_channels": 3, + "out_channels": 3, + "layers_per_block" : 3, + "num_class_embeds": 1000, + "block_out_channels": [192, 192*2, 192*3, 192*4], + "attention_head_dim" : 64, + "down_block_types": ["ResnetDownsampleBlock2D", "AttnDownsampleBlock2D", "AttnDownsampleBlock2D", "AttnDownsampleBlock2D"], + "up_block_types": ["AttnUpsampleBlock2D", "AttnUpsampleBlock2D", "AttnUpsampleBlock2D", "ResnetUpsampleBlock2D"], + "resnet_time_scale_shift" : "scale_shift" +} + + +def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False): + new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"] + new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"] + new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"] + new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"] + new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"] + new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"] + + if has_skip: + new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"] + new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"] + + + return new_checkpoint + +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix): + weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) + + new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] + new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] + + new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) + + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) + + return new_checkpoint + + +def con_pt_to_diffuser(checkpoint_path: str, output_path: str): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] + + new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] + + down_block_types = UNET_CONFIG["down_block_types"] + layers_per_block = UNET_CONFIG["layers_per_block"] + current_layer = 1 + + for (i,layer_type) in enumerate(down_block_types): + + if layer_type == "ResnetDownsampleBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + elif layer_type == "AttnDownsampleBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + has_skip = True if j == 0 else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip) + new_prefix = f"down_blocks.{i}.attentions.{j}" + old_prefix = f"input_blocks.{current_layer}.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + if i!= len(down_block_types)-1: + new_prefix = f"down_blocks.{i}.downsamplers.0" + old_prefix = f"input_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + # hardcoded the mid-block for now + new_prefix = f"mid_block.resnets.0" + old_prefix = f"middle_block.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_prefix = f"mid_block.attentions.0" + old_prefix = f"middle_block.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_prefix = f"mid_block.resnets.1" + old_prefix = f"middle_block.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + current_layer = 0 + up_block_types = UNET_CONFIG["up_block_types"] + + for (i, layer_type) in enumerate (up_block_types): + if layer_type == "ResnetUpsampleBlock2D": + for j in range(layers_per_block+1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + current_layer += 1 + elif layer_type == "AttnUpsampleBlock2D": + for j in range(layers_per_block+1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + new_prefix = f"up_blocks.{i}.attentions.{j}" + old_prefix = f"output_blocks.{current_layer}.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.2" + # print(new_prefix) + # print(old_prefix) + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + + new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + + args = parser.parse_args() + + converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, args.dump_path) + image_unet = UNet2DModel(**UNET_CONFIG) + # print(image_unet) + # exit() + image_unet.load_state_dict(converted_unet_ckpt) + image_unet.save_pretrained(args.dump_path) \ No newline at end of file From b6c5e15bab797e24587fd6fc9d4289d79f7ba43d Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Mon, 22 May 2023 17:54:12 +0530 Subject: [PATCH 06/72] Fix bug in new unet blocks --- src/diffusers/models/unet_2d_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index cad0676f4eab..f94f178f11e7 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -877,7 +877,7 @@ def forward(self, hidden_states, temb=None, upsample_size=None): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, temb) output_states += (hidden_states,) @@ -2046,7 +2046,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, temb) return hidden_states From 4e93f09f2b88c25edfbb4ab6cc57b7a9befd9f51 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Tue, 23 May 2023 11:31:29 +0530 Subject: [PATCH 07/72] Fix attention weight loading --- scripts/convert_consistency_to_diffusers.py | 37 ++++++++++++++------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index bd269c735dde..37a89bb86693 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -37,19 +37,31 @@ def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip= return new_checkpoint -def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix): - weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) - bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim=64): + c, _, _, _ = checkpoint[f"{old_prefix}.qkv.weight"].shape + n_heads = c // (attention_head_dim*3) + old_weights = checkpoint[f"{old_prefix}.qkv.weight"].reshape(n_heads, attention_head_dim*3, -1, 1, 1) + old_biases = checkpoint[f"{old_prefix}.qkv.bias"].reshape(n_heads, attention_head_dim*3, -1, 1, 1) + + weight_q, weight_k, weight_v = old_weights.chunk(3, dim=1) + weight_q = weight_q.reshape(n_heads*attention_head_dim, -1, 1, 1) + weight_k = weight_k.reshape(n_heads*attention_head_dim, -1, 1, 1) + weight_v = weight_v.reshape(n_heads*attention_head_dim, -1, 1, 1) + + bias_q, bias_k, bias_v = old_biases.chunk(3, dim=1) + bias_q = bias_q.reshape(n_heads*attention_head_dim, -1, 1, 1) + bias_k = bias_k.reshape(n_heads*attention_head_dim, -1, 1, 1) + bias_v = bias_v.reshape(n_heads*attention_head_dim, -1, 1, 1) new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] - new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.weight"] = torch.squeeze(weight_q) + new_checkpoint[f"{new_prefix}.to_q.bias"] = torch.squeeze(bias_q) + new_checkpoint[f"{new_prefix}.to_k.weight"] = torch.squeeze(weight_k) + new_checkpoint[f"{new_prefix}.to_k.bias"] = torch.squeeze(bias_k) + new_checkpoint[f"{new_prefix}.to_v.weight"] = torch.squeeze(weight_v) + new_checkpoint[f"{new_prefix}.to_v.bias"] = torch.squeeze(bias_v) new_checkpoint[f"{new_prefix}.to_out.0.weight"] = checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) @@ -73,6 +85,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): down_block_types = UNET_CONFIG["down_block_types"] layers_per_block = UNET_CONFIG["layers_per_block"] + attention_head_dim = UNET_CONFIG["attention_head_dim"] current_layer = 1 for (i,layer_type) in enumerate(down_block_types): @@ -92,7 +105,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip) new_prefix = f"down_blocks.{i}.attentions.{j}" old_prefix = f"input_blocks.{current_layer}.1" - new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) current_layer += 1 if i!= len(down_block_types)-1: @@ -107,7 +120,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) new_prefix = f"mid_block.attentions.0" old_prefix = f"middle_block.1" - new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) new_prefix = f"mid_block.resnets.1" old_prefix = f"middle_block.2" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) @@ -129,7 +142,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) new_prefix = f"up_blocks.{i}.attentions.{j}" old_prefix = f"output_blocks.{current_layer}.1" - new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) current_layer += 1 new_prefix = f"up_blocks.{i}.upsamplers.0" From 9ae7669f58e2a4f74fe11dd969e0a6695d883558 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 25 May 2023 22:22:26 -0700 Subject: [PATCH 08/72] Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests. --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 1 + .../pipeline_consistency_models.py | 134 +++---- src/diffusers/schedulers/__init__.py | 1 + .../scheduling_consistency_models.py | 336 ++++++++++++++---- .../test_consistency_models.py | 184 +++++++++- 6 files changed, 496 insertions(+), 162 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9b3f8adad376..2f405aea22f0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -58,6 +58,7 @@ ) from .pipelines import ( AudioPipelineOutput, + ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, @@ -72,6 +73,7 @@ ScoreSdeVePipeline, ) from .schedulers import ( + CMStochasticIterativeScheduler, DDIMInverseScheduler, DDIMScheduler, DDPMScheduler, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9b44f4e5eb14..caa269eb364c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -16,6 +16,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 5f608626a8bc..2a4ffcd8394b 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -3,7 +3,7 @@ import torch -from ...models import UNet2DConditionModel +from ...models import UNet2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -21,20 +21,19 @@ def append_dims(x, target_dims): class ConsistencyModelPipeline(DiffusionPipeline): r""" - TODO + Sampling pipeline for consistency models. """ - def __init__(self, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers) -> None: + def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> None: super().__init__() self.register_modules( unet=unet, scheduler=scheduler, ) - - # Need to handle boundary conditions (e.g. c_skip, c_out, etc.) somewhere. - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + # Additionally prepare sigma_min, sigma_max kwargs for CM multistep scheduler + def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -44,6 +43,12 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta + + accepts_sigma_min = "sigma_min" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_sigma_min: + # Assume accepting sigma_min always means scheduler also accepts sigma_max + extra_step_kwargs["sigma_min"] = sigma_min + extra_step_kwargs["sigma_max"] = sigma_max # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) @@ -51,50 +56,13 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def get_sigma_min_max_from_scheduler(self): - # Get sigma_min, sigma_max in original sigma space, not Karras sigma space - # (e.g. not exponentiated by 1 / rho) - if hasattr(self.scheduler, "sigma_min"): - sigma_min = self.scheduler.sigma_min - sigma_max = self.scheduler.sigma_max - elif hasattr(self.scheduler, "sigmas"): - # Karras-style scheduler e.g. (EulerDiscreteScheduler, HeunDiscreteScheduler) - # Get sigma_min, sigma_max before they're converted into Karras sigma space by set_timesteps - # TODO: Karras schedulers are inconsistent about how they initialize sigmas in __init__ - # For example, EulerDiscreteScheduler gets sigmas in original sigma space, but HeunDiscreteScheduler - # initializes it through set_timesteps, which potentially leaves the sigmas in Karras sigma space. - # TODO: For example, in EulerDiscreteScheduler, a value of 0.0 is appended to the sigmas whern initialized - # in __init__. But wouldn't we usually want sigma_min to be a small positive number, following the - # consistency models paper? - # See e.g. https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L13 - sigma_min = self.scheduler.sigmas[-1].item() - sigma_max = self.scheduler.sigmas[0].item() - else: - raise ValueError( - f"Scheduler {self.scheduler.__class__} does not have sigma_min or sigma_max." - ) - return sigma_min, sigma_max - - def get_sigmas_from_scheduler(self): - if hasattr(self.scheduler, "sigmas"): - # e.g. HeunDiscreteScheduler - sigmas = self.scheduler.sigmas - elif hasattr(self.scheduler, "schedule"): - # e.g. KarrasVeScheduler - sigmas = self.scheduler.schedule - else: - raise ValueError( - f"Scheduler {self.scheduler.__class__} does not have sigmas." - ) - return sigmas - def get_scalings(self, sigma, sigma_data: float = 0.5): c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out, c_in - def get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data: float = 0.5): + def get_scalings_for_boundary_condition(self, sigma, sigma_min: float = 0.002, sigma_data: float = 0.5): # sigma_min should be in original sigma space, not in karras sigma space # (e.g. not exponentiated by 1 / rho) c_skip = sigma_data**2 / ( @@ -108,7 +76,7 @@ def get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data: float = 0. c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out, c_in - def denoise(self, x_t, sigma, sigma_min, sigma_data: float = 0.5, clip_denoised=True): + def denoise(self, x_t, sigma, sigma_min: float = 0.002, sigma_data: float = 0.5, clip_denoised=True): """ Run the consistency model forward...? """ @@ -116,7 +84,7 @@ def denoise(self, x_t, sigma, sigma_min, sigma_data: float = 0.5, clip_denoised= # (e.g. not exponentiated by 1 / rho) c_skip, c_out, c_in = [ append_dims(x, x_t.ndim) - for x in self.get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data=sigma_data) + for x in self.get_scalings_for_boundary_condition(sigma, sigma_min=sigma_min, sigma_data=sigma_data) ] rescaled_t = 1000 * 0.25 * torch.log(sigma + 1e-44) model_output = self.unet(c_in * x_t, rescaled_t).sample @@ -135,6 +103,8 @@ def __call__( batch_size: int = 1, num_inference_steps: int = 40, clip_denoised: bool = True, + sigma_min: float = 0.002, + sigma_max: float = 80.0, sigma_data: float = 0.5, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -169,59 +139,43 @@ def __call__( # 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma - # 2. Set timesteps and get sigmas - # Get sigma_min, sigma_max in original sigma space (not Karras sigma space) - sigma_min, sigma_max = self.get_sigma_min_max_from_scheduler() + # 2. Set timesteps self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps # Now get Karras sigma schedule (which I think the original implementation always uses) # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L376 - sigmas = self.get_sigmas_from_scheduler() + # TODO: how do we ensure that this in Karras sigma space rather than in "original" sigma space? + # 3. Get sigma schedule + assert hasattr(self.scheduler, "sigmas"), "Scheduler needs to operate in sigma space" + sigmas = self.scheduler.sigmas - # 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 4. Denoising loop - # TODO: hack, is there a better way to identify schedulers that implement the stochastic iterative sampling - # similar to stochastic_iterative_sampler in the original code? - if hasattr(self.scheduler, "add_noise_to_input"): - # 4.1 Consistency Model Stochastic Iterative Scheduler (multi-step sampling) - # First evaluate the consistency model. This will be the output sample if num_inference_steps == 1 - # TODO: not all schedulers have an index_for_timestep method (e.g. KarrasVeScheduler) - step_idx = self.scheduler.index_for_timestep(timesteps[0]) - sigma = sigmas[step_idx] - _, sample = self.denoise(sample, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised) - - # If num_inference_steps > 1, perform multi-step sampling (stochastic_iterative_sampler) - # Alternate adding noise and evaluating the consistency model on the noised input - for i, t in self.progress_bar(enumerate(self.scheduler.timesteps[1:])): - step_idx = self.scheduler.index_for_timestep(t) - sigma = sigmas[step_idx] - sigma_prev = sigmas[step_idx - 1] - sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)[0] - - model_output, denoised = self.denoise( - sample_hat, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised - ) - - sample = self.scheduler.step(denoised, sigma_hat, sigma_prev, sample_hat).prev_sample - elif hasattr(self.scheduler, "sigmas"): - # 4.2 Karras-style scheduler in sigma space (e.g. HeunDiscreteScheduler) - # TODO: warmup steps logic correct? + # 4. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, sigma_min, sigma_max) + + # 5. Denoising loop + if num_inference_steps == 1: + # Onestep sampling: simply evaluate the consistency model at the first sigma + # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L643 + sigma = sigmas[0] + _, sample = self.denoise( + sample, sigma, sigma_min=sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised + ) + else: + # Multistep sampling or Karras sampler num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - step_idx = self.scheduler.index_for_timestep(t) - sigma = self.scheduler.sigmas[step_idx] - # TODO: handle class labels? + # Don't loop over last timestep + for i, t in enumerate(timesteps[:-1]): + sigma = sigmas[i] + # TODO: handle class labels # TODO: check shapes, might need equivalent of s_in in original code # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L510 model_output, denoised = self.denoise( - sample, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised + sample, sigma, sigma_min=sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised ) - # Karras-style schedulers already convert to a ODE derivative inside step() + # Works for both Karras-style schedulers (e.g. Euler, Heun) and the CM multistep scheduler sample = self.scheduler.step(denoised, t, sample, **extra_step_kwargs).prev_sample # Note: differs from callback support in original code @@ -231,12 +185,8 @@ def __call__( progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, sample) - else: - raise ValueError( - f"Scheduler {self.scheduler.__class__} is not compatible with consistency models." - ) - # 5. Post-process image sample + # 6. Post-process image sample sample = (sample / 2 + 0.5).clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 05414e32fc9e..79e8b7e34ecf 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -28,6 +28,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddpm import DDPMScheduler diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 8aa5ecd0d7ae..88b0785ccfb5 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor +from ..utils import BaseOutput, logging, randn_tensor from .scheduling_utils import SchedulerMixin +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) @@ -48,55 +51,119 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput): # pred_original_sample: Optional[torch.FloatTensor] = None +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic - differential equations." https://arxiv.org/abs/2011.13456 + + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" https://arxiv.org/pdf/2303.01469 + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + Args: - sigma_min (`float`): minimum noise magnitude - sigma_max (`float`): maximum noise magnitude - s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. - A reasonable range is [1.000, 1.011]. - s_churn (`float`): the parameter controlling the overall amount of stochasticity. - A reasonable range is [0, 100]. - s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). - A reasonable range is [0, 10]. - s_max (`float`): the end value of the sigma range where we add noise. - A reasonable range is [0.2, 80]. + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `"epsilon"`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + interpolation_type (`str`, default `"linear"`, optional): + interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of + [`"linear"`, `"log_linear"`]. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ + order = 1 + @register_to_config def __init__( self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, sigma_data: float = 0.5, - sigma_min: float = 0.002, - sigma_max: float = 80.0, rho: float = 7.0, - s_noise: float = 1.0, - s_churn: float = 0.0, - s_min: float = 0.0, - s_max: float = float('inf'), + prediction_type: str = "sample", + interpolation_type: str = "linear", + use_karras_sigmas: Optional[bool] = True, ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + # standard deviation of the initial noise distribution - self.init_noise_sigma = sigma_max + self.init_noise_sigma = self.sigmas.max() # setable values - self.num_inference_steps: int = None - self.timesteps: np.IntTensor = None - self.schedule: torch.FloatTensor = None # sigma(t_i) - + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.sigma_data = sigma_data self.rho = rho + self.is_scale_input_called = False + self.use_karras_sigmas = use_karras_sigmas def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: @@ -116,40 +183,118 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = `torch.FloatTensor`: scaled input sample """ return sample - - def get_sigmas_karras(self): - """Constructs the noise schedule of Karras et al. (2022).""" - ramp = np.linspace(0, 1, self.num_inference_steps) - min_inv_rho = self.sigma_min ** (1 / self.rho) - max_inv_rho = self.sigma_max ** (1 / self.rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho - return append_zero(sigmas) def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ - Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps - # TODO: how should timesteps be set? the original code seems to either solely work in sigma space or have - # hardcoded timesteps (see e.g. https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L74) - # TODO: should add num_train_timesteps here??? - timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() - sigmas = self.get_sigmas_karras() - - self.timesteps = torch.from_numpy(timesteps).to(device) - self.sigmas = torch.tensor(sigmas, dtype=torch.float32, device=device) + + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.interpolation_type == "linear": + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + elif self.config.interpolation_type == "log_linear": + sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() + else: + raise ValueError( + f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" + " 'linear' or 'log_linear'" + ) + + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + # Copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Modified from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras + # Use self.rho instead of hardcoded 7.0 for rho + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = self.rho + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas - def add_noise(self, original_samples, noise, timesteps): - """Add noise for training.""" - raise NotImplementedError() + # TODO: add_noise meant to be called during training (forward diffusion process) + # TODO: need to check if this corresponds to noise added during training + # TODO: may want multiple add_noise-type methods for CD, CT + # Copied from diffusers.schedulers.scheduling_euler_discrete.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples def add_noise_to_input( self, sample: torch.FloatTensor, sigma: float, + sigma_min: float = 0.002, + sigma_max: float = 80.0, generator: Optional[torch.Generator] = None ) -> Tuple[torch.FloatTensor, float]: """ @@ -157,9 +302,6 @@ def add_noise_to_input( higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: """ - sigma_min = self.config.sigma_min - sigma_max = self.config.sigma_max - step_idx = (self.sigmas == sigma).nonzero().item() sigma_hat = self.sigmas[step_idx + 1].clamp(min=sigma_min, max=sigma_max) @@ -170,13 +312,16 @@ def add_noise_to_input( sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z) return sample_hat, sigma_hat - + def step( self, model_output: torch.FloatTensor, - sigma_hat: float, - sigma_prev: float, - sample_hat: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: """ @@ -184,23 +329,78 @@ def step( process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor`): TODO - return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class - KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class Returns: - [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: - [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. """ - # Assume model output is the consistency model evaluated at sample_hat. - sample_prev = model_output - if not return_dict: - return (sample_prev,) + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) - return CMStochasticIterativeSchedulerOutput( - prev_sample=sample_prev, + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sigma_next = self.sigmas[step_index + 1] + + # sample z ~ N(0, s_noise^2 * I) + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) + z = noise * s_noise + + sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Return noisy sample + # tau = sigma_hat, eps = sigma_min + prev_sample = pred_original_sample + z * (sigma_hat**2 - sigma_min**2) ** 0.5 + + if not return_dict: + return (prev_sample,) + + return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) + def __len__(self): + return self.config.num_train_timesteps diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index ee829c24f846..f84778bff75c 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -5,17 +5,197 @@ import numpy as np import torch from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from diffusers import ( + CMStochasticIterativeScheduler, + ConsistencyModelPipeline, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + UNet2DModel, +) from diffusers.utils import floats_tensor, load_image, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu +from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin class ConsistencyModelPipelineFastTests( PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase ): - pass + params = UNCONDITIONAL_IMAGE_GENERATION_PARAMS + batch_params = UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS + + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("ResnetDownsampleBlock2D", "AttnDownsampleBlock2D"), + up_block_types=("AttnUpsampleBlock2D", "ResnetUpsampleBlock2D"), + ) + return model + + def get_dummy_components(self): + unet = self.dummy_uncond_unet + + # Default to CM multistep sampler + # TODO: need to determine most sensible settings for these args + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + ) + + components = { + "unet": unet, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "num_inference_steps": 2, + "clip_denoised": True, + "sigma_min": 0.002, + "sigma_max": 80.0, + "sigma_data": 0.5, + "generator": generator, + "output_type": "numpy", + } + + return inputs + + def test_consistency_model_pipeline_multistep(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + # TODO: get correct expected_slice + expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_onestep(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + # TODO: get correct expected_slice + expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_k_dpm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_uncond_unet + # TODO: get reasonable args for KDPM2DiscreteScheduler + scheduler = KDPM2DiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="linear" + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + # TODO: get correct expected_slice + expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_k_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_uncond_unet + # TODO: get reasonable args for EulerDiscreteScheduler + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="linear" + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + # TODO: get correct expected_slice + expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_k_euler_ancestral(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_uncond_unet + # TODO: get reasonable args for EulerAncestralDiscreteScheduler + scheduler = EulerAncestralDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="linear" + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + # TODO: get correct expected_slice + expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_k_heun(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_uncond_unet + # TODO: get reasonable args for HeunDiscreteScheduler + scheduler = HeunDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="linear" + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + # TODO: get correct expected_slice + expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + @slow @require_torch_gpu From 54b287e79aae7af17703e01821a427886ee4c034 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 25 May 2023 22:23:52 -0700 Subject: [PATCH 09/72] make style --- scripts/convert_consistency_to_diffusers.py | 84 +++++++++++-------- src/diffusers/models/unet_2d_blocks.py | 3 +- .../pipelines/consistency_models/__init__.py | 2 +- .../pipeline_consistency_models.py | 39 ++++----- .../scheduling_consistency_models.py | 34 ++++---- .../test_consistency_models.py | 43 ++++------ 6 files changed, 100 insertions(+), 105 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 37a89bb86693..bc96918554d2 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -1,20 +1,26 @@ import argparse -import io -from diffusers.models.unet_2d import UNet2DModel -import requests + import torch +from diffusers.models.unet_2d import UNet2DModel + + UNET_CONFIG = { "sample_size": 64, "in_channels": 3, "out_channels": 3, - "layers_per_block" : 3, + "layers_per_block": 3, "num_class_embeds": 1000, - "block_out_channels": [192, 192*2, 192*3, 192*4], - "attention_head_dim" : 64, - "down_block_types": ["ResnetDownsampleBlock2D", "AttnDownsampleBlock2D", "AttnDownsampleBlock2D", "AttnDownsampleBlock2D"], + "block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "AttnDownsampleBlock2D", + "AttnDownsampleBlock2D", + "AttnDownsampleBlock2D", + ], "up_block_types": ["AttnUpsampleBlock2D", "AttnUpsampleBlock2D", "AttnUpsampleBlock2D", "ResnetUpsampleBlock2D"], - "resnet_time_scale_shift" : "scale_shift" + "resnet_time_scale_shift": "scale_shift", } @@ -34,24 +40,24 @@ def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip= new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"] new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"] - return new_checkpoint + def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim=64): c, _, _, _ = checkpoint[f"{old_prefix}.qkv.weight"].shape - n_heads = c // (attention_head_dim*3) - old_weights = checkpoint[f"{old_prefix}.qkv.weight"].reshape(n_heads, attention_head_dim*3, -1, 1, 1) - old_biases = checkpoint[f"{old_prefix}.qkv.bias"].reshape(n_heads, attention_head_dim*3, -1, 1, 1) + n_heads = c // (attention_head_dim * 3) + old_weights = checkpoint[f"{old_prefix}.qkv.weight"].reshape(n_heads, attention_head_dim * 3, -1, 1, 1) + old_biases = checkpoint[f"{old_prefix}.qkv.bias"].reshape(n_heads, attention_head_dim * 3, -1, 1, 1) weight_q, weight_k, weight_v = old_weights.chunk(3, dim=1) - weight_q = weight_q.reshape(n_heads*attention_head_dim, -1, 1, 1) - weight_k = weight_k.reshape(n_heads*attention_head_dim, -1, 1, 1) - weight_v = weight_v.reshape(n_heads*attention_head_dim, -1, 1, 1) + weight_q = weight_q.reshape(n_heads * attention_head_dim, -1, 1, 1) + weight_k = weight_k.reshape(n_heads * attention_head_dim, -1, 1, 1) + weight_v = weight_v.reshape(n_heads * attention_head_dim, -1, 1, 1) bias_q, bias_k, bias_v = old_biases.chunk(3, dim=1) - bias_q = bias_q.reshape(n_heads*attention_head_dim, -1, 1, 1) - bias_k = bias_k.reshape(n_heads*attention_head_dim, -1, 1, 1) - bias_v = bias_v.reshape(n_heads*attention_head_dim, -1, 1, 1) + bias_q = bias_q.reshape(n_heads * attention_head_dim, -1, 1, 1) + bias_k = bias_k.reshape(n_heads * attention_head_dim, -1, 1, 1) + bias_v = bias_v.reshape(n_heads * attention_head_dim, -1, 1, 1) new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] @@ -63,7 +69,9 @@ def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attent new_checkpoint[f"{new_prefix}.to_v.weight"] = torch.squeeze(weight_v) new_checkpoint[f"{new_prefix}.to_v.bias"] = torch.squeeze(bias_v) - new_checkpoint[f"{new_prefix}.to_out.0.weight"] = checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = ( + checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + ) new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) return new_checkpoint @@ -88,15 +96,14 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): attention_head_dim = UNET_CONFIG["attention_head_dim"] current_layer = 1 - for (i,layer_type) in enumerate(down_block_types): - + for i, layer_type in enumerate(down_block_types): if layer_type == "ResnetDownsampleBlock2D": for j in range(layers_per_block): new_prefix = f"down_blocks.{i}.resnets.{j}" old_prefix = f"input_blocks.{current_layer}.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) current_layer += 1 - + elif layer_type == "AttnDownsampleBlock2D": for j in range(layers_per_block): new_prefix = f"down_blocks.{i}.resnets.{j}" @@ -105,53 +112,56 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip) new_prefix = f"down_blocks.{i}.attentions.{j}" old_prefix = f"input_blocks.{current_layer}.1" - new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) + new_checkpoint = convert_attention( + checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim + ) current_layer += 1 - if i!= len(down_block_types)-1: + if i != len(down_block_types) - 1: new_prefix = f"down_blocks.{i}.downsamplers.0" old_prefix = f"input_blocks.{current_layer}.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) current_layer += 1 # hardcoded the mid-block for now - new_prefix = f"mid_block.resnets.0" - old_prefix = f"middle_block.0" + new_prefix = "mid_block.resnets.0" + old_prefix = "middle_block.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) - new_prefix = f"mid_block.attentions.0" - old_prefix = f"middle_block.1" + new_prefix = "mid_block.attentions.0" + old_prefix = "middle_block.1" new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) - new_prefix = f"mid_block.resnets.1" - old_prefix = f"middle_block.2" + new_prefix = "mid_block.resnets.1" + old_prefix = "middle_block.2" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) current_layer = 0 up_block_types = UNET_CONFIG["up_block_types"] - for (i, layer_type) in enumerate (up_block_types): + for i, layer_type in enumerate(up_block_types): if layer_type == "ResnetUpsampleBlock2D": - for j in range(layers_per_block+1): + for j in range(layers_per_block + 1): new_prefix = f"up_blocks.{i}.resnets.{j}" old_prefix = f"output_blocks.{current_layer}.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) current_layer += 1 elif layer_type == "AttnUpsampleBlock2D": - for j in range(layers_per_block+1): + for j in range(layers_per_block + 1): new_prefix = f"up_blocks.{i}.resnets.{j}" old_prefix = f"output_blocks.{current_layer}.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) new_prefix = f"up_blocks.{i}.attentions.{j}" old_prefix = f"output_blocks.{current_layer}.1" - new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) + new_checkpoint = convert_attention( + checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim + ) current_layer += 1 - + new_prefix = f"up_blocks.{i}.upsamplers.0" old_prefix = f"output_blocks.{current_layer-1}.2" # print(new_prefix) # print(old_prefix) new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) - new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"] @@ -173,4 +183,4 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): # print(image_unet) # exit() image_unet.load_state_dict(converted_unet_ckpt) - image_unet.save_pretrained(args.dump_path) \ No newline at end of file + image_unet.save_pretrained(args.dump_path) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f94f178f11e7..ff7b187ca40a 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -860,7 +860,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - down=True + down=True, ) ] ) @@ -884,7 +884,6 @@ def forward(self, hidden_states, temb=None, upsample_size=None): return hidden_states, output_states - class CrossAttnDownBlock2D(nn.Module): def __init__( self, diff --git a/src/diffusers/pipelines/consistency_models/__init__.py b/src/diffusers/pipelines/consistency_models/__init__.py index 52cfd0ba939b..fd78ddb3aae2 100644 --- a/src/diffusers/pipelines/consistency_models/__init__.py +++ b/src/diffusers/pipelines/consistency_models/__init__.py @@ -1 +1 @@ -from .pipeline_consistency_models import ConsistencyModelPipeline \ No newline at end of file +from .pipeline_consistency_models import ConsistencyModelPipeline diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 2a4ffcd8394b..53e90f97b226 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional, Tuple, Union, Callable +from typing import Callable, List, Optional, Union import torch @@ -13,9 +13,7 @@ def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] @@ -23,6 +21,7 @@ class ConsistencyModelPipeline(DiffusionPipeline): r""" Sampling pipeline for consistency models. """ + def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> None: super().__init__() @@ -30,7 +29,7 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> N unet=unet, scheduler=scheduler, ) - + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Additionally prepare sigma_min, sigma_max kwargs for CM multistep scheduler def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): @@ -43,7 +42,7 @@ def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta - + accepts_sigma_min = "sigma_min" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_sigma_min: # Assume accepting sigma_min always means scheduler also accepts sigma_max @@ -55,27 +54,21 @@ def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs - + def get_scalings(self, sigma, sigma_data: float = 0.5): c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out, c_in - + def get_scalings_for_boundary_condition(self, sigma, sigma_min: float = 0.002, sigma_data: float = 0.5): # sigma_min should be in original sigma space, not in karras sigma space # (e.g. not exponentiated by 1 / rho) - c_skip = sigma_data**2 / ( - (sigma - sigma_min) ** 2 + sigma_data**2 - ) - c_out = ( - (sigma - sigma_min) - * sigma_data - / (sigma**2 + sigma_data**2) ** 0.5 - ) + c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) + c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out, c_in - + def denoise(self, x_t, sigma, sigma_min: float = 0.002, sigma_data: float = 0.5, clip_denoised=True): """ Run the consistency model forward...? @@ -92,11 +85,11 @@ def denoise(self, x_t, sigma, sigma_min: float = 0.002, sigma_data: float = 0.5, if clip_denoised: denoised = denoised.clamp(-1, 1) return model_output, denoised - + def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / append_dims(sigma, x.ndim) - + @torch.no_grad() def __call__( self, @@ -149,7 +142,7 @@ def __call__( # 3. Get sigma schedule assert hasattr(self.scheduler, "sigmas"), "Scheduler needs to operate in sigma space" sigmas = self.scheduler.sigmas - + # 4. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, sigma_min, sigma_max) @@ -185,17 +178,17 @@ def __call__( progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, sample) - + # 6. Post-process image sample sample = (sample / 2 + 0.5).clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": sample = self.numpy_to_pil(sample) - + if not return_dict: return (sample,) - + # TODO: Offload to cpu? return ImagePipelineOutput(images=sample) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 88b0785ccfb5..b261ddc32d49 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -35,6 +35,7 @@ def append_zero(x): class CMStochasticIterativeSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. + Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the @@ -55,9 +56,9 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. + (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the + cumulative product of (1-beta) up to that part of the diffusion process. + Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to @@ -82,7 +83,8 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. - [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" https://arxiv.org/pdf/2303.01469 + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" + https://arxiv.org/pdf/2303.01469 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. @@ -145,7 +147,7 @@ def __init__( self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -164,7 +166,7 @@ def __init__( self.rho = rho self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas - + def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -176,6 +178,7 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. + Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep @@ -187,6 +190,7 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. @@ -220,7 +224,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) - + # Copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma @@ -244,7 +248,7 @@ def _sigma_to_t(self, sigma, log_sigmas): t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t - + # Modified from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras # Use self.rho instead of hardcoded 7.0 for rho def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: @@ -259,7 +263,7 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - + # TODO: add_noise meant to be called during training (forward diffusion process) # TODO: need to check if this corresponds to noise added during training # TODO: may want multiple add_noise-type methods for CD, CT @@ -295,12 +299,11 @@ def add_noise_to_input( sigma: float, sigma_min: float = 0.002, sigma_max: float = 80.0, - generator: Optional[torch.Generator] = None + generator: Optional[torch.Generator] = None, ) -> Tuple[torch.FloatTensor, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a - higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. - TODO Args: + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: """ step_idx = (self.sigmas == sigma).nonzero().item() sigma_hat = self.sigmas[step_idx + 1].clamp(min=sigma_min, max=sigma_max) @@ -312,7 +315,7 @@ def add_noise_to_input( sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z) return sample_hat, sigma_hat - + def step( self, model_output: torch.FloatTensor, @@ -327,6 +330,7 @@ def step( """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). + Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. @@ -401,6 +405,6 @@ def step( return (prev_sample,) return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) - + def __len__(self): return self.config.num_train_timesteps diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index f84778bff75c..daf83f71eaea 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -1,10 +1,8 @@ import gc -import random import unittest import numpy as np import torch -from PIL import Image from diffusers import ( CMStochasticIterativeScheduler, @@ -15,15 +13,14 @@ KDPM2DiscreteScheduler, UNet2DModel, ) -from diffusers.utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils import slow from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -class ConsistencyModelPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase -): + +class ConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): params = UNCONDITIONAL_IMAGE_GENERATION_PARAMS batch_params = UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS @@ -58,13 +55,13 @@ def get_dummy_components(self): } return components - + def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - + inputs = { "num_inference_steps": 2, "clip_denoised": True, @@ -76,7 +73,7 @@ def get_dummy_inputs(self, device, seed=0): } return inputs - + def test_consistency_model_pipeline_multistep(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -93,7 +90,7 @@ def test_consistency_model_pipeline_multistep(self): expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_onestep(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -111,14 +108,12 @@ def test_consistency_model_pipeline_onestep(self): expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_k_dpm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_uncond_unet # TODO: get reasonable args for KDPM2DiscreteScheduler - scheduler = KDPM2DiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="linear" - ) + scheduler = KDPM2DiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear") pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -132,14 +127,12 @@ def test_consistency_model_pipeline_k_dpm(self): expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_k_euler(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_uncond_unet # TODO: get reasonable args for EulerDiscreteScheduler - scheduler = EulerDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="linear" - ) + scheduler = EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear") pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -153,14 +146,12 @@ def test_consistency_model_pipeline_k_euler(self): expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_k_euler_ancestral(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_uncond_unet # TODO: get reasonable args for EulerAncestralDiscreteScheduler - scheduler = EulerAncestralDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="linear" - ) + scheduler = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear") pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -174,14 +165,12 @@ def test_consistency_model_pipeline_k_euler_ancestral(self): expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_k_heun(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_uncond_unet # TODO: get reasonable args for HeunDiscreteScheduler - scheduler = HeunDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="linear" - ) + scheduler = HeunDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear") pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -203,4 +192,4 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() \ No newline at end of file + torch.cuda.empty_cache() From 067a9efd5476b679a9a05def3738e83eeee03eda Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Fri, 26 May 2023 16:47:05 +0530 Subject: [PATCH 10/72] Add initial training script --- examples/consistency_models/requirements.txt | 3 + .../train_consistency_distillation.py | 694 ++++++++++++++++++ 2 files changed, 697 insertions(+) create mode 100644 examples/consistency_models/requirements.txt create mode 100644 examples/consistency_models/train_consistency_distillation.py diff --git a/examples/consistency_models/requirements.txt b/examples/consistency_models/requirements.txt new file mode 100644 index 000000000000..f366720afd11 --- /dev/null +++ b/examples/consistency_models/requirements.txt @@ -0,0 +1,3 @@ +accelerate>=0.16.0 +torchvision +datasets diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py new file mode 100644 index 000000000000..a14d652c208b --- /dev/null +++ b/examples/consistency_models/train_consistency_distillation.py @@ -0,0 +1,694 @@ +import argparse +import inspect +import logging +import math +import os +from pathlib import Path +from typing import Optional + +import accelerate +import datasets +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from datasets import load_dataset +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm + +import diffusers +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +#Copied from examples/unconditional_image_generation/train_unconditional.py for now + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + if not isinstance(arr, torch.Tensor): + arr = torch.from_numpy(arr) + res = arr[timesteps].float().to(timesteps.device) + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that HF Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--model_config_name_or_path", + type=str, + default=None, + help="The config of the UNet model to train, leave as None to use standard DDPM configuration.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="ddpm-model-64", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--resolution", + type=int, + default=64, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + default=False, + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") + parser.add_argument( + "--save_model_epochs", type=int, default=10, help="How often to save the model during training." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer." + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.") + parser.add_argument( + "--use_ema", + action="store_true", + help="Whether to use Exponential Moving Average for the final model weights.", + ) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") + parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") + parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." + ) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + help=( + "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)" + " for experiment tracking and logging of model metrics and model checkpoints" + ), + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--prediction_type", + type=str, + default="epsilon", + choices=["epsilon", "sample"], + help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", + ) + parser.add_argument("--ddpm_num_steps", type=int, default=1000) + parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000) + parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(args): + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.logger, + logging_dir=logging_dir, + project_config=accelerator_project_config, + ) + + if args.logger == "tensorboard": + if not is_tensorboard_available(): + raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") + + elif args.logger == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Initialize the model + if args.model_config_name_or_path is None: + model = UNet2DModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + else: + config = UNet2DModel.load_config(args.model_config_name_or_path) + model = UNet2DModel.from_config(config) + + # Create EMA for the model. + if args.use_ema: + ema_model = EMAModel( + model.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + model_cls=UNet2DModel, + model_config=model.config, + ) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + model.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Initialize the scheduler + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + if accepts_prediction_type: + noise_scheduler = DDPMScheduler( + num_train_timesteps=args.ddpm_num_steps, + beta_schedule=args.ddpm_beta_schedule, + prediction_type=args.prediction_type, + ) + else: + noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) + + # Initialize the optimizer + optimizer = torch.optim.AdamW( + model.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + split="train", + ) + else: + dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets and DataLoaders creation. + augmentations = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def transform_images(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + return {"input": images} + + logger.info(f"Dataset size: {len(dataset)}") + + dataset.set_transform(transform_images) + train_dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) + + # Initialize the learning rate scheduler + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs), + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_model.to(accelerator.device) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run) + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + max_train_steps = args.num_epochs * num_update_steps_per_epoch + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num Epochs = {args.num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Train! + for epoch in range(first_epoch, args.num_epochs): + model.train() + progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + clean_images = batch["input"] + # Sample noise that we'll add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bsz = clean_images.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device + ).long() + + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) + + with accelerator.accumulate(model): + # Predict the noise residual + model_output = model(noisy_images, timesteps).sample + + if args.prediction_type == "epsilon": + loss = F.mse_loss(model_output, noise) # this could have different weights! + elif args.prediction_type == "sample": + alpha_t = _extract_into_tensor( + noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) + ) + snr_weights = alpha_t / (1 - alpha_t) + loss = snr_weights * F.mse_loss( + model_output, clean_images, reduction="none" + ) # use SNR weighting from distillation paper + loss = loss.mean() + else: + raise ValueError(f"Unsupported prediction type: {args.prediction_type}") + + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_model.step(model.parameters()) + progress_bar.update(1) + global_step += 1 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + if args.use_ema: + logs["ema_decay"] = ema_model.cur_decay_value + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + progress_bar.close() + + accelerator.wait_for_everyone() + + # Generate sample images for visual inspection + if accelerator.is_main_process: + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + unet = accelerator.unwrap_model(model) + + if args.use_ema: + ema_model.store(unet.parameters()) + ema_model.copy_to(unet.parameters()) + + pipeline = DDPMPipeline( + unet=unet, + scheduler=noise_scheduler, + ) + + generator = torch.Generator(device=pipeline.device).manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + images = pipeline( + generator=generator, + batch_size=args.eval_batch_size, + num_inference_steps=args.ddpm_num_inference_steps, + output_type="numpy", + ).images + + if args.use_ema: + ema_model.restore(unet.parameters()) + + # denormalize the images and save to tensorboard + images_processed = (images * 255).round().astype("uint8") + + if args.logger == "tensorboard": + if is_accelerate_version(">=", "0.17.0.dev0"): + tracker = accelerator.get_tracker("tensorboard", unwrap=True) + else: + tracker = accelerator.get_tracker("tensorboard") + tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch) + elif args.logger == "wandb": + # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files + accelerator.get_tracker("wandb").log( + {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, + step=global_step, + ) + + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + # save the model + unet = accelerator.unwrap_model(model) + + if args.use_ema: + ema_model.store(unet.parameters()) + ema_model.copy_to(unet.parameters()) + + pipeline = DDPMPipeline( + unet=unet, + scheduler=noise_scheduler, + ) + + pipeline.save_pretrained(args.output_dir) + + if args.use_ema: + ema_model.restore(unet.parameters()) + + if args.push_to_hub: + repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 5a27c2f467555879c80ba70cf63d5a3bf05b25a7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 18:44:38 -0700 Subject: [PATCH 11/72] Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints. --- tests/pipelines/consistency_models/test_consistency_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index daf83f71eaea..eca99ec11c89 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -35,6 +35,8 @@ def dummy_uncond_unet(self): out_channels=3, down_block_types=("ResnetDownsampleBlock2D", "AttnDownsampleBlock2D"), up_block_types=("AttnUpsampleBlock2D", "ResnetUpsampleBlock2D"), + num_class_embeds=10, + resnet_time_scale_shift="scale_shift", ) return model From f2783a87f5322c14878ef0fa1cc6d980830f2139 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 18:45:54 -0700 Subject: [PATCH 12/72] Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script. --- scripts/convert_consistency_to_diffusers.py | 58 +++++++++++++++++---- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index bc96918554d2..78e998e1c577 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -4,6 +4,24 @@ from diffusers.models.unet_2d import UNet2DModel +TEST_UNET_CONFIG = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 2, + "num_class_embeds": 10, + "block_out_channels": [32, 64], + "attention_head_dim": 8, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "AttnDownsampleBlock2D", + ], + "up_block_types": [ + "AttnUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", +} UNET_CONFIG = { "sample_size": 64, @@ -19,7 +37,12 @@ "AttnDownsampleBlock2D", "AttnDownsampleBlock2D", ], - "up_block_types": ["AttnUpsampleBlock2D", "AttnUpsampleBlock2D", "AttnUpsampleBlock2D", "ResnetUpsampleBlock2D"], + "up_block_types": [ + "AttnUpsampleBlock2D", + "AttnUpsampleBlock2D", + "AttnUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], "resnet_time_scale_shift": "scale_shift", } @@ -77,7 +100,7 @@ def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attent return new_checkpoint -def con_pt_to_diffuser(checkpoint_path: str, output_path: str): +def con_pt_to_diffuser(checkpoint_path: str, unet_config): checkpoint = torch.load(checkpoint_path, map_location="cpu") new_checkpoint = {} @@ -86,14 +109,15 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] - new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"] + if unet_config["num_class_embeds"] is not None: + new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"] new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] - down_block_types = UNET_CONFIG["down_block_types"] - layers_per_block = UNET_CONFIG["layers_per_block"] - attention_head_dim = UNET_CONFIG["attention_head_dim"] + down_block_types = unet_config["down_block_types"] + layers_per_block = unet_config["layers_per_block"] + attention_head_dim = unet_config["attention_head_dim"] current_layer = 1 for i, layer_type in enumerate(down_block_types): @@ -135,7 +159,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) current_layer = 0 - up_block_types = UNET_CONFIG["up_block_types"] + up_block_types = unet_config["up_block_types"] for i, layer_type in enumerate(up_block_types): if layer_type == "ResnetUpsampleBlock2D": @@ -174,12 +198,26 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): parser = argparse.ArgumentParser() parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") - parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model.") + parser.add_argument("--checkpoint_name", default="cd_imagenet64_l2", type=str, help="Checkpoint to convert.") args = parser.parse_args() - converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, args.dump_path) - image_unet = UNet2DModel(**UNET_CONFIG) + if args.checkpoint_name == "cd_imagenet64_l2": + unet_config = UNET_CONFIG + elif args.checkpoint_name == "test": + unet_config = TEST_UNET_CONFIG + unet_config["num_class_embeds"] = None + elif args.checkpoint_name == "test_class_cond": + unet_config = TEST_UNET_CONFIG + else: + raise ValueError( + f"Checkpoint type {args.checkpoint_name} is not currently supported." + ) + + converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config) + + image_unet = UNet2DModel(**unet_config) # print(image_unet) # exit() image_unet.load_state_dict(converted_unet_ckpt) From ed53b8592a8dc2dd3f0bbf8a28d134a6fee32e6e Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 18:46:48 -0700 Subject: [PATCH 13/72] make style --- scripts/convert_consistency_to_diffusers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 78e998e1c577..0a64eb981a09 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -4,6 +4,7 @@ from diffusers.models.unet_2d import UNet2DModel + TEST_UNET_CONFIG = { "sample_size": 32, "in_channels": 3, @@ -198,7 +199,9 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): parser = argparse.ArgumentParser() parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") - parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model.") + parser.add_argument( + "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model." + ) parser.add_argument("--checkpoint_name", default="cd_imagenet64_l2", type=str, help="Checkpoint to convert.") args = parser.parse_args() @@ -211,9 +214,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): elif args.checkpoint_name == "test_class_cond": unet_config = TEST_UNET_CONFIG else: - raise ValueError( - f"Checkpoint type {args.checkpoint_name} is not currently supported." - ) + raise ValueError(f"Checkpoint type {args.checkpoint_name} is not currently supported.") converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config) From a505c6ccd442748979cbe1064c1e22f7812ff6ad Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 20:06:46 -0700 Subject: [PATCH 14/72] Change num_class_embeds to 1000 to better match the original consistency models implementation. --- scripts/convert_consistency_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 0a64eb981a09..dda6a6790ed2 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -10,7 +10,7 @@ "in_channels": 3, "out_channels": 3, "layers_per_block": 2, - "num_class_embeds": 10, + "num_class_embeds": 1000, "block_out_channels": [32, 64], "attention_head_dim": 8, "down_block_types": [ From a927a4af16011d4cd0b1a903129b9f5c4b1774d1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 20:12:02 -0700 Subject: [PATCH 15/72] Add support for distillation in pipeline_consistency_models.py. --- .../pipeline_consistency_models.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 53e90f97b226..774fc8ac29a6 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -22,7 +22,7 @@ class ConsistencyModelPipeline(DiffusionPipeline): Sampling pipeline for consistency models. """ - def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> None: + def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers, distillation: bool = False) -> None: super().__init__() self.register_modules( @@ -30,6 +30,8 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> N scheduler=scheduler, ) + self.distillation = distillation + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Additionally prepare sigma_min, sigma_max kwargs for CM multistep scheduler def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): @@ -69,16 +71,28 @@ def get_scalings_for_boundary_condition(self, sigma, sigma_min: float = 0.002, s c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out, c_in - def denoise(self, x_t, sigma, sigma_min: float = 0.002, sigma_data: float = 0.5, clip_denoised=True): + def denoise( + self, + x_t, + sigma, + sigma_min: float = 0.002, + sigma_data: float = 0.5, + clip_denoised=True, + ): """ Run the consistency model forward...? """ # sigma_min should be in original sigma space, not in karras sigma space # (e.g. not exponentiated by 1 / rho) - c_skip, c_out, c_in = [ - append_dims(x, x_t.ndim) - for x in self.get_scalings_for_boundary_condition(sigma, sigma_min=sigma_min, sigma_data=sigma_data) - ] + if self.distillation: + c_skip, c_out, c_in = [ + append_dims(x, x_t.ndim) + for x in self.get_scalings_for_boundary_condition(sigma, sigma_min=sigma_min, sigma_data=sigma_data) + ] + else: + c_skip, c_out, c_in = [ + append_dims(x, x_t.ndim) for x in self.get_scalings(sigma, sigma_data=sigma_data) + ] rescaled_t = 1000 * 0.25 * torch.log(sigma + 1e-44) model_output = self.unet(c_in * x_t, rescaled_t).sample denoised = c_out * model_output + c_skip * x_t From b2e64243e8401cdba359ba580cf8bb849fbe45b1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 20:12:37 -0700 Subject: [PATCH 16/72] Improve consistency model tests: - Get small testing checkpoints from hub - Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline - Add onestep, multistep tests for distillation and distillation + class conditional - Add expected image slices for onestep tests --- .../test_consistency_models.py | 110 +++++++++++++++--- 1 file changed, 93 insertions(+), 17 deletions(-) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index eca99ec11c89..6c97120fd63f 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -26,27 +26,30 @@ class ConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, PipelineTeste @property def dummy_uncond_unet(self): - torch.manual_seed(0) - model = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=3, - out_channels=3, - down_block_types=("ResnetDownsampleBlock2D", "AttnDownsampleBlock2D"), - up_block_types=("AttnUpsampleBlock2D", "ResnetUpsampleBlock2D"), - num_class_embeds=10, - resnet_time_scale_shift="scale_shift", + unet = UNet2DModel.from_pretrained( + "dg845/consistency-models-test", + subfolder="test_unet", ) - return model + return unet + + @property + def dummy_cond_unet(self): + unet = UNet2DModel.from_pretrained( + "dg845/consistency-models-test", + subfolder="test_unet_class_cond", + ) + return unet - def get_dummy_components(self): - unet = self.dummy_uncond_unet + def get_dummy_components(self, class_cond=False): + if class_cond: + unet = self.dummy_cond_unet + else: + unet = self.dummy_uncond_unet # Default to CM multistep sampler # TODO: need to determine most sensible settings for these args scheduler = CMStochasticIterativeScheduler( - num_train_timesteps=1000, + num_train_timesteps=40, beta_start=0.0001, beta_end=0.02, ) @@ -54,6 +57,7 @@ def get_dummy_components(self): components = { "unet": unet, "scheduler": scheduler, + "distillation": False, } return components @@ -93,15 +97,33 @@ def test_consistency_model_pipeline_multistep(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_consistency_model_pipeline_onestep(self): + def test_consistency_model_pipeline_multistep_distillation(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() + components["distillation"] = True + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + # TODO: get correct expected_slice + expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_multistep_class_cond_distillation(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(class_cond=True) + components["distillation"] = True pipe = ConsistencyModelPipeline(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 1 image = pipe(**inputs).images assert image.shape == (1, 32, 32, 3) @@ -111,6 +133,60 @@ def test_consistency_model_pipeline_onestep(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_consistency_model_pipeline_onestep(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_onestep_distillation(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["distillation"] = True + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + print(f"Image slice: {image_slice.flatten()}") + expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_consistency_model_pipeline_onestep_class_cond_distillation(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(class_cond=True) + components["distillation"] = True + pipe = ConsistencyModelPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + image = pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_consistency_model_pipeline_k_dpm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_uncond_unet From c37e302d50b59592416f9bb8cea092a235bf6916 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 20:16:04 -0700 Subject: [PATCH 17/72] make style --- .../consistency_models/pipeline_consistency_models.py | 4 +--- .../pipelines/consistency_models/test_consistency_models.py | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 774fc8ac29a6..385c870b2aa9 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -90,9 +90,7 @@ def denoise( for x in self.get_scalings_for_boundary_condition(sigma, sigma_min=sigma_min, sigma_data=sigma_data) ] else: - c_skip, c_out, c_in = [ - append_dims(x, x_t.ndim) for x in self.get_scalings(sigma, sigma_data=sigma_data) - ] + c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigma, sigma_data=sigma_data)] rescaled_t = 1000 * 0.25 * torch.log(sigma + 1e-44) model_output = self.unet(c_in * x_t, rescaled_t).sample denoised = c_out * model_output + c_skip * x_t diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 6c97120fd63f..360672de568e 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -31,7 +31,7 @@ def dummy_uncond_unet(self): subfolder="test_unet", ) return unet - + @property def dummy_cond_unet(self): unet = UNet2DModel.from_pretrained( @@ -149,7 +149,7 @@ def test_consistency_model_pipeline_onestep(self): expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_onestep_distillation(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -168,7 +168,7 @@ def test_consistency_model_pipeline_onestep_distillation(self): expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_onestep_class_cond_distillation(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components(class_cond=True) From a0a164cfb0b67ba125b3cd8c9bb0345247b4e15d Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 21:29:28 -0700 Subject: [PATCH 18/72] Improve ConsistencyModelPipeline: - Add initial support for class-conditional generation - Fix initial sigma for onestep generation - Fix some sigma shape issues --- .../pipeline_consistency_models.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 385c870b2aa9..fcccc4dc0f3d 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -31,6 +31,7 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers, dist ) self.distillation = distillation + self.num_classes = unet.config.num_class_embeds # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Additionally prepare sigma_min, sigma_max kwargs for CM multistep scheduler @@ -75,6 +76,7 @@ def denoise( self, x_t, sigma, + class_labels=None, sigma_min: float = 0.002, sigma_data: float = 0.5, clip_denoised=True, @@ -92,7 +94,7 @@ def denoise( else: c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigma, sigma_data=sigma_data)] rescaled_t = 1000 * 0.25 * torch.log(sigma + 1e-44) - model_output = self.unet(c_in * x_t, rescaled_t).sample + model_output = self.unet(c_in * x_t, rescaled_t, class_labels=class_labels).sample denoised = c_out * model_output + c_skip * x_t if clip_denoised: denoised = denoised.clamp(-1, 1) @@ -106,6 +108,7 @@ def to_d(x, sigma, denoised): def __call__( self, batch_size: int = 1, + class_labels: Optional[Union[torch.IntTensor, List[int], int]] = None, num_inference_steps: int = 40, clip_denoised: bool = True, sigma_min: float = 0.002, @@ -144,27 +147,40 @@ def __call__( # 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma - # 2. Set timesteps + # 2. Handle class_labels for class-conditional models + if self.num_classes is not None: + if isinstance(class_labels, list): + class_labels = torch.tensor(class_labels, dtype=torch.int) + elif isinstance(class_labels, int): + assert batch_size == 1, "Batch size must be 1 if classes is an int" + class_labels = torch.tensor([class_labels], dtype=torch.int) + elif class_labels is None: + # Randomly generate batch_size class labels + class_labels = torch.randint(0, self.num_classes, size=batch_size) + class_labels.to(device) + + # 3. Set timesteps self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps # Now get Karras sigma schedule (which I think the original implementation always uses) # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L376 # TODO: how do we ensure that this in Karras sigma space rather than in "original" sigma space? - # 3. Get sigma schedule + # 4. Get sigma schedule assert hasattr(self.scheduler, "sigmas"), "Scheduler needs to operate in sigma space" sigmas = self.scheduler.sigmas - # 4. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, sigma_min, sigma_max) - # 5. Denoising loop + # 6. Denoising loop if num_inference_steps == 1: # Onestep sampling: simply evaluate the consistency model at the first sigma # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L643 - sigma = sigmas[0] + sigma = sigma_max + sigma_in = sample.new_ones([sample.shape[0]]) * sigma _, sample = self.denoise( - sample, sigma, sigma_min=sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised + sample, sigma_in, class_labels=class_labels, sigma_min=sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised ) else: # Multistep sampling or Karras sampler @@ -173,11 +189,11 @@ def __call__( # Don't loop over last timestep for i, t in enumerate(timesteps[:-1]): sigma = sigmas[i] - # TODO: handle class labels + sigma_in = sample.new_ones([sample.shape[0]]) * sigma # TODO: check shapes, might need equivalent of s_in in original code # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L510 model_output, denoised = self.denoise( - sample, sigma, sigma_min=sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised + sample, sigma_in, class_labels=class_labels, sigma_min=sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised ) # Works for both Karras-style schedulers (e.g. Euler, Heun) and the CM multistep scheduler @@ -191,7 +207,7 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, sample) - # 6. Post-process image sample + # 7. Post-process image sample sample = (sample / 2 + 0.5).clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() From 0d1de0893d871d2c4a6662ebae03a9318e053ddf Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 21:31:39 -0700 Subject: [PATCH 19/72] make style --- .../pipeline_consistency_models.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index fcccc4dc0f3d..7398b7dbe54e 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -180,7 +180,12 @@ def __call__( sigma = sigma_max sigma_in = sample.new_ones([sample.shape[0]]) * sigma _, sample = self.denoise( - sample, sigma_in, class_labels=class_labels, sigma_min=sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised + sample, + sigma_in, + class_labels=class_labels, + sigma_min=sigma_min, + sigma_data=sigma_data, + clip_denoised=clip_denoised, ) else: # Multistep sampling or Karras sampler @@ -193,7 +198,12 @@ def __call__( # TODO: check shapes, might need equivalent of s_in in original code # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L510 model_output, denoised = self.denoise( - sample, sigma_in, class_labels=class_labels, sigma_min=sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised + sample, + sigma_in, + class_labels=class_labels, + sigma_min=sigma_min, + sigma_data=sigma_data, + clip_denoised=clip_denoised, ) # Works for both Karras-style schedulers (e.g. Euler, Heun) and the CM multistep scheduler From 5f4f4064fc47660fc98e2f772271333c8793fe56 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 22:23:28 -0700 Subject: [PATCH 20/72] Improve ConsistencyModelPipeline: - add latents __call__ argument and prepare_latents method - add check_inputs method - add initial docstrings for ConsistencyModelPipeline.__call__ --- .../pipeline_consistency_models.py | 93 +++++++++++++++++-- 1 file changed, 83 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 7398b7dbe54e..13dc0d0ded50 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -57,6 +57,25 @@ def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + # Unlike stable diffusion, no VAE so no vae_scale_factor, num_channels_latent => num_channels + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents def get_scalings(self, sigma, sigma_data: float = 0.5): c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) @@ -103,12 +122,28 @@ def denoise( def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / append_dims(sigma, x.ndim) + + def check_inputs(self, latents, batch_size, img_size, callback_steps): + if latents is not None: + expected_shape = (batch_size, 3, img_size, img_size) + if latents.shape != expected_shape: + raise ValueError( + f"The shape of latents is {latents.shape} but is expected to be {expected_shape}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) @torch.no_grad() def __call__( self, batch_size: int = 1, - class_labels: Optional[Union[torch.IntTensor, List[int], int]] = None, + class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, num_inference_steps: int = 40, clip_denoised: bool = True, sigma_min: float = 0.002, @@ -116,6 +151,7 @@ def __call__( sigma_data: float = 0.5, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -125,29 +161,66 @@ def __call__( Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. + class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): + Optional class labels for conditioning class-conditional consistency models. Will not be used if the + model is not class-conditional. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + clip_denoised (`bool`, *optional*, defaults to `True`): + Whether to clip the consistency model denoising output to `(0, 1)`. + sigma_min (`float`, *optional*, defaults to 0.002): + The minimum (and last) value in the sigma noise schedule. + sigma_max (`float`, *optional*, defaults to 80.0): + The maximum (and first) value in the sigma noise schedule. + sigma_data (`float`, *optional*, defaults to 0.5): + TODO eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ + # 0. Prepare call parameters img_size = img_size = self.unet.config.sample_size - shape = (batch_size, 3, img_size, img_size) device = self.device - # 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I) - sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma + # 1. Check inputs + self.check_inputs(latents, batch_size, img_size, callback_steps) + + # 2. Prepare image latents + # Sample image latents x_0 ~ N(0, sigma_0^2 * I) + sample = self.prepare_latents( + batch_size=batch_size, + num_channels=3, + height=img_size, + width=img_size, + dtype=self.unet.dtype, + device=device, + generator=generator, + latents=latents, + ) - # 2. Handle class_labels for class-conditional models + # 3. Handle class_labels for class-conditional models if self.num_classes is not None: if isinstance(class_labels, list): class_labels = torch.tensor(class_labels, dtype=torch.int) @@ -159,21 +232,21 @@ def __call__( class_labels = torch.randint(0, self.num_classes, size=batch_size) class_labels.to(device) - # 3. Set timesteps + # 4. Set timesteps self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps # Now get Karras sigma schedule (which I think the original implementation always uses) # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L376 # TODO: how do we ensure that this in Karras sigma space rather than in "original" sigma space? - # 4. Get sigma schedule + # 5. Get sigma schedule assert hasattr(self.scheduler, "sigmas"), "Scheduler needs to operate in sigma space" sigmas = self.scheduler.sigmas - # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, sigma_min, sigma_max) - # 6. Denoising loop + # 7. Denoising loop if num_inference_steps == 1: # Onestep sampling: simply evaluate the consistency model at the first sigma # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L643 @@ -217,7 +290,7 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, sample) - # 7. Post-process image sample + # 8. Post-process image sample sample = (sample / 2 + 0.5).clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() From 213b25de786fb80bd9f00f60e74525fdec26c67a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 22:28:29 -0700 Subject: [PATCH 21/72] make style --- .../consistency_models/pipeline_consistency_models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 13dc0d0ded50..3f0cd090a91b 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -57,7 +57,7 @@ def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs - + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Unlike stable diffusion, no VAE so no vae_scale_factor, num_channels_latent => num_channels def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): @@ -122,14 +122,12 @@ def denoise( def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / append_dims(sigma, x.ndim) - + def check_inputs(self, latents, batch_size, img_size, callback_steps): if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) if latents.shape != expected_shape: - raise ValueError( - f"The shape of latents is {latents.shape} but is expected to be {expected_shape}." - ) + raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) From fbe34c3fc6463fafd8283607c2f69c3c453977f0 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 May 2023 23:03:25 -0700 Subject: [PATCH 22/72] Fix bug when randomly generating class labels for class-conditional generation. --- .../pipelines/consistency_models/pipeline_consistency_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 3f0cd090a91b..0b36af183c0a 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -227,7 +227,7 @@ def __call__( class_labels = torch.tensor([class_labels], dtype=torch.int) elif class_labels is None: # Randomly generate batch_size class labels - class_labels = torch.randint(0, self.num_classes, size=batch_size) + class_labels = torch.randint(0, self.num_classes, size=(batch_size,)) class_labels.to(device) # 4. Set timesteps From 0e53d8bb6c6e70577154a5fea7a8419ebc516430 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 1 Jun 2023 03:52:08 -0700 Subject: [PATCH 23/72] Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests. --- .../pipeline_consistency_models.py | 4 + .../scheduling_consistency_models.py | 134 ++++-------------- .../test_consistency_models.py | 6 +- 3 files changed, 37 insertions(+), 107 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 0b36af183c0a..de6a2d01c1f2 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -239,6 +239,10 @@ def __call__( # TODO: how do we ensure that this in Karras sigma space rather than in "original" sigma space? # 5. Get sigma schedule assert hasattr(self.scheduler, "sigmas"), "Scheduler needs to operate in sigma space" + if hasattr(self.scheduler, "sigma_min"): + # Overwrite sigma_min with sigma_min from the scheduler + sigma_min = self.scheduler.sigma_min + sigma_max = self.scheduler.sigma_max sigmas = self.scheduler.sigmas # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index b261ddc32d49..42bd5d258253 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -122,50 +122,20 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + num_train_timesteps: int = 40, + sigma_min: float = 0.002, + sigma_max: float = 80.0, sigma_data: float = 0.5, rho: float = 7.0, - prediction_type: str = "sample", - interpolation_type: str = "linear", - use_karras_sigmas: Optional[bool] = True, + timesteps: Optional[np.ndarray] = None, ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() + self.init_noise_sigma = sigma_max # setable values self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.sigma_data = sigma_data - self.rho = rho + self.timesteps = timesteps self.is_scale_input_called = False - self.use_karras_sigmas = use_karras_sigmas def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: @@ -198,24 +168,17 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps - - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - - if self.config.interpolation_type == "linear": - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - elif self.config.interpolation_type == "log_linear": - sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() + + # TODO: timesteps should be increasing rather than decreasing?? + if self.timesteps is None: + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) else: - raise ValueError( - f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" - " 'linear' or 'log_linear'" - ) - - if self.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + timesteps = self.timesteps + + # Map timesteps to Karras sigmas directly for multistep sampling + # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 + ramp = timesteps / (self.config.num_train_timesteps - 1) + sigmas = self._convert_to_karras(ramp) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) @@ -225,40 +188,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic else: self.timesteps = torch.from_numpy(timesteps).to(device=device) - # Copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(sigma) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - # Modified from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras - # Use self.rho instead of hardcoded 7.0 for rho - def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + # Use self.rho instead of hardcoded 7.0 for rho, sigma_min/max from config, configurable ramp function + def _convert_to_karras(self, ramp): """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + sigma_min: float = self.config.sigma_min + sigma_max: float = self.config.sigma_max rho = self.rho - ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho @@ -297,14 +235,14 @@ def add_noise_to_input( self, sample: torch.FloatTensor, sigma: float, - sigma_min: float = 0.002, - sigma_max: float = 80.0, generator: Optional[torch.Generator] = None, ) -> Tuple[torch.FloatTensor, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: """ + sigma_min = self.config.sigma_min + sigma_max = self.config.sigma_max step_idx = (self.sigmas == sigma).nonzero().item() sigma_hat = self.sigmas[step_idx + 1].clamp(min=sigma_min, max=sigma_max) @@ -321,8 +259,6 @@ def step( model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, - sigma_min: float = 0.002, - sigma_max: float = 80.0, s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, @@ -369,12 +305,17 @@ def step( if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) + + sigma_min = self.config.sigma_min + sigma_max = self.config.sigma_max - step_index = (self.timesteps == timestep).nonzero().item() + step_index = self.index_for_timestep(timestep) + + # sigma corresponds to t and sigma_next corresponds to next_t in original implementation sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] - # sample z ~ N(0, s_noise^2 * I) + # 1. Sample z ~ N(0, s_noise^2 * I) noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) @@ -382,24 +323,9 @@ def step( sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - # NOTE: "original_sample" should not be an expected prediction_type but is left in for - # backwards compatibility - if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma_hat * model_output - elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - # 2. Return noisy sample # tau = sigma_hat, eps = sigma_min - prev_sample = pred_original_sample + z * (sigma_hat**2 - sigma_min**2) ** 0.5 + prev_sample = model_output + z * (sigma_hat**2 - sigma_min**2) ** 0.5 if not return_dict: return (prev_sample,) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 360672de568e..a2cce3a283b0 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -50,8 +50,9 @@ def get_dummy_components(self, class_cond=False): # TODO: need to determine most sensible settings for these args scheduler = CMStochasticIterativeScheduler( num_train_timesteps=40, - beta_start=0.0001, - beta_end=0.02, + sigma_min=0.002, + sigma_max=80.0, + timesteps=np.array([22, 39]), ) components = { @@ -164,7 +165,6 @@ def test_consistency_model_pipeline_onestep_distillation(self): assert image.shape == (1, 32, 32, 3) image_slice = image[0, -3:, -3:, -1] - print(f"Image slice: {image_slice.flatten()}") expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From c3b242edf6d4ab6a87ef8646f6af1fba7a6970a6 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 1 Jun 2023 03:58:00 -0700 Subject: [PATCH 24/72] Remove some unused code and make style. --- .../scheduling_consistency_models.py | 42 ++++--------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 42bd5d258253..c7253ef9b391 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -52,32 +51,6 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput): # pred_original_sample: Optional[torch.FloatTensor] = None -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the - cumulative product of (1-beta) up to that part of the diffusion process. - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and @@ -168,13 +141,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps - + # TODO: timesteps should be increasing rather than decreasing?? if self.timesteps is None: timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) else: timesteps = self.timesteps - + # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 ramp = timesteps / (self.config.num_train_timesteps - 1) @@ -196,7 +169,7 @@ def _convert_to_karras(self, ramp): sigma_min: float = self.config.sigma_min sigma_max: float = self.config.sigma_max - rho = self.rho + rho = self.config.rho min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho @@ -305,14 +278,13 @@ def step( if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) - + sigma_min = self.config.sigma_min sigma_max = self.config.sigma_max step_index = self.index_for_timestep(timestep) - - # sigma corresponds to t and sigma_next corresponds to next_t in original implementation - sigma = self.sigmas[step_index] + + # sigma_next corresponds to next_t in original implementation sigma_next = self.sigmas[step_index + 1] # 1. Sample z ~ N(0, s_noise^2 * I) From 43e437927c7a10cfd01380dbd56cb3ee7d9aeacd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 1 Jun 2023 04:26:05 -0700 Subject: [PATCH 25/72] Fix small bug in CMStochasticIterativeScheduler. --- src/diffusers/schedulers/scheduling_consistency_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index c7253ef9b391..df08d0ff85e9 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -146,7 +146,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic if self.timesteps is None: timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) else: - timesteps = self.timesteps + timesteps = self.timesteps.astype(np.float32) # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 @@ -265,7 +265,7 @@ def step( raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + f" `{self.__class__}.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) From 94c99ca4888bb1966c2c0d023621daf4ecac7b73 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 1 Jun 2023 21:18:39 -0700 Subject: [PATCH 26/72] Add expected slices for multistep sampling tests and make them pass. --- .../consistency_models/pipeline_consistency_models.py | 3 +-- .../schedulers/scheduling_consistency_models.py | 6 ++++-- .../consistency_models/test_consistency_models.py | 11 ++++------- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index de6a2d01c1f2..cc8a7bdfcd83 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -266,8 +266,7 @@ def __call__( # Multistep sampling or Karras sampler num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: - # Don't loop over last timestep - for i, t in enumerate(timesteps[:-1]): + for i, t in enumerate(timesteps): sigma = sigmas[i] sigma_in = sample.new_ones([sample.shape[0]]) * sigma # TODO: check shapes, might need equivalent of s_in in original code diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index df08d0ff85e9..2c3b1fb642c5 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -142,7 +142,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - # TODO: timesteps should be increasing rather than decreasing?? + # Note: timesteps are expected to be increasing rather than decreasing, following original implementation if self.timesteps is None: timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) else: @@ -150,7 +150,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 - ramp = timesteps / (self.config.num_train_timesteps - 1) + num_train_timesteps = self.config.num_train_timesteps + # Append num_train_timesteps - 1 so sigmas[-1] == sigma_min + ramp = np.append(timesteps, [num_train_timesteps - 1]) / (num_train_timesteps - 1) sigmas = self._convert_to_karras(ramp) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index a2cce3a283b0..8a6affb292fb 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -52,7 +52,7 @@ def get_dummy_components(self, class_cond=False): num_train_timesteps=40, sigma_min=0.002, sigma_max=80.0, - timesteps=np.array([22, 39]), + timesteps=np.array([0, 22]), ) components = { @@ -93,8 +93,7 @@ def test_consistency_model_pipeline_multistep(self): assert image.shape == (1, 32, 32, 3) image_slice = image[0, -3:, -3:, -1] - # TODO: get correct expected_slice - expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + expected_slice = np.array([0.3576, 0.6270, 0.4034, 0.3964, 0.4323, 0.5728, 0.5265, 0.4781, 0.5004]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -111,8 +110,7 @@ def test_consistency_model_pipeline_multistep_distillation(self): assert image.shape == (1, 32, 32, 3) image_slice = image[0, -3:, -3:, -1] - # TODO: get correct expected_slice - expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -129,8 +127,7 @@ def test_consistency_model_pipeline_multistep_class_cond_distillation(self): assert image.shape == (1, 32, 32, 3) image_slice = image[0, -3:, -3:, -1] - # TODO: get correct expected_slice - expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From 0773b271775ba03bf8f4c796ad8a1214fd273ef1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 1 Jun 2023 23:35:16 -0700 Subject: [PATCH 27/72] Work on consistency model fast tests: - in pipeline, call self.scheduler.scale_model_input before denoising - get expected slices for Euler and Heun scheduler tests - make Euler test pass - mark Heun test as expected fail because it doesn't support prediction_type "sample" yet - remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas --- .../pipeline_consistency_models.py | 5 +- .../test_consistency_models.py | 73 ++++++------------- 2 files changed, 26 insertions(+), 52 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index cc8a7bdfcd83..76b28aee59c6 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -269,8 +269,9 @@ def __call__( for i, t in enumerate(timesteps): sigma = sigmas[i] sigma_in = sample.new_ones([sample.shape[0]]) * sigma - # TODO: check shapes, might need equivalent of s_in in original code - # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L510 + + # TODO: should we call scale_model_input here? + sample = self.scheduler.scale_model_input(sample, t) model_output, denoised = self.denoise( sample, sigma_in, diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 8a6affb292fb..adc4114f6423 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -1,4 +1,5 @@ import gc +import pytest import unittest import numpy as np @@ -7,10 +8,8 @@ from diffusers import ( CMStochasticIterativeScheduler, ConsistencyModelPipeline, - EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, - KDPM2DiscreteScheduler, UNet2DModel, ) from diffusers.utils import slow @@ -20,7 +19,8 @@ from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -class ConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class ConsistencyModelPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ConsistencyModelPipeline params = UNCONDITIONAL_IMAGE_GENERATION_PARAMS batch_params = UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS @@ -184,49 +184,17 @@ def test_consistency_model_pipeline_onestep_class_cond_distillation(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_consistency_model_pipeline_k_dpm(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_uncond_unet - # TODO: get reasonable args for KDPM2DiscreteScheduler - scheduler = KDPM2DiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear") - pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - assert image.shape == (1, 32, 32, 3) - - image_slice = image[0, -3:, -3:, -1] - # TODO: get correct expected_slice - expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_consistency_model_pipeline_k_euler(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_uncond_unet - # TODO: get reasonable args for EulerDiscreteScheduler - scheduler = EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear") - pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - assert image.shape == (1, 32, 32, 3) - - image_slice = image[0, -3:, -3:, -1] - # TODO: get correct expected_slice - expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - - def test_consistency_model_pipeline_k_euler_ancestral(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_uncond_unet - # TODO: get reasonable args for EulerAncestralDiscreteScheduler - scheduler = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear") + scheduler = EulerDiscreteScheduler( + num_train_timesteps=2, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + prediction_type="sample", + use_karras_sigmas=True, + ) pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -236,16 +204,22 @@ def test_consistency_model_pipeline_k_euler_ancestral(self): assert image.shape == (1, 32, 32, 3) image_slice = image[0, -3:, -3:, -1] - # TODO: get correct expected_slice - expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + expected_slice = np.array([0.5157, 0.5143, 0.4804, 0.5273, 0.4146, 0.5619, 0.4651, 0.4359, 0.4540]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + + @pytest.mark.xfail(reason="Heun scheduler does not implement prediction_type 'sample' yet") def test_consistency_model_pipeline_k_heun(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_uncond_unet - # TODO: get reasonable args for HeunDiscreteScheduler - scheduler = HeunDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear") + scheduler = HeunDiscreteScheduler( + num_train_timesteps=2, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + prediction_type="sample", + use_karras_sigmas=True, + ) pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -255,8 +229,7 @@ def test_consistency_model_pipeline_k_heun(self): assert image.shape == (1, 32, 32, 3) image_slice = image[0, -3:, -3:, -1] - # TODO: get correct expected_slice - expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) + expected_slice = np.array([0.5159, 0.5145, 0.4801, 0.5277, 0.4134, 0.5628, 0.4646, 0.4350, 0.4533]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From 6867a3ad17970cd4f15d3a8d739500ec3d41a496 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 1 Jun 2023 23:39:14 -0700 Subject: [PATCH 28/72] make style --- .../pipelines/consistency_models/test_consistency_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index adc4114f6423..01473dc92140 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -1,8 +1,8 @@ import gc -import pytest import unittest import numpy as np +import pytest import torch from diffusers import ( @@ -16,7 +16,7 @@ from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineTesterMixin class ConsistencyModelPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @@ -207,7 +207,7 @@ def test_consistency_model_pipeline_k_euler(self): expected_slice = np.array([0.5157, 0.5143, 0.4804, 0.5273, 0.4146, 0.5619, 0.4651, 0.4359, 0.4540]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + @pytest.mark.xfail(reason="Heun scheduler does not implement prediction_type 'sample' yet") def test_consistency_model_pipeline_k_heun(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator From f0c85d3e7b53f6ff921f91067bdf962e53a26498 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 3 Jun 2023 20:32:38 -0700 Subject: [PATCH 29/72] Refactor conversion script to make it easier to add more model architectures to convert in the future. --- scripts/convert_consistency_to_diffusers.py | 30 ++++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index dda6a6790ed2..3cd6951e016d 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -24,7 +24,7 @@ "resnet_time_scale_shift": "scale_shift", } -UNET_CONFIG = { +IMAGENET_64_UNET_CONFIG = { "sample_size": 64, "in_channels": 3, "out_channels": 3, @@ -48,6 +48,20 @@ } +def str2bool(v): + """ + https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("boolean value expected") + + def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False): new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"] new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"] @@ -203,18 +217,20 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model." ) parser.add_argument("--checkpoint_name", default="cd_imagenet64_l2", type=str, help="Checkpoint to convert.") + parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.") args = parser.parse_args() + args.class_cond = str2bool(args.class_cond) - if args.checkpoint_name == "cd_imagenet64_l2": - unet_config = UNET_CONFIG - elif args.checkpoint_name == "test": - unet_config = TEST_UNET_CONFIG - unet_config["num_class_embeds"] = None - elif args.checkpoint_name == "test_class_cond": + if "imagenet64" in args.checkpoint_name: + unet_config = IMAGENET_64_UNET_CONFIG + elif "test" in args.checkpoint_name: unet_config = TEST_UNET_CONFIG else: raise ValueError(f"Checkpoint type {args.checkpoint_name} is not currently supported.") + + if not args.class_cond: + unet_config["num_class_embeds"] = None converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config) From 6adb58922164c64e0acd121f19befc807a2e61cd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 3 Jun 2023 20:35:23 -0700 Subject: [PATCH 30/72] Work on ConsistencyModelPipeline tests: - Fix device bug when handling class labels in ConsistencyModelPipeline.__call__ - Add slow tests for onestep and multistep sampling and make them pass - Refactor fast tests - Refactor ConsistencyModelPipeline.__init__ --- .../pipeline_consistency_models.py | 9 +- .../test_consistency_models.py | 102 ++++++++++++++---- 2 files changed, 85 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 76b28aee59c6..9d868a8dc5e6 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -22,7 +22,7 @@ class ConsistencyModelPipeline(DiffusionPipeline): Sampling pipeline for consistency models. """ - def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers, distillation: bool = False) -> None: + def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers, distillation: bool = True) -> None: super().__init__() self.register_modules( @@ -31,7 +31,6 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers, dist ) self.distillation = distillation - self.num_classes = unet.config.num_class_embeds # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Additionally prepare sigma_min, sigma_max kwargs for CM multistep scheduler @@ -219,7 +218,7 @@ def __call__( ) # 3. Handle class_labels for class-conditional models - if self.num_classes is not None: + if self.unet.config.num_class_embeds is not None: if isinstance(class_labels, list): class_labels = torch.tensor(class_labels, dtype=torch.int) elif isinstance(class_labels, int): @@ -227,8 +226,8 @@ def __call__( class_labels = torch.tensor([class_labels], dtype=torch.int) elif class_labels is None: # Randomly generate batch_size class labels - class_labels = torch.randint(0, self.num_classes, size=(batch_size,)) - class_labels.to(device) + class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) + class_labels = class_labels.to(device) # 4. Set timesteps self.scheduler.set_timesteps(num_inference_steps) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 01473dc92140..ec837cb2d980 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -12,7 +12,7 @@ HeunDiscreteScheduler, UNet2DModel, ) -from diffusers.utils import slow +from diffusers.utils import slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS @@ -47,7 +47,6 @@ def get_dummy_components(self, class_cond=False): unet = self.dummy_uncond_unet # Default to CM multistep sampler - # TODO: need to determine most sensible settings for these args scheduler = CMStochasticIterativeScheduler( num_train_timesteps=40, sigma_min=0.002, @@ -58,7 +57,7 @@ def get_dummy_components(self, class_cond=False): components = { "unet": unet, "scheduler": scheduler, - "distillation": False, + "distillation": True, } return components @@ -93,14 +92,13 @@ def test_consistency_model_pipeline_multistep(self): assert image.shape == (1, 32, 32, 3) image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.3576, 0.6270, 0.4034, 0.3964, 0.4323, 0.5728, 0.5265, 0.4781, 0.5004]) + expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_consistency_model_pipeline_multistep_distillation(self): + def test_consistency_model_pipeline_multistep_class_cond(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - components["distillation"] = True + components = self.get_dummy_components(class_cond=True) pipe = ConsistencyModelPipeline(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -113,11 +111,11 @@ def test_consistency_model_pipeline_multistep_distillation(self): expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - - def test_consistency_model_pipeline_multistep_class_cond_distillation(self): + + def test_consistency_model_pipeline_multistep_edm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components(class_cond=True) - components["distillation"] = True + components = self.get_dummy_components() + components["distillation"] = False pipe = ConsistencyModelPipeline(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -127,7 +125,7 @@ def test_consistency_model_pipeline_multistep_class_cond_distillation(self): assert image.shape == (1, 32, 32, 3) image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) + expected_slice = np.array([0.3576, 0.6270, 0.4034, 0.3964, 0.4323, 0.5728, 0.5265, 0.4781, 0.5004]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -148,10 +146,9 @@ def test_consistency_model_pipeline_onestep(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_consistency_model_pipeline_onestep_distillation(self): + def test_consistency_model_pipeline_onestep_class_cond(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - components["distillation"] = True + components = self.get_dummy_components(class_cond=True) pipe = ConsistencyModelPipeline(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -165,11 +162,11 @@ def test_consistency_model_pipeline_onestep_distillation(self): expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - - def test_consistency_model_pipeline_onestep_class_cond_distillation(self): + + def test_consistency_model_pipeline_onestep_edm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components(class_cond=True) - components["distillation"] = True + components = self.get_dummy_components() + components["distillation"] = False pipe = ConsistencyModelPipeline(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -195,7 +192,7 @@ def test_consistency_model_pipeline_k_euler(self): prediction_type="sample", use_karras_sigmas=True, ) - pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler, distillation=False) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -220,7 +217,7 @@ def test_consistency_model_pipeline_k_heun(self): prediction_type="sample", use_karras_sigmas=True, ) - pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler, distillation=False) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -241,3 +238,66 @@ def tearDown(self): super().tearDown() gc.collect() torch.cuda.empty_cache() + + def get_inputs(self, seed=0): + generator = torch.manual_seed(seed) + + inputs = { + "num_inference_steps": 2, + "class_labels": 0, + "clip_denoised": True, + "sigma_min": 0.002, + "sigma_max": 80.0, + "sigma_data": 0.5, + "generator": generator, + "output_type": "numpy", + } + + return inputs + + def test_consistency_model_cd_multistep(self): + unet = UNet2DModel.from_pretrained( + "ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2" + ) + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + timesteps=np.array([0, 22]), + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs() + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.2645, 0.3386, 0.1928, 0.1284, 0.1215, 0.0285, 0.0800, 0.1213, 0.3331]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 3e-3 + + def test_consistency_model_cd_onestep(self): + unet = UNet2DModel.from_pretrained( + "ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2" + ) + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + timesteps=np.array([0, 22]), + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs() + inputs["num_inference_steps"] = 1 + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.2480, 0.1257, 0.0852, 0.2474, 0.3226, 0.1637, 0.3169, 0.2660, 0.3875]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 3e-3 From c213bf7ed623fd7fe2fbaec89260254c55e04837 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 3 Jun 2023 20:41:55 -0700 Subject: [PATCH 31/72] make style --- scripts/convert_consistency_to_diffusers.py | 2 +- .../test_consistency_models.py | 18 +++++++----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 3cd6951e016d..6a8e8eb938e0 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -228,7 +228,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): unet_config = TEST_UNET_CONFIG else: raise ValueError(f"Checkpoint type {args.checkpoint_name} is not currently supported.") - + if not args.class_cond: unet_config["num_class_embeds"] = None diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index ec837cb2d980..e930b88af68c 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -111,7 +111,7 @@ def test_consistency_model_pipeline_multistep_class_cond(self): expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_multistep_edm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -162,7 +162,7 @@ def test_consistency_model_pipeline_onestep_class_cond(self): expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + def test_consistency_model_pipeline_onestep_edm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -238,7 +238,7 @@ def tearDown(self): super().tearDown() gc.collect() torch.cuda.empty_cache() - + def get_inputs(self, seed=0): generator = torch.manual_seed(seed) @@ -254,11 +254,9 @@ def get_inputs(self, seed=0): } return inputs - + def test_consistency_model_cd_multistep(self): - unet = UNet2DModel.from_pretrained( - "ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2" - ) + unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") scheduler = CMStochasticIterativeScheduler( num_train_timesteps=40, sigma_min=0.002, @@ -277,11 +275,9 @@ def test_consistency_model_cd_multistep(self): expected_slice = np.array([0.2645, 0.3386, 0.1928, 0.1284, 0.1215, 0.0285, 0.0800, 0.1213, 0.3331]) assert np.abs(image_slice.flatten() - expected_slice).max() < 3e-3 - + def test_consistency_model_cd_onestep(self): - unet = UNet2DModel.from_pretrained( - "ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2" - ) + unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") scheduler = CMStochasticIterativeScheduler( num_train_timesteps=40, sigma_min=0.002, From 3be6e086d05c88cfb362d9dcb1a344e306a7ae9e Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 3 Jun 2023 20:54:21 -0700 Subject: [PATCH 32/72] Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now. --- .../scheduling_consistency_models.py | 52 ------------------- 1 file changed, 52 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 2c3b1fb642c5..28639330c75f 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -177,58 +177,6 @@ def _convert_to_karras(self, ramp): sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - # TODO: add_noise meant to be called during training (forward diffusion process) - # TODO: need to check if this corresponds to noise added during training - # TODO: may want multiple add_noise-type methods for CD, CT - # Copied from diffusers.schedulers.scheduling_euler_discrete.add_noise - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def add_noise_to_input( - self, - sample: torch.FloatTensor, - sigma: float, - generator: Optional[torch.Generator] = None, - ) -> Tuple[torch.FloatTensor, float]: - """ - Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a - higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: - """ - sigma_min = self.config.sigma_min - sigma_max = self.config.sigma_max - step_idx = (self.sigmas == sigma).nonzero().item() - sigma_hat = self.sigmas[step_idx + 1].clamp(min=sigma_min, max=sigma_max) - - # sample z ~ N(0, s_noise^2 * I) - z = self.config.s_noise * randn_tensor(sample.shape, generator=generator, device=sample.device) - - # tau = sigma_hat, eps = sigma_min - sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z) - - return sample_hat, sigma_hat - def step( self, model_output: torch.FloatTensor, From 660909f8d9c80d46ee308e7ea188e774047f23e8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 3 Jun 2023 21:07:44 -0700 Subject: [PATCH 33/72] Run python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler. --- src/diffusers/utils/dummy_pt_objects.py | 30 +++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index e07b7cb27da7..87828155032a 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -210,6 +210,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ConsistencyModelPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DanceDiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -390,6 +405,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CMStochasticIterativeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] From ad5abdc064ed6118591b2b10065d8e1b888cfcc9 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 3 Jun 2023 22:33:52 -0700 Subject: [PATCH 34/72] Make fast tests from PipelineTesterMixin pass. --- .../pipeline_consistency_models.py | 101 ++++++++++++++++-- .../scheduling_consistency_models.py | 7 +- .../test_consistency_models.py | 25 ++++- 3 files changed, 121 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 9d868a8dc5e6..22021ad3c0e7 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -5,7 +5,11 @@ from ...models import UNet2DModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import randn_tensor +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + randn_tensor, +) from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -22,7 +26,7 @@ class ConsistencyModelPipeline(DiffusionPipeline): Sampling pipeline for consistency models. """ - def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers, distillation: bool = True) -> None: + def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> None: super().__init__() self.register_modules( @@ -30,7 +34,90 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers, dist scheduler=scheduler, ) - self.distillation = distillation + self.distillation = True + self.safety_checker = None + + def set_consistency(self): + self.distillation = True + + def set_edm(self): + self.distillation = False + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + # Modified to only offload self.unet + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + # Modified to only offload self.unet + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Additionally prepare sigma_min, sigma_max kwargs for CM multistep scheduler @@ -198,8 +285,8 @@ def __call__( True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Prepare call parameters - img_size = img_size = self.unet.config.sample_size - device = self.device + img_size = self.unet.config.sample_size + device = self._execution_device # 1. Check inputs self.check_inputs(latents, batch_size, img_size, callback_steps) @@ -301,6 +388,8 @@ def __call__( if not return_dict: return (sample,) - # TODO: Offload to cpu? + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() return ImagePipelineOutput(images=sample) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 28639330c75f..a44dbccd5a7c 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -146,7 +146,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic if self.timesteps is None: timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) else: - timesteps = self.timesteps.astype(np.float32) + if isinstance(self.timesteps, list): + timesteps = np.array(self.timesteps, dtype=np.float32) + elif isinstance(self.timesteps, np.ndarray): + timesteps = self.timesteps.astype(np.float32) + else: + timesteps = self.timesteps.numpy().astype(np.float32) # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index e930b88af68c..88b3f5cc0b3e 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -24,6 +24,19 @@ class ConsistencyModelPipelineFastTests(PipelineTesterMixin, unittest.TestCase): params = UNCONDITIONAL_IMAGE_GENERATION_PARAMS batch_params = UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS + # Override required_optional_params to remove num_images_per_prompt + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "output_type", + "return_dict", + "callback", + "callback_steps", + ] + ) + @property def dummy_uncond_unet(self): unet = UNet2DModel.from_pretrained( @@ -57,7 +70,6 @@ def get_dummy_components(self, class_cond=False): components = { "unet": unet, "scheduler": scheduler, - "distillation": True, } return components @@ -69,6 +81,7 @@ def get_dummy_inputs(self, device, seed=0): generator = torch.Generator(device=device).manual_seed(seed) inputs = { + "batch_size": 1, "num_inference_steps": 2, "clip_denoised": True, "sigma_min": 0.002, @@ -115,8 +128,8 @@ def test_consistency_model_pipeline_multistep_class_cond(self): def test_consistency_model_pipeline_multistep_edm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - components["distillation"] = False pipe = ConsistencyModelPipeline(**components) + pipe.set_edm() pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -166,8 +179,8 @@ def test_consistency_model_pipeline_onestep_class_cond(self): def test_consistency_model_pipeline_onestep_edm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - components["distillation"] = False pipe = ConsistencyModelPipeline(**components) + pipe.set_edm() pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -192,7 +205,8 @@ def test_consistency_model_pipeline_k_euler(self): prediction_type="sample", use_karras_sigmas=True, ) - pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler, distillation=False) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.set_edm() pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -217,7 +231,8 @@ def test_consistency_model_pipeline_k_heun(self): prediction_type="sample", use_karras_sigmas=True, ) - pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler, distillation=False) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.set_edm() pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) From a1c11d3016cbaa8471c7ac5fed4507ef3db788c1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 3 Jun 2023 22:36:45 -0700 Subject: [PATCH 35/72] make style --- .../consistency_models/pipeline_consistency_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 22021ad3c0e7..6a0469631139 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -36,13 +36,13 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> N self.distillation = True self.safety_checker = None - + def set_consistency(self): self.distillation = True - + def set_edm(self): self.distillation = False - + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Modified to only offload self.unet def enable_sequential_cpu_offload(self, gpu_id=0): @@ -99,7 +99,7 @@ def enable_model_cpu_offload(self, gpu_id=0): # We'll offload the last model manually. self.final_offload_hook = hook - + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): From 2dc2fa311260adf07d64d1bbf6586b452fbafaba Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jun 2023 23:03:58 -0700 Subject: [PATCH 36/72] Refactor consistency models pipeline and scheduler: - Remove support for Karras schedulers (only support CMStochasticIterativeScheduler) - Move sigma manipulation, input scaling, denoising from pipeline to scheduler - Make corresponding changes to tests and ensure they pass --- .../pipeline_consistency_models.py | 201 +++++------------- .../scheduling_consistency_models.py | 87 ++++++-- .../test_consistency_models.py | 98 +-------- 3 files changed, 130 insertions(+), 256 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 6a0469631139..a4ba318a14c0 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -34,15 +34,8 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> N scheduler=scheduler, ) - self.distillation = True self.safety_checker = None - def set_consistency(self): - self.distillation = True - - def set_edm(self): - self.distillation = False - # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Modified to only offload self.unet def enable_sequential_cpu_offload(self, gpu_id=0): @@ -119,9 +112,8 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - # Additionally prepare sigma_min, sigma_max kwargs for CM multistep scheduler - def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -132,12 +124,6 @@ def prepare_extra_step_kwargs(self, generator, eta, sigma_min, sigma_max): if accepts_eta: extra_step_kwargs["eta"] = eta - accepts_sigma_min = "sigma_min" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_sigma_min: - # Assume accepting sigma_min always means scheduler also accepts sigma_max - extra_step_kwargs["sigma_min"] = sigma_min - extra_step_kwargs["sigma_max"] = sigma_max - # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: @@ -162,54 +148,30 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - - def get_scalings(self, sigma, sigma_data: float = 0.5): - c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) - c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 - c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 - return c_skip, c_out, c_in - - def get_scalings_for_boundary_condition(self, sigma, sigma_min: float = 0.002, sigma_data: float = 0.5): - # sigma_min should be in original sigma space, not in karras sigma space - # (e.g. not exponentiated by 1 / rho) - c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) - c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 - c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 - return c_skip, c_out, c_in - - def denoise( - self, - x_t, - sigma, - class_labels=None, - sigma_min: float = 0.002, - sigma_data: float = 0.5, - clip_denoised=True, - ): - """ - Run the consistency model forward...? - """ - # sigma_min should be in original sigma space, not in karras sigma space - # (e.g. not exponentiated by 1 / rho) - if self.distillation: - c_skip, c_out, c_in = [ - append_dims(x, x_t.ndim) - for x in self.get_scalings_for_boundary_condition(sigma, sigma_min=sigma_min, sigma_data=sigma_data) - ] + + def prepare_class_labels(self, batch_size, device, class_labels=None): + if self.unet.config.num_class_embeds is not None: + if isinstance(class_labels, list): + class_labels = torch.tensor(class_labels, dtype=torch.int) + elif isinstance(class_labels, int): + assert batch_size == 1, "Batch size must be 1 if classes is an int" + class_labels = torch.tensor([class_labels], dtype=torch.int) + elif class_labels is None: + # Randomly generate batch_size class labels + # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils + class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) + class_labels = class_labels.to(device) else: - c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigma, sigma_data=sigma_data)] - rescaled_t = 1000 * 0.25 * torch.log(sigma + 1e-44) - model_output = self.unet(c_in * x_t, rescaled_t, class_labels=class_labels).sample - denoised = c_out * model_output + c_skip * x_t - if clip_denoised: - denoised = denoised.clamp(-1, 1) - return model_output, denoised - - def to_d(x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / append_dims(sigma, x.ndim) - - def check_inputs(self, latents, batch_size, img_size, callback_steps): + class_labels = None + return class_labels + + def check_inputs(self, num_inference_steps, latents, batch_size, img_size, callback_steps): + if self.scheduler.timesteps is not None and len(self.scheduler.timesteps) < num_inference_steps: + raise ValueError( + f"The scheduler's timestep schedule: {self.scheduler.timesteps} is shorter than num_inference_steps:" + " {num_inference_steps}, but is expected to be at least as long." + ) + if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) if latents.shape != expected_shape: @@ -229,10 +191,6 @@ def __call__( batch_size: int = 1, class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, num_inference_steps: int = 40, - clip_denoised: bool = True, - sigma_min: float = 0.002, - sigma_max: float = 80.0, - sigma_data: float = 0.5, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -251,14 +209,6 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - clip_denoised (`bool`, *optional*, defaults to `True`): - Whether to clip the consistency model denoising output to `(0, 1)`. - sigma_min (`float`, *optional*, defaults to 0.002): - The minimum (and last) value in the sigma noise schedule. - sigma_max (`float`, *optional*, defaults to 80.0): - The maximum (and first) value in the sigma noise schedule. - sigma_data (`float`, *optional*, defaults to 0.5): - TODO eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -289,7 +239,7 @@ def __call__( device = self._execution_device # 1. Check inputs - self.check_inputs(latents, batch_size, img_size, callback_steps) + self.check_inputs(num_inference_steps, latents, batch_size, img_size, callback_steps) # 2. Prepare image latents # Sample image latents x_0 ~ N(0, sigma_0^2 * I) @@ -305,80 +255,41 @@ def __call__( ) # 3. Handle class_labels for class-conditional models - if self.unet.config.num_class_embeds is not None: - if isinstance(class_labels, list): - class_labels = torch.tensor(class_labels, dtype=torch.int) - elif isinstance(class_labels, int): - assert batch_size == 1, "Batch size must be 1 if classes is an int" - class_labels = torch.tensor([class_labels], dtype=torch.int) - elif class_labels is None: - # Randomly generate batch_size class labels - class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) - class_labels = class_labels.to(device) + class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) # 4. Set timesteps self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps - # Now get Karras sigma schedule (which I think the original implementation always uses) - # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L376 - # TODO: how do we ensure that this in Karras sigma space rather than in "original" sigma space? - # 5. Get sigma schedule - assert hasattr(self.scheduler, "sigmas"), "Scheduler needs to operate in sigma space" - if hasattr(self.scheduler, "sigma_min"): - # Overwrite sigma_min with sigma_min from the scheduler - sigma_min = self.scheduler.sigma_min - sigma_max = self.scheduler.sigma_max - sigmas = self.scheduler.sigmas - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, sigma_min, sigma_max) - - # 7. Denoising loop - if num_inference_steps == 1: - # Onestep sampling: simply evaluate the consistency model at the first sigma - # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L643 - sigma = sigma_max - sigma_in = sample.new_ones([sample.shape[0]]) * sigma - _, sample = self.denoise( - sample, - sigma_in, - class_labels=class_labels, - sigma_min=sigma_min, - sigma_data=sigma_data, - clip_denoised=clip_denoised, - ) - else: - # Multistep sampling or Karras sampler - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - sigma = sigmas[i] - sigma_in = sample.new_ones([sample.shape[0]]) * sigma - - # TODO: should we call scale_model_input here? - sample = self.scheduler.scale_model_input(sample, t) - model_output, denoised = self.denoise( - sample, - sigma_in, - class_labels=class_labels, - sigma_min=sigma_min, - sigma_data=sigma_data, - clip_denoised=clip_denoised, - ) - - # Works for both Karras-style schedulers (e.g. Euler, Heun) and the CM multistep scheduler - sample = self.scheduler.step(denoised, t, sample, **extra_step_kwargs).prev_sample - - # Note: differs from callback support in original code - # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459 - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, sample) - - # 8. Post-process image sample + # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6. Denoising loop + # Multistep sampling: implements Algorithm 1 in the paper + sigma = torch.tensor([self.scheduler.init_noise_sigma], dtype=torch.float, device=device) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Onestep sampling does not use random noise + use_noise = False if num_inference_steps == 1 else True + + rescaled_sample = self.scheduler.scale_model_input(sample, t) + rescaled_t = self.scheduler.scale_timestep(sigma) + model_output = self.unet(rescaled_sample, rescaled_t, class_labels=class_labels).sample + + sample, sigma = self.scheduler.step( + model_output, t, sample, use_noise=use_noise, return_dict=False, **extra_step_kwargs + ) + + # Note: differs from callback support in original code + # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459 + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + # 7. Post-process image sample sample = (sample / 2 + 0.5).clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index a44dbccd5a7c..b3b4b601023b 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import numpy as np import torch @@ -30,6 +30,14 @@ def append_zero(x): return torch.cat([x, x.new_zeros([1])]) +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + @dataclass class CMStochasticIterativeSchedulerOutput(BaseOutput): """ @@ -47,8 +55,7 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput): """ prev_sample: torch.FloatTensor - # derivative: torch.FloatTensor - # pred_original_sample: Optional[torch.FloatTensor] = None + sigma_next: torch.FloatTensor class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): @@ -100,7 +107,8 @@ def __init__( sigma_max: float = 80.0, sigma_data: float = 0.5, rho: float = 7.0, - timesteps: Optional[np.ndarray] = None, + clip_denoised: bool = True, + timesteps: Optional[Union[List, np.ndarray, torch.Tensor]] = None, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max @@ -116,19 +124,53 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): indices = (schedule_timesteps == timestep).nonzero() return indices.item() + + def get_scalings(self, sigma): + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out, c_in + + def get_scalings_for_boundary_condition(self, sigma): + # sigma_min should be in original sigma space, not in karras sigma space + # (e.g. not exponentiated by 1 / rho) + sigma_min = self.config.sigma_min + sigma_data = self.config.sigma_data - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) + c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out, c_in + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ + # Get sigma corresponding to timestep + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_idx = self.index_for_timestep(timestep) + sigma = self.sigmas[step_idx] + + c_skip, c_out, c_in = self.get_scalings_for_boundary_condition(sigma) + sample = c_in * sample + + self.is_scale_input_called = True return sample + + def scale_timestep(self, sigma): + return 1000 * 0.25 * torch.log(sigma + 1e-44) def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -152,6 +194,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic timesteps = self.timesteps.astype(np.float32) else: timesteps = self.timesteps.numpy().astype(np.float32) + timesteps = timesteps[:self.num_inference_steps] # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 @@ -160,7 +203,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ramp = np.append(timesteps, [num_train_timesteps - 1]) / (num_train_timesteps - 1) sigmas = self._convert_to_karras(ramp) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + # sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = sigmas.astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 @@ -188,6 +232,7 @@ def step( timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_noise: float = 1.0, + use_noise: bool = True, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: @@ -240,24 +285,36 @@ def step( step_index = self.index_for_timestep(timestep) # sigma_next corresponds to next_t in original implementation + sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] - # 1. Sample z ~ N(0, s_noise^2 * I) - noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator - ) + # Get scalings for boundary conditions + c_skip, c_out, c_in = self.get_scalings_for_boundary_condition(sigma) + + # 1. Denoise model output using boundary conditions + denoised = c_out * model_output + c_skip * sample + if self.config.clip_denoised: + denoised = denoised.clamp(-1, 1) + + # 2. Sample z ~ N(0, s_noise^2 * I) + if use_noise: + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + else: + noise = torch.zeros_like(model_output) z = noise * s_noise sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) - # 2. Return noisy sample + # 3. Return noisy sample # tau = sigma_hat, eps = sigma_min - prev_sample = model_output + z * (sigma_hat**2 - sigma_min**2) ** 0.5 + prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 if not return_dict: - return (prev_sample,) + return (prev_sample, sigma_next) - return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) + return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample, sigma_next=sigma_next) def __len__(self): return self.config.num_train_timesteps diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 88b3f5cc0b3e..46dd5014399c 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -83,10 +83,6 @@ def get_dummy_inputs(self, device, seed=0): inputs = { "batch_size": 1, "num_inference_steps": 2, - "clip_denoised": True, - "sigma_min": 0.002, - "sigma_max": 80.0, - "sigma_data": 0.5, "generator": generator, "output_type": "numpy", } @@ -125,23 +121,6 @@ def test_consistency_model_pipeline_multistep_class_cond(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_consistency_model_pipeline_multistep_edm(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - pipe = ConsistencyModelPipeline(**components) - pipe.set_edm() - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - assert image.shape == (1, 32, 32, 3) - - image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.3576, 0.6270, 0.4034, 0.3964, 0.4323, 0.5728, 0.5265, 0.4781, 0.5004]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_consistency_model_pipeline_onestep(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -176,75 +155,6 @@ def test_consistency_model_pipeline_onestep_class_cond(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - def test_consistency_model_pipeline_onestep_edm(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - pipe = ConsistencyModelPipeline(**components) - pipe.set_edm() - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 1 - image = pipe(**inputs).images - assert image.shape == (1, 32, 32, 3) - - image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - - def test_consistency_model_pipeline_k_euler(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_uncond_unet - scheduler = EulerDiscreteScheduler( - num_train_timesteps=2, - beta_start=0.00085, - beta_end=0.012, - beta_schedule="linear", - prediction_type="sample", - use_karras_sigmas=True, - ) - pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) - pipe.set_edm() - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - assert image.shape == (1, 32, 32, 3) - - image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.5157, 0.5143, 0.4804, 0.5273, 0.4146, 0.5619, 0.4651, 0.4359, 0.4540]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - - @pytest.mark.xfail(reason="Heun scheduler does not implement prediction_type 'sample' yet") - def test_consistency_model_pipeline_k_heun(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_uncond_unet - scheduler = HeunDiscreteScheduler( - num_train_timesteps=2, - beta_start=0.00085, - beta_end=0.012, - beta_schedule="linear", - prediction_type="sample", - use_karras_sigmas=True, - ) - pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) - pipe.set_edm() - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - assert image.shape == (1, 32, 32, 3) - - image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.5159, 0.5145, 0.4801, 0.5277, 0.4134, 0.5628, 0.4646, 0.4350, 0.4533]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - @slow @require_torch_gpu @@ -260,10 +170,6 @@ def get_inputs(self, seed=0): inputs = { "num_inference_steps": 2, "class_labels": 0, - "clip_denoised": True, - "sigma_min": 0.002, - "sigma_max": 80.0, - "sigma_data": 0.5, "generator": generator, "output_type": "numpy", } @@ -289,7 +195,7 @@ def test_consistency_model_cd_multistep(self): image_slice = image[0, -3:, -3:, -1] expected_slice = np.array([0.2645, 0.3386, 0.1928, 0.1284, 0.1215, 0.0285, 0.0800, 0.1213, 0.3331]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 3e-3 + assert np.abs(image_slice.flatten() - expected_slice).max() < 4e-3 def test_consistency_model_cd_onestep(self): unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") @@ -311,4 +217,4 @@ def test_consistency_model_cd_onestep(self): image_slice = image[0, -3:, -3:, -1] expected_slice = np.array([0.2480, 0.1257, 0.0852, 0.2474, 0.3226, 0.1637, 0.3169, 0.2660, 0.3875]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 3e-3 + assert np.abs(image_slice.flatten() - expected_slice).max() < 4e-3 From 7d3dbe3d15b6288cf39eaee74eaa76618b71c017 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jun 2023 23:09:29 -0700 Subject: [PATCH 37/72] make style --- .../consistency_models/pipeline_consistency_models.py | 4 ++-- src/diffusers/schedulers/scheduling_consistency_models.py | 8 ++++---- .../consistency_models/test_consistency_models.py | 3 --- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index a4ba318a14c0..de998139e24d 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -148,7 +148,7 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - + def prepare_class_labels(self, batch_size, device, class_labels=None): if self.unet.config.num_class_embeds is not None: if isinstance(class_labels, list): @@ -171,7 +171,7 @@ def check_inputs(self, num_inference_steps, latents, batch_size, img_size, callb f"The scheduler's timestep schedule: {self.scheduler.timesteps} is shorter than num_inference_steps:" " {num_inference_steps}, but is expected to be at least as long." ) - + if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) if latents.shape != expected_shape: diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index b3b4b601023b..cb821d3f3644 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -124,7 +124,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): indices = (schedule_timesteps == timestep).nonzero() return indices.item() - + def get_scalings(self, sigma): sigma_data = self.config.sigma_data @@ -168,7 +168,7 @@ def scale_model_input( self.is_scale_input_called = True return sample - + def scale_timestep(self, sigma): return 1000 * 0.25 * torch.log(sigma + 1e-44) @@ -194,7 +194,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic timesteps = self.timesteps.astype(np.float32) else: timesteps = self.timesteps.numpy().astype(np.float32) - timesteps = timesteps[:self.num_inference_steps] + timesteps = timesteps[: self.num_inference_steps] # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 46dd5014399c..75358c4f65f5 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -2,14 +2,11 @@ import unittest import numpy as np -import pytest import torch from diffusers import ( CMStochasticIterativeScheduler, ConsistencyModelPipeline, - EulerDiscreteScheduler, - HeunDiscreteScheduler, UNet2DModel, ) from diffusers.utils import slow, torch_device From 55f80ed98261d51c04e6f5943925e32c89a9b965 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Jun 2023 01:12:49 -0700 Subject: [PATCH 38/72] Add docstrings and further refactor pipeline and scheduler. --- .../pipeline_consistency_models.py | 43 +++--- .../scheduling_consistency_models.py | 142 ++++++++++-------- 2 files changed, 100 insertions(+), 85 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index de998139e24d..e6f5d41379af 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -13,17 +13,22 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") - return x[(...,) + (None,) * dims_to_append] - - class ConsistencyModelPipeline(DiffusionPipeline): r""" - Sampling pipeline for consistency models. + Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1]. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" + https://arxiv.org/pdf/2303.01469 + + Args: + unet ([`UNet2DModel`]): + Unconditional or class-conditional U-Net architecture to denoise image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible + with [`CMStochasticIterativeScheduler`]. """ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> None: @@ -143,7 +148,7 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma @@ -266,20 +271,18 @@ def __call__( # 6. Denoising loop # Multistep sampling: implements Algorithm 1 in the paper - sigma = torch.tensor([self.scheduler.init_noise_sigma], dtype=torch.float, device=device) + # Onestep sampling does not use random noise + use_noise = False if num_inference_steps == 1 else True num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # Onestep sampling does not use random noise - use_noise = False if num_inference_steps == 1 else True - - rescaled_sample = self.scheduler.scale_model_input(sample, t) - rescaled_t = self.scheduler.scale_timestep(sigma) - model_output = self.unet(rescaled_sample, rescaled_t, class_labels=class_labels).sample + scaled_sample = self.scheduler.scale_model_input(sample, t) + scaled_t = self.scheduler.scale_timestep(t) + model_output = self.unet(scaled_sample, scaled_t, class_labels=class_labels).sample - sample, sigma = self.scheduler.step( - model_output, t, sample, use_noise=use_noise, return_dict=False, **extra_step_kwargs - ) + sample = self.scheduler.step( + model_output, t, sample, use_noise=use_noise, **extra_step_kwargs + ).prev_sample # Note: differs from callback support in original code # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459 diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index cb821d3f3644..05a305dab43c 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -47,54 +47,47 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. - derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Derivative of predicted original image sample (x_0). - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor - sigma_next: torch.FloatTensor class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): """ - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and - the VE column of Table 1 from [1] for reference. + Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in + the paper [1]. [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" https://arxiv.org/pdf/2303.01469 + [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. - For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of - Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the - optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. - Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - prediction_type (`str`, default `"epsilon"`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - interpolation_type (`str`, default `"linear"`, optional): - interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of - [`"linear"`, `"log_linear"`]. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the - noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence - of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + sigma_min (`float`): + Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation. + sigma_max (`float`): + Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation. + sigma_data (`float`): + The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the + original implementation, which is also the original value suggested in the EDM paper. + s_noise (`float`): + The amount of additional noise to counteract loss of detail during sampling. A reasonable range is + [1.000, 1.011]. This was set to 1.0 in the original implementation. + rho (`float`): + The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This + was set to 7.0 in the original implementation, which is also the original value suggested in the EDM + paper. + clip_denoised (`bool`): + Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`. + timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): + Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing + order. """ order = 1 @@ -106,6 +99,7 @@ def __init__( sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, + s_noise: float = 1.0, rho: float = 7.0, clip_denoised: bool = True, timesteps: Optional[Union[List, np.ndarray, torch.Tensor]] = None, @@ -125,31 +119,11 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): indices = (schedule_timesteps == timestep).nonzero() return indices.item() - def get_scalings(self, sigma): - sigma_data = self.config.sigma_data - - c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) - c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 - c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 - return c_skip, c_out, c_in - - def get_scalings_for_boundary_condition(self, sigma): - # sigma_min should be in original sigma space, not in karras sigma space - # (e.g. not exponentiated by 1 / rho) - sigma_min = self.config.sigma_min - sigma_data = self.config.sigma_data - - c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) - c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 - c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5 - return c_skip, c_out, c_in - def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. + Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model. Args: sample (`torch.FloatTensor`): input sample @@ -163,13 +137,26 @@ def scale_model_input( step_idx = self.index_for_timestep(timestep) sigma = self.sigmas[step_idx] - c_skip, c_out, c_in = self.get_scalings_for_boundary_condition(sigma) - sample = c_in * sample + sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) self.is_scale_input_called = True return sample - def scale_timestep(self, sigma): + def scale_timestep(self, timestep: Union[float, torch.FloatTensor]): + """ + Scales the timestep based on the associated Karras sigma, for input to the consistency model. + + Args: + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + Returns: + `torch.FloatTensor`: scaled input timestep + """ + # Get sigma corresponding to timestep + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_idx = self.index_for_timestep(timestep) + sigma = self.sigmas[step_idx] + return 1000 * 0.25 * torch.log(sigma + 1e-44) def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): @@ -203,7 +190,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ramp = np.append(timesteps, [num_train_timesteps - 1]) / (num_train_timesteps - 1) sigmas = self._convert_to_karras(ramp) - # sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = sigmas.astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): @@ -225,13 +211,41 @@ def _convert_to_karras(self, ramp): max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + + def get_scalings(self, sigma): + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def get_scalings_for_boundary_condition(self, sigma): + """ + Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper. + This enforces the consistency model boundary condition. + + Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min. + + Args: + sigma (`torch.FloatTensor`): + The current sigma in the Karras sigma schedule. + Returns: + `tuple`: + A two-element tuple where c_skip (which weights the current sample) is the first element and c_out + (which weights the consistency model output) is the second element. + """ + sigma_min = self.config.sigma_min + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) + c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, - s_noise: float = 1.0, use_noise: bool = True, generator: Optional[torch.Generator] = None, return_dict: bool = True, @@ -245,11 +259,9 @@ def step( timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - s_churn (`float`) - s_tmin (`float`) - s_tmax (`float`) - s_noise (`float`) - generator (`torch.Generator`, optional): Random number generator. + use_noise: (`bool`, *optional*, defaults to `True`): + Whether to inject noise during the step. Noise is not used for onestep sampling. + generator (`torch.Generator`, *optional*): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`: @@ -289,7 +301,7 @@ def step( sigma_next = self.sigmas[step_index + 1] # Get scalings for boundary conditions - c_skip, c_out, c_in = self.get_scalings_for_boundary_condition(sigma) + c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) # 1. Denoise model output using boundary conditions denoised = c_out * model_output + c_skip * sample @@ -303,7 +315,7 @@ def step( ) else: noise = torch.zeros_like(model_output) - z = noise * s_noise + z = noise * self.config.s_noise sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) @@ -312,9 +324,9 @@ def step( prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 if not return_dict: - return (prev_sample, sigma_next) + return (prev_sample) - return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample, sigma_next=sigma_next) + return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) def __len__(self): return self.config.num_train_timesteps From 3234543ff47182c0bb753d23ef79ade95cf6495c Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Jun 2023 01:16:06 -0700 Subject: [PATCH 39/72] make style --- .../scheduling_consistency_models.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 05a305dab43c..a5a59403cb72 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -54,13 +54,12 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput): class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): """ - Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in - the paper [1]. + Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the + paper [1]. [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" - https://arxiv.org/pdf/2303.01469 - [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 + https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based + Generative Models." https://arxiv.org/abs/2206.00364 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. @@ -77,17 +76,16 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the original implementation, which is also the original value suggested in the EDM paper. s_noise (`float`): - The amount of additional noise to counteract loss of detail during sampling. A reasonable range is - [1.000, 1.011]. This was set to 1.0 in the original implementation. + The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, + 1.011]. This was set to 1.0 in the original implementation. rho (`float`): - The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This - was set to 7.0 in the original implementation, which is also the original value suggested in the EDM - paper. + The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was + set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper. clip_denoised (`bool`): Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`. timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing - order. + order. """ order = 1 @@ -211,7 +209,7 @@ def _convert_to_karras(self, ramp): max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - + def get_scalings(self, sigma): sigma_data = self.config.sigma_data @@ -324,7 +322,7 @@ def step( prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 if not return_dict: - return (prev_sample) + return (prev_sample,) return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) From dc4349cadc35d4515ef76cdf5981fc5b2fac5d99 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jun 2023 23:04:55 -0700 Subject: [PATCH 40/72] Add initial version of the consistency models documentation. --- docs/source/en/_toctree.yml | 4 ++ .../en/api/pipelines/consistency_models.mdx | 54 +++++++++++++++++++ .../schedulers/cm_stochastic_iterative.mdx | 11 ++++ .../scheduling_consistency_models.py | 2 +- 4 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/pipelines/consistency_models.mdx create mode 100644 docs/source/en/api/schedulers/cm_stochastic_iterative.mdx diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index aa2d907da4bd..b857f6192722 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -152,6 +152,8 @@ title: Audio Diffusion - local: api/pipelines/audioldm title: AudioLDM + - local: api/pipelines/consistency_models + title: Consistency Models - local: api/pipelines/controlnet title: ControlNet - local: api/pipelines/cycle_diffusion @@ -236,6 +238,8 @@ - sections: - local: api/schedulers/overview title: Overview + - local: api/schedulers/cm_stochastic_iterative + title: Consistency Model Multistep Scheduler - local: api/schedulers/ddim title: DDIM - local: api/schedulers/ddim_inverse diff --git a/docs/source/en/api/pipelines/consistency_models.mdx b/docs/source/en/api/pipelines/consistency_models.mdx new file mode 100644 index 000000000000..9707cc910359 --- /dev/null +++ b/docs/source/en/api/pipelines/consistency_models.mdx @@ -0,0 +1,54 @@ +# Consistency Models + +Consistency Models were proposed in [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. + +The abstract of the [paper](https://arxiv.org/pdf/2303.01469.pdf) is as follows: + +*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256. * + +Resources: + +* [Paper](https://arxiv.org/abs/2303.01469) +* [Original Code](https://github.com/openai/consistency_models) + +Available Checkpoints are: +- *cd_imagenet64_l2 (64x64 resolution)* [dg845/consistency-model-pipelines](https://huggingface.co/dg845/consistency-model-pipelines) +- TODO: add more checkpoints from original release? + +## Available Pipelines + +| Pipeline | Tasks | Demo | Colab | +|:---:|:---:|:---:|:---:| +| [ConsistencyModelPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_consistency_models.py) | *Unconditional Image Generation* | | | + +## Usage Example + +```python +import torch + +from diffusers import ConsistencyModelPipeline + +device = "cuda" +# Load the cd_imagenet64_l2 checkpoint. +model_id_or_path = "dg845/consistency-model-pipelines" +pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Onestep Sampling +image = pipe(num_inference_steps=1).images[0] +image.save("consistency_model_onestep_sample.png") + +# Onestep sampling, class-conditional image generation +# ImageNet-64 class label 145 corresponds to king penguins +image = pipe(num_inference_steps=1, class_labels=145).images[0] +image.save("consistency_model_onestep_sample_penguin.png") + +# Multistep sampling, class-conditional image generation +image = pipe(num_inference_steps=2, class_labels=145).images[0] +image.save("consistency_model_multistep_sample_penguin.png") +``` + +## ConsistencyModelPipeline +[[autodoc]] ConsistencyModelPipeline + - all + - __call__ \ No newline at end of file diff --git a/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx b/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx new file mode 100644 index 000000000000..0cc40bde47a0 --- /dev/null +++ b/docs/source/en/api/schedulers/cm_stochastic_iterative.mdx @@ -0,0 +1,11 @@ +# Consistency Model Multistep Scheduler + +## Overview + +Multistep and onestep scheduler (Algorithm 1) introduced alongside consistency models in the paper [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. +Based on the [original consistency models implementation](https://github.com/openai/consistency_models). +Should generate good samples from [`ConsistencyModelPipeline`] in one or a small number of steps. + +## CMStochasticIterativeScheduler +[[autodoc]] CMStochasticIterativeScheduler + diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index a5a59403cb72..216c7e60b7e8 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -257,7 +257,7 @@ def step( timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - use_noise: (`bool`, *optional*, defaults to `True`): + use_noise (`bool`, *optional*, defaults to `True`): Whether to inject noise during the step. Noise is not used for onestep sampling. generator (`torch.Generator`, *optional*): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class From 787e65af93e226316ffdb163a45929cba92c5a8d Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Fri, 9 Jun 2023 12:40:57 +0530 Subject: [PATCH 41/72] Minor --- examples/consistency_models/script.sh | 3 + .../train_consistency_distillation.py | 59 ++++++++++--------- 2 files changed, 33 insertions(+), 29 deletions(-) create mode 100644 examples/consistency_models/script.sh diff --git a/examples/consistency_models/script.sh b/examples/consistency_models/script.sh new file mode 100644 index 000000000000..5b5f6b4710e9 --- /dev/null +++ b/examples/consistency_models/script.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +accelerate launch train_consistency_distillation.py --dataset_name="huggan/flowers-102-categories" --resolution=64 --center_crop --random_flip --output_dir="ddpm-ema-flowers-64" --train_batch_size=16 --num_epochs=100 --gradient_accumulation_steps=1 --use_ema --learning_rate=1e-4 --lr_warmup_steps=500 --mixed_precision=no --push_to_hub \ No newline at end of file diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index a14d652c208b..022d4e5a63ef 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -75,7 +75,7 @@ def parse_args(): "--model_config_name_or_path", type=str, default=None, - help="The config of the UNet model to train, leave as None to use standard DDPM configuration.", + help="The config of the UNet model to train, leave as None to use standard Consistency configuration.", ) parser.add_argument( "--train_data_dir", @@ -378,31 +378,31 @@ def load_model_hook(models, input_dir): # Initialize the model if args.model_config_name_or_path is None: model = UNet2DModel( - sample_size=args.resolution, - in_channels=3, - out_channels=3, - layers_per_block=2, - block_out_channels=(128, 128, 256, 256, 512, 512), - down_block_types=( - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "AttnDownBlock2D", - "DownBlock2D", - ), - up_block_types=( - "UpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - ), - ) + sample_size= args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + num_class_embeds=1000, + block_out_channels= [32, 64], + attention_head_dim=8, + down_block_types= [ + "ResnetDownsampleBlock2D", + "AttnDownsampleBlock2D", + ], + up_block_types= [ + "AttnUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + resnet_time_scale_shift="scale_shift", + + ) else: config = UNet2DModel.load_config(args.model_config_name_or_path) model = UNet2DModel.from_config(config) + + teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet + # print(teacher_model) + # Create EMA for the model. if args.use_ema: @@ -478,11 +478,11 @@ def load_model_hook(models, input_dir): ) def transform_images(examples): - images = [augmentations(image.convert("RGB")) for image in examples["image"]] - return {"input": images} + images = [augmentations(image.convert("RGB")) for image in examples["img"]] + labels = [torch.tensor(label) for label in examples["label"]] + return {"input": images, "labels": labels} logger.info(f"Dataset size: {len(dataset)}") - dataset.set_transform(transform_images) train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers @@ -497,8 +497,8 @@ def transform_images(examples): ) # Prepare everything with our `accelerator`. - model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler + model, optimizer, train_dataloader, lr_scheduler, teacher_model = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler, teacher_model ) if args.use_ema: @@ -563,6 +563,7 @@ def transform_images(examples): continue clean_images = batch["input"] + labels = batch["labels"] # Sample noise that we'll add to the images noise = torch.randn(clean_images.shape).to(clean_images.device) bsz = clean_images.shape[0] @@ -577,7 +578,7 @@ def transform_images(examples): with accelerator.accumulate(model): # Predict the noise residual - model_output = model(noisy_images, timesteps).sample + model_output = model(noisy_images, timesteps, labels).sample if args.prediction_type == "epsilon": loss = F.mse_loss(model_output, noise) # this could have different weights! From 6530b17e6af1452dae45966d492eb6f251fa72cd Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Fri, 9 Jun 2023 17:17:25 +0530 Subject: [PATCH 42/72] Add training code --- .../train_consistency_distillation.py | 97 ++++++++++++------- 1 file changed, 64 insertions(+), 33 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 022d4e5a63ef..da362e3c55b3 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm import diffusers -from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, CMStochasticIterativeScheduler from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available @@ -34,6 +34,29 @@ logger = get_logger(__name__, log_level="INFO") +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def heun_solver(samples, t, next_t, x0): + dims = samples.ndim + x = samples + denoiser = teacher_denoise_fn(x, t) + + d = (x - denoiser) / append_dims(t, dims) + samples = x + d * append_dims(next_t - t, dims) + denoiser = teacher_denoise_fn(samples, next_t) + + next_d = (samples - denoiser) / append_dims(next_t, dims) + samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims) + + return samples + + def _extract_into_tensor(arr, timesteps, broadcast_shape): """ @@ -401,10 +424,13 @@ def load_model_hook(models, input_dir): model = UNet2DModel.from_config(config) teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet + noise_scheduler = CMStochasticIterativeScheduler() + num_scales = 40 + noise_scheduler.set_timesteps(num_scales) # print(teacher_model) - # Create EMA for the model. + # Create EMA for the model, this is the target model in the paper if args.use_ema: ema_model = EMAModel( model.parameters(), @@ -429,17 +455,6 @@ def load_model_hook(models, input_dir): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - # Initialize the scheduler - accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) - if accepts_prediction_type: - noise_scheduler = DDPMScheduler( - num_train_timesteps=args.ddpm_num_steps, - beta_schedule=args.ddpm_beta_schedule, - prediction_type=args.prediction_type, - ) - else: - noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) - # Initialize the optimizer optimizer = torch.optim.AdamW( model.parameters(), @@ -569,30 +584,46 @@ def transform_images(examples): bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device + 0, noise_scheduler.config.num_train_timesteps-1, (bsz,), device=clean_images.device ).long() - - # Add noise to the clean images according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) + timesteps_prev = timesteps + 1 + # TO-DO, we should have an add noise in the scheduler maybe? + noised_image = clean_images + noise*append_dims(timesteps, clean_images.ndims) + scaled_timesteps = noise_scheduler.scale_timesteps(timesteps) + scaled_timesteps_prev = noise_scheduler.scale_timesteps(timesteps_prev) with accelerator.accumulate(model): # Predict the noise residual - model_output = model(noisy_images, timesteps, labels).sample - - if args.prediction_type == "epsilon": - loss = F.mse_loss(model_output, noise) # this could have different weights! - elif args.prediction_type == "sample": - alpha_t = _extract_into_tensor( - noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) - ) - snr_weights = alpha_t / (1 - alpha_t) - loss = snr_weights * F.mse_loss( - model_output, clean_images, reduction="none" - ) # use SNR weighting from distillation paper - loss = loss.mean() - else: - raise ValueError(f"Unsupported prediction type: {args.prediction_type}") + model_output = model(noised_image, scaled_timesteps, class_labels=labels).sample + distiller = noise_scheduler.step( + model_output, timesteps, noised_image, use_noise=False + ).prev_sample + + # Heun Solver to get previous timestep image + samples = noised_image + x = samples + model_output = teacher_model(x, scaled_timesteps, class_labels=labels).sample + teacher_denoiser = noise_scheduler.step( + model_output, timesteps, x, use_noise=False + ).prev_sample + d = (x - teacher_denoiser) / append_dims(scaled_timesteps, x.ndims) + samples = x + d * append_dims(scaled_timesteps_prev - scaled_timesteps, x.ndims) + model_output = teacher_model(samples, scaled_timesteps_prev, class_labels=labels).sample + teacher_denoiser = noise_scheduler.step( + model_output, timesteps_prev, samples, use_noise=False + ).prev_sample + + next_d = (samples - teacher_denoiser) / append_dims(scaled_timesteps_prev, x.ndims) + denoised_image = x + (d + next_d) * append_dims((scaled_timesteps_prev - scaled_timesteps) /2, x.ndims) + + # get output from target model + model_output = ema_model(denoised_image, scaled_timesteps_prev, class_labels=labels).sample + distiller_target = noise_scheduler.step( + model_output, timesteps_prev, denoised_image, use_noise=False + ).prev_sample + + loss = F.mse_loss(distiller, distiller_target) # this could have different weights! + loss = loss.mean() accelerator.backward(loss) From 8f858cb6d4d7f160e0b9378f7e331e2b689068d5 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Fri, 9 Jun 2023 18:40:47 +0530 Subject: [PATCH 43/72] Fix bugs in training --- .../train_consistency_distillation.py | 46 ++++++++----------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index da362e3c55b3..a591542a831c 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -42,19 +42,7 @@ def append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -def heun_solver(samples, t, next_t, x0): - dims = samples.ndim - x = samples - denoiser = teacher_denoise_fn(x, t) - d = (x - denoiser) / append_dims(t, dims) - samples = x + d * append_dims(next_t - t, dims) - denoiser = teacher_denoise_fn(samples, next_t) - - next_d = (samples - denoiser) / append_dims(next_t, dims) - samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims) - - return samples @@ -424,9 +412,12 @@ def load_model_hook(models, input_dir): model = UNet2DModel.from_config(config) teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet + model = model.double() + teacher_model = teacher_model.double() noise_scheduler = CMStochasticIterativeScheduler() num_scales = 40 noise_scheduler.set_timesteps(num_scales) + timesteps = noise_scheduler.timesteps # print(teacher_model) @@ -583,20 +574,21 @@ def transform_images(examples): noise = torch.randn(clean_images.shape).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps-1, (bsz,), device=clean_images.device + index = torch.randint( + 0, noise_scheduler.config.num_train_timesteps-1, (1,), device=clean_images.device ).long() - timesteps_prev = timesteps + 1 + timestep = timesteps[index] + timestep_prev = timestep + 1 # TO-DO, we should have an add noise in the scheduler maybe? - noised_image = clean_images + noise*append_dims(timesteps, clean_images.ndims) - scaled_timesteps = noise_scheduler.scale_timesteps(timesteps) - scaled_timesteps_prev = noise_scheduler.scale_timesteps(timesteps_prev) + noised_image = clean_images + noise*append_dims(timestep, clean_images.ndim) + scaled_timesteps = noise_scheduler.scale_timestep(timestep) + scaled_timesteps_prev = noise_scheduler.scale_timestep(timestep_prev) with accelerator.accumulate(model): # Predict the noise residual model_output = model(noised_image, scaled_timesteps, class_labels=labels).sample distiller = noise_scheduler.step( - model_output, timesteps, noised_image, use_noise=False + model_output, timestep, noised_image, use_noise=False ).prev_sample # Heun Solver to get previous timestep image @@ -604,22 +596,22 @@ def transform_images(examples): x = samples model_output = teacher_model(x, scaled_timesteps, class_labels=labels).sample teacher_denoiser = noise_scheduler.step( - model_output, timesteps, x, use_noise=False + model_output, timestep, x, use_noise=False ).prev_sample - d = (x - teacher_denoiser) / append_dims(scaled_timesteps, x.ndims) - samples = x + d * append_dims(scaled_timesteps_prev - scaled_timesteps, x.ndims) + d = (x - teacher_denoiser) / append_dims(scaled_timesteps, x.ndim) + samples = x + d * append_dims(scaled_timesteps_prev - scaled_timesteps, x.ndim) model_output = teacher_model(samples, scaled_timesteps_prev, class_labels=labels).sample teacher_denoiser = noise_scheduler.step( - model_output, timesteps_prev, samples, use_noise=False + model_output, timestep_prev, samples, use_noise=False ).prev_sample - next_d = (samples - teacher_denoiser) / append_dims(scaled_timesteps_prev, x.ndims) - denoised_image = x + (d + next_d) * append_dims((scaled_timesteps_prev - scaled_timesteps) /2, x.ndims) + next_d = (samples - teacher_denoiser) / append_dims(scaled_timesteps_prev, x.ndim) + denoised_image = x + (d + next_d) * append_dims((scaled_timesteps_prev - scaled_timesteps) /2, x.ndim) # get output from target model - model_output = ema_model(denoised_image, scaled_timesteps_prev, class_labels=labels).sample + model_output = model(denoised_image, scaled_timesteps_prev, class_labels=labels).sample distiller_target = noise_scheduler.step( - model_output, timesteps_prev, denoised_image, use_noise=False + model_output, timestep_prev, denoised_image, use_noise=False ).prev_sample loss = F.mse_loss(distiller, distiller_target) # this could have different weights! From e56b870d2f3acd63dcccf4c770005a87247246d6 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Fri, 9 Jun 2023 23:08:57 +0530 Subject: [PATCH 44/72] Remove some args, add target model --- .../train_consistency_distillation.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index a591542a831c..c4712fa4efcc 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -235,16 +235,6 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) - parser.add_argument( - "--prediction_type", - type=str, - default="epsilon", - choices=["epsilon", "sample"], - help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", - ) - parser.add_argument("--ddpm_num_steps", type=int, default=1000) - parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000) - parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") parser.add_argument( "--checkpointing_steps", type=int, @@ -406,13 +396,34 @@ def load_model_hook(models, input_dir): ], resnet_time_scale_shift="scale_shift", + ) + target_model = UNet2DModel( + sample_size= args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + num_class_embeds=1000, + block_out_channels= [32, 64], + attention_head_dim=8, + down_block_types= [ + "ResnetDownsampleBlock2D", + "AttnDownsampleBlock2D", + ], + up_block_types= [ + "AttnUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + resnet_time_scale_shift="scale_shift", + ) else: config = UNet2DModel.load_config(args.model_config_name_or_path) model = UNet2DModel.from_config(config) + target_model = UNet2DModel.from_config(config) teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet model = model.double() + target_model = target_model.double() teacher_model = teacher_model.double() noise_scheduler = CMStochasticIterativeScheduler() num_scales = 40 @@ -503,8 +514,8 @@ def transform_images(examples): ) # Prepare everything with our `accelerator`. - model, optimizer, train_dataloader, lr_scheduler, teacher_model = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler, teacher_model + model, optimizer, train_dataloader, lr_scheduler, teacher_model, target_model, ema_model = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler, teacher_model, target_model, ema_model ) if args.use_ema: @@ -583,6 +594,7 @@ def transform_images(examples): noised_image = clean_images + noise*append_dims(timestep, clean_images.ndim) scaled_timesteps = noise_scheduler.scale_timestep(timestep) scaled_timesteps_prev = noise_scheduler.scale_timestep(timestep_prev) + ema_model.copy_to(target_model.parameters()) with accelerator.accumulate(model): # Predict the noise residual @@ -609,7 +621,7 @@ def transform_images(examples): denoised_image = x + (d + next_d) * append_dims((scaled_timesteps_prev - scaled_timesteps) /2, x.ndim) # get output from target model - model_output = model(denoised_image, scaled_timesteps_prev, class_labels=labels).sample + model_output = target_model(denoised_image, scaled_timesteps_prev, class_labels=labels).sample distiller_target = noise_scheduler.step( model_output, timestep_prev, denoised_image, use_noise=False ).prev_sample From 85eb7961d862c2ba0c5810bc798f1c84b3a126cd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 9 Jun 2023 18:54:22 -0700 Subject: [PATCH 45/72] Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging. --- .../pipeline_consistency_models.py | 36 ++++++---- .../scheduling_consistency_models.py | 69 +++++++++++++------ .../test_consistency_models.py | 12 ++-- 3 files changed, 79 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index e6f5d41379af..b6e4ef11e46e 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Union import torch +from torch.backends.cuda import sdp_kernel from ...models import UNet2DModel from ...schedulers import KarrasDiffusionSchedulers @@ -170,13 +171,13 @@ def prepare_class_labels(self, batch_size, device, class_labels=None): class_labels = None return class_labels - def check_inputs(self, num_inference_steps, latents, batch_size, img_size, callback_steps): - if self.scheduler.timesteps is not None and len(self.scheduler.timesteps) < num_inference_steps: - raise ValueError( - f"The scheduler's timestep schedule: {self.scheduler.timesteps} is shorter than num_inference_steps:" - " {num_inference_steps}, but is expected to be at least as long." - ) - + def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") + if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) if latents.shape != expected_shape: @@ -196,6 +197,7 @@ def __call__( batch_size: int = 1, class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, num_inference_steps: int = 40, + timesteps: List[int] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -214,6 +216,9 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -244,7 +249,7 @@ def __call__( device = self._execution_device # 1. Check inputs - self.check_inputs(num_inference_steps, latents, batch_size, img_size, callback_steps) + self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) # 2. Prepare image latents # Sample image latents x_0 ~ N(0, sigma_0^2 * I) @@ -262,9 +267,14 @@ def __call__( # 3. Handle class_labels for class-conditional models class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) - # 4. Set timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -278,7 +288,9 @@ def __call__( for i, t in enumerate(timesteps): scaled_sample = self.scheduler.scale_model_input(sample, t) scaled_t = self.scheduler.scale_timestep(t) - model_output = self.unet(scaled_sample, scaled_t, class_labels=class_labels).sample + # model_output = self.unet(scaled_sample, scaled_t, class_labels=class_labels).sample + with sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + model_output = self.unet(scaled_sample, scaled_t, class_labels=class_labels).sample sample = self.scheduler.step( model_output, t, sample, use_noise=use_noise, **extra_step_kwargs diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 216c7e60b7e8..21400f8d8a97 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -100,14 +100,14 @@ def __init__( s_noise: float = 1.0, rho: float = 7.0, clip_denoised: bool = True, - timesteps: Optional[Union[List, np.ndarray, torch.Tensor]] = None, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max # setable values self.num_inference_steps = None - self.timesteps = timesteps + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + self.custom_timesteps = False, self.is_scale_input_called = False def index_for_timestep(self, timestep, schedule_timesteps=None): @@ -157,7 +157,12 @@ def scale_timestep(self, timestep: Union[float, torch.FloatTensor]): return 1000 * 0.25 * torch.log(sigma + 1e-44) - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -167,34 +172,56 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps - - # Note: timesteps are expected to be increasing rather than decreasing, following original implementation - if self.timesteps is None: - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") + + # Follow DDPMScheduler custom timesteps logic + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True else: - if isinstance(self.timesteps, list): - timesteps = np.array(self.timesteps, dtype=np.float32) - elif isinstance(self.timesteps, np.ndarray): - timesteps = self.timesteps.astype(np.float32) - else: - timesteps = self.timesteps.numpy().astype(np.float32) - timesteps = timesteps[: self.num_inference_steps] + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.custom_timesteps = False + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float64) # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 num_train_timesteps = self.config.num_train_timesteps # Append num_train_timesteps - 1 so sigmas[-1] == sigma_min - ramp = np.append(timesteps, [num_train_timesteps - 1]) / (num_train_timesteps - 1) + ramp = np.append(timesteps[::-1].copy(), [num_train_timesteps - 1]) + ramp = ramp / (num_train_timesteps - 1) sigmas = self._convert_to_karras(ramp) sigmas = sigmas.astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) - if str(device).startswith("mps"): - # mps does not support float64 - self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) - else: - self.timesteps = torch.from_numpy(timesteps).to(device=device) # Modified from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras # Use self.rho instead of hardcoded 7.0 for rho, sigma_min/max from config, configurable ramp function diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 75358c4f65f5..7a29f22bfdab 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -61,7 +61,6 @@ def get_dummy_components(self, class_cond=False): num_train_timesteps=40, sigma_min=0.002, sigma_max=80.0, - timesteps=np.array([0, 22]), ) components = { @@ -79,7 +78,8 @@ def get_dummy_inputs(self, device, seed=0): inputs = { "batch_size": 1, - "num_inference_steps": 2, + "num_inference_steps": None, + "timesteps": [22, 0], "generator": generator, "output_type": "numpy", } @@ -127,6 +127,7 @@ def test_consistency_model_pipeline_onestep(self): inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None image = pipe(**inputs).images assert image.shape == (1, 32, 32, 3) @@ -144,6 +145,7 @@ def test_consistency_model_pipeline_onestep_class_cond(self): inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None image = pipe(**inputs).images assert image.shape == (1, 32, 32, 3) @@ -165,7 +167,8 @@ def get_inputs(self, seed=0): generator = torch.manual_seed(seed) inputs = { - "num_inference_steps": 2, + "num_inference_steps": None, + "timesteps": [22, 0], "class_labels": 0, "generator": generator, "output_type": "numpy", @@ -179,7 +182,6 @@ def test_consistency_model_cd_multistep(self): num_train_timesteps=40, sigma_min=0.002, sigma_max=80.0, - timesteps=np.array([0, 22]), ) pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) pipe.to(torch_device) @@ -200,7 +202,6 @@ def test_consistency_model_cd_onestep(self): num_train_timesteps=40, sigma_min=0.002, sigma_max=80.0, - timesteps=np.array([0, 22]), ) pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) pipe.to(torch_device) @@ -208,6 +209,7 @@ def test_consistency_model_cd_onestep(self): inputs = self.get_inputs() inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None image = pipe(**inputs).images assert image.shape == (1, 64, 64, 3) From bf3a40566609619644422773335fcea92db5eb7d Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 9 Jun 2023 18:56:27 -0700 Subject: [PATCH 46/72] make style --- .../consistency_models/pipeline_consistency_models.py | 4 ++-- src/diffusers/schedulers/scheduling_consistency_models.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index b6e4ef11e46e..b9afaba7e3df 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -174,10 +174,10 @@ def prepare_class_labels(self, batch_size, device, class_labels=None): def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") - + if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") - + if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) if latents.shape != expected_shape: diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 21400f8d8a97..1ad03b5ce65b 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -107,7 +107,7 @@ def __init__( # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) - self.custom_timesteps = False, + self.custom_timesteps = False self.is_scale_input_called = False def index_for_timestep(self, timestep, schedule_timesteps=None): @@ -174,10 +174,10 @@ def set_timesteps( """ if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") - + if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") - + # Follow DDPMScheduler custom timesteps logic if timesteps is not None: for i in range(1, len(timesteps)): @@ -205,7 +205,7 @@ def set_timesteps( step_ratio = self.config.num_train_timesteps // self.num_inference_steps timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False - + if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) From 1488180a4710e9cdcead5295ebd8ff1e4eb8443a Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Mon, 12 Jun 2023 17:59:03 +0530 Subject: [PATCH 47/72] attention weight loading fix --- .../train_consistency_distillation.py | 10 +++--- scripts/convert_consistency_to_diffusers.py | 34 ++++++------------- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index c4712fa4efcc..b664cbe8bb39 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm import diffusers -from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, CMStochasticIterativeScheduler +from diffusers import DDPMPipeline, UNet2DModel, CMStochasticIterativeScheduler, ConsistencyModelPipeline from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available @@ -584,7 +584,7 @@ def transform_images(examples): # Sample noise that we'll add to the images noise = torch.randn(clean_images.shape).to(clean_images.device) bsz = clean_images.shape[0] - # Sample a random timestep for each image + # Sample a random timestep for each image, TODO - allow different timesteps in a batch index = torch.randint( 0, noise_scheduler.config.num_train_timesteps-1, (1,), device=clean_images.device ).long() @@ -668,7 +668,7 @@ def transform_images(examples): ema_model.store(unet.parameters()) ema_model.copy_to(unet.parameters()) - pipeline = DDPMPipeline( + pipeline = ConsistencyModelPipeline( unet=unet, scheduler=noise_scheduler, ) @@ -678,7 +678,7 @@ def transform_images(examples): images = pipeline( generator=generator, batch_size=args.eval_batch_size, - num_inference_steps=args.ddpm_num_inference_steps, + num_inference_steps=1, output_type="numpy", ).images @@ -709,7 +709,7 @@ def transform_images(examples): ema_model.store(unet.parameters()) ema_model.copy_to(unet.parameters()) - pipeline = DDPMPipeline( + pipeline = ConsistencyModelPipeline( unet=unet, scheduler=noise_scheduler, ) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 6a8e8eb938e0..923b72815e66 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -81,35 +81,21 @@ def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip= return new_checkpoint -def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim=64): - c, _, _, _ = checkpoint[f"{old_prefix}.qkv.weight"].shape - n_heads = c // (attention_head_dim * 3) - old_weights = checkpoint[f"{old_prefix}.qkv.weight"].reshape(n_heads, attention_head_dim * 3, -1, 1, 1) - old_biases = checkpoint[f"{old_prefix}.qkv.bias"].reshape(n_heads, attention_head_dim * 3, -1, 1, 1) - - weight_q, weight_k, weight_v = old_weights.chunk(3, dim=1) - weight_q = weight_q.reshape(n_heads * attention_head_dim, -1, 1, 1) - weight_k = weight_k.reshape(n_heads * attention_head_dim, -1, 1, 1) - weight_v = weight_v.reshape(n_heads * attention_head_dim, -1, 1, 1) - - bias_q, bias_k, bias_v = old_biases.chunk(3, dim=1) - bias_q = bias_q.reshape(n_heads * attention_head_dim, -1, 1, 1) - bias_k = bias_k.reshape(n_heads * attention_head_dim, -1, 1, 1) - bias_v = bias_v.reshape(n_heads * attention_head_dim, -1, 1, 1) +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None): + weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] - new_checkpoint[f"{new_prefix}.to_q.weight"] = torch.squeeze(weight_q) - new_checkpoint[f"{new_prefix}.to_q.bias"] = torch.squeeze(bias_q) - new_checkpoint[f"{new_prefix}.to_k.weight"] = torch.squeeze(weight_k) - new_checkpoint[f"{new_prefix}.to_k.bias"] = torch.squeeze(bias_k) - new_checkpoint[f"{new_prefix}.to_v.weight"] = torch.squeeze(weight_v) - new_checkpoint[f"{new_prefix}.to_v.bias"] = torch.squeeze(bias_v) + new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_out.0.weight"] = ( - checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) - ) + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) return new_checkpoint From fd0a2535848814b3b8175a165dbb4a4c2972d695 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 12 Jun 2023 20:58:53 -0700 Subject: [PATCH 48/72] Convert current slow tests to use fp16 and flash attention. --- .../pipeline_consistency_models.py | 5 +- .../test_consistency_models.py | 52 +++++++++++++------ 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index b9afaba7e3df..c622407b5dbb 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -2,7 +2,6 @@ from typing import Callable, List, Optional, Union import torch -from torch.backends.cuda import sdp_kernel from ...models import UNet2DModel from ...schedulers import KarrasDiffusionSchedulers @@ -288,9 +287,7 @@ def __call__( for i, t in enumerate(timesteps): scaled_sample = self.scheduler.scale_model_input(sample, t) scaled_t = self.scheduler.scale_timestep(t) - # model_output = self.unet(scaled_sample, scaled_t, class_labels=class_labels).sample - with sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): - model_output = self.unet(scaled_sample, scaled_t, class_labels=class_labels).sample + model_output = self.unet(scaled_sample, scaled_t, class_labels=class_labels).sample sample = self.scheduler.step( model_output, t, sample, use_noise=use_noise, **extra_step_kwargs diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 7a29f22bfdab..43c2ab5e45f2 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -3,14 +3,15 @@ import numpy as np import torch +from torch.backends.cuda import sdp_kernel from diffusers import ( CMStochasticIterativeScheduler, ConsistencyModelPipeline, UNet2DModel, ) -from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils import slow, torch_device, randn_tensor +from diffusers.utils.testing_utils import require_torch_2, require_torch_gpu from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -163,7 +164,7 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def get_inputs(self, seed=0): + def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): generator = torch.manual_seed(seed) inputs = { @@ -174,9 +175,21 @@ def get_inputs(self, seed=0): "output_type": "numpy", } - return inputs + if get_fixed_latents: + latents = self.get_fixed_latents(seed=seed, device=device, dtype=dtype, shape=shape) + inputs["latents"] = latents - def test_consistency_model_cd_multistep(self): + return inputs + + def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): + if type(device) == str: + device = torch.device(device) + generator = torch.Generator(device=device).manual_seed(seed) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @require_torch_2 + def test_consistency_model_cd_multistep_flash_attn(self): unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") scheduler = CMStochasticIterativeScheduler( num_train_timesteps=40, @@ -184,19 +197,22 @@ def test_consistency_model_cd_multistep(self): sigma_max=80.0, ) pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) - pipe.to(torch_device) + pipe.to(torch_device=torch_device, torch_dtype=torch.float16) pipe.set_progress_bar_config(disable=None) - inputs = self.get_inputs() - image = pipe(**inputs).images + inputs = self.get_inputs(get_fixed_latents=True, device=torch_device) + # Ensure usage of flash attention in torch 2.0 + with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + image = pipe(**inputs).images assert image.shape == (1, 64, 64, 3) image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.2645, 0.3386, 0.1928, 0.1284, 0.1215, 0.0285, 0.0800, 0.1213, 0.3331]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 4e-3 + expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]) - def test_consistency_model_cd_onestep(self): + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + @require_torch_2 + def test_consistency_model_cd_onestep_flash_attn(self): unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") scheduler = CMStochasticIterativeScheduler( num_train_timesteps=40, @@ -204,16 +220,18 @@ def test_consistency_model_cd_onestep(self): sigma_max=80.0, ) pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) - pipe.to(torch_device) + pipe.to(torch_device=torch_device, torch_dtype=torch.float16) pipe.set_progress_bar_config(disable=None) - inputs = self.get_inputs() + inputs = self.get_inputs(get_fixed_latents=True, device=torch_device) inputs["num_inference_steps"] = 1 inputs["timesteps"] = None - image = pipe(**inputs).images + # Ensure usage of flash attention in torch 2.0 + with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + image = pipe(**inputs).images assert image.shape == (1, 64, 64, 3) image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.2480, 0.1257, 0.0852, 0.2474, 0.3226, 0.1637, 0.3169, 0.2660, 0.3875]) + expected_slice = np.array([0.1623, 0.2009, 0.2387, 0.1731, 0.1168, 0.1202, 0.2031, 0.1327, 0.2447]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 4e-3 + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From 7c03ff6d123cb6703771df69af77101755c3f7f0 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 12 Jun 2023 21:01:44 -0700 Subject: [PATCH 49/72] make style --- .../consistency_models/test_consistency_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 43c2ab5e45f2..8bcc33a12078 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -10,7 +10,7 @@ ConsistencyModelPipeline, UNet2DModel, ) -from diffusers.utils import slow, torch_device, randn_tensor +from diffusers.utils import randn_tensor, slow, torch_device from diffusers.utils.testing_utils import require_torch_2, require_torch_gpu from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS @@ -180,14 +180,14 @@ def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch. inputs["latents"] = latents return inputs - + def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): if type(device) == str: device = torch.device(device) generator = torch.Generator(device=device).manual_seed(seed) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents - + @require_torch_2 def test_consistency_model_cd_multistep_flash_attn(self): unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") @@ -210,7 +210,7 @@ def test_consistency_model_cd_multistep_flash_attn(self): expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - + @require_torch_2 def test_consistency_model_cd_onestep_flash_attn(self): unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") From 80d7745243fbf4c916a54eaaaf7af12cc2100060 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 12 Jun 2023 22:49:11 -0700 Subject: [PATCH 50/72] Add slow tests for normal attention on cuda device. --- .../test_consistency_models.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 8bcc33a12078..551913ce22f0 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -187,6 +187,48 @@ def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, generator = torch.Generator(device=device).manual_seed(seed) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents + + def test_consistency_model_cd_multistep(self): + unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs() + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.0059, 0.0003, 0.0000, 0.0023, 0.0052, 0.0007, 0.0165, 0.0081, 0.0095]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + + def test_consistency_model_cd_onestep(self): + unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") + scheduler = CMStochasticIterativeScheduler( + num_train_timesteps=40, + sigma_min=0.002, + sigma_max=80.0, + ) + pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device=torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs() + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = None + image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.0146, 0.0158, 0.0092, 0.0086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0058]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 @require_torch_2 def test_consistency_model_cd_multistep_flash_attn(self): From 0662f63f2e98e3a273689db78541f84deda02dc3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 12 Jun 2023 22:50:31 -0700 Subject: [PATCH 51/72] make style --- tests/pipelines/consistency_models/test_consistency_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 551913ce22f0..1c2e9acf624a 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -187,7 +187,7 @@ def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, generator = torch.Generator(device=device).manual_seed(seed) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents - + def test_consistency_model_cd_multistep(self): unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2") scheduler = CMStochasticIterativeScheduler( From cce04c09b343976119e6837c2717565741bff6e4 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Tue, 13 Jun 2023 11:24:22 +0530 Subject: [PATCH 52/72] Fix attention weights loading --- scripts/convert_consistency_to_diffusers.py | 34 ++++++--------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 6a8e8eb938e0..923b72815e66 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -81,35 +81,21 @@ def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip= return new_checkpoint -def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim=64): - c, _, _, _ = checkpoint[f"{old_prefix}.qkv.weight"].shape - n_heads = c // (attention_head_dim * 3) - old_weights = checkpoint[f"{old_prefix}.qkv.weight"].reshape(n_heads, attention_head_dim * 3, -1, 1, 1) - old_biases = checkpoint[f"{old_prefix}.qkv.bias"].reshape(n_heads, attention_head_dim * 3, -1, 1, 1) - - weight_q, weight_k, weight_v = old_weights.chunk(3, dim=1) - weight_q = weight_q.reshape(n_heads * attention_head_dim, -1, 1, 1) - weight_k = weight_k.reshape(n_heads * attention_head_dim, -1, 1, 1) - weight_v = weight_v.reshape(n_heads * attention_head_dim, -1, 1, 1) - - bias_q, bias_k, bias_v = old_biases.chunk(3, dim=1) - bias_q = bias_q.reshape(n_heads * attention_head_dim, -1, 1, 1) - bias_k = bias_k.reshape(n_heads * attention_head_dim, -1, 1, 1) - bias_v = bias_v.reshape(n_heads * attention_head_dim, -1, 1, 1) +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None): + weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] - new_checkpoint[f"{new_prefix}.to_q.weight"] = torch.squeeze(weight_q) - new_checkpoint[f"{new_prefix}.to_q.bias"] = torch.squeeze(bias_q) - new_checkpoint[f"{new_prefix}.to_k.weight"] = torch.squeeze(weight_k) - new_checkpoint[f"{new_prefix}.to_k.bias"] = torch.squeeze(bias_k) - new_checkpoint[f"{new_prefix}.to_v.weight"] = torch.squeeze(weight_v) - new_checkpoint[f"{new_prefix}.to_v.bias"] = torch.squeeze(bias_v) + new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_out.0.weight"] = ( - checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) - ) + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) return new_checkpoint From 62a49a2fb70f2c18cd992a726463516546c611da Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 12 Jun 2023 23:33:08 -0700 Subject: [PATCH 53/72] Update consistency model fast tests for new test checkpoints with attention fix. --- tests/pipelines/consistency_models/test_consistency_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 1c2e9acf624a..2277d85e0b6e 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -111,6 +111,7 @@ def test_consistency_model_pipeline_multistep_class_cond(self): pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) + inputs["class_labels"] = 0 image = pipe(**inputs).images assert image.shape == (1, 32, 32, 3) @@ -147,6 +148,7 @@ def test_consistency_model_pipeline_onestep_class_cond(self): inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 1 inputs["timesteps"] = None + inputs["class_labels"] = 0 image = pipe(**inputs).images assert image.shape == (1, 32, 32, 3) From 58f12cad9abccc28844d517553c4d9c2e1e23569 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 12 Jun 2023 23:37:00 -0700 Subject: [PATCH 54/72] make style --- scripts/convert_consistency_to_diffusers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 923b72815e66..e47e4a572259 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -95,7 +95,9 @@ def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attent new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_out.0.weight"] = checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = ( + checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + ) new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) return new_checkpoint From 98e13816d91ed7b78b2e7386248ad3c4494166a2 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Thu, 15 Jun 2023 11:27:36 +0530 Subject: [PATCH 55/72] Renaming ema model to target --- .../train_consistency_distillation.py | 73 +++++++------------ 1 file changed, 28 insertions(+), 45 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index b664cbe8bb39..375ac4ab0466 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -185,11 +185,6 @@ def parse_args(): "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer." ) parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.") - parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to use Exponential Moving Average for the final model weights.", - ) parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") @@ -314,8 +309,7 @@ def main(args): if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): - if args.use_ema: - ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) + target_model_ema.save_pretrained(os.path.join(output_dir, "unet_ema")) for i, model in enumerate(models): model.save_pretrained(os.path.join(output_dir, "unet")) @@ -324,11 +318,10 @@ def save_model_hook(models, weights, output_dir): weights.pop() def load_model_hook(models, input_dir): - if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel) - ema_model.load_state_dict(load_model.state_dict()) - ema_model.to(accelerator.device) - del load_model + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel) + target_model_ema.load_state_dict(load_model.state_dict()) + target_model_ema.to(accelerator.device) + del load_model for i in range(len(models)): # pop models so that they are not loaded again @@ -429,20 +422,18 @@ def load_model_hook(models, input_dir): num_scales = 40 noise_scheduler.set_timesteps(num_scales) timesteps = noise_scheduler.timesteps - # print(teacher_model) # Create EMA for the model, this is the target model in the paper - if args.use_ema: - ema_model = EMAModel( - model.parameters(), - decay=args.ema_max_decay, - use_ema_warmup=True, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - model_cls=UNet2DModel, - model_config=model.config, - ) + target_model_ema = EMAModel( + model.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + model_cls=UNet2DModel, + model_config=model.config, + ) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -514,12 +505,11 @@ def transform_images(examples): ) # Prepare everything with our `accelerator`. - model, optimizer, train_dataloader, lr_scheduler, teacher_model, target_model, ema_model = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler, teacher_model, target_model, ema_model + model, optimizer, train_dataloader, lr_scheduler, teacher_model, target_model, target_model_ema = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler, teacher_model, target_model, target_model_ema ) - if args.use_ema: - ema_model.to(accelerator.device) + target_model_ema.to(accelerator.device) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -583,7 +573,6 @@ def transform_images(examples): labels = batch["labels"] # Sample noise that we'll add to the images noise = torch.randn(clean_images.shape).to(clean_images.device) - bsz = clean_images.shape[0] # Sample a random timestep for each image, TODO - allow different timesteps in a batch index = torch.randint( 0, noise_scheduler.config.num_train_timesteps-1, (1,), device=clean_images.device @@ -594,7 +583,7 @@ def transform_images(examples): noised_image = clean_images + noise*append_dims(timestep, clean_images.ndim) scaled_timesteps = noise_scheduler.scale_timestep(timestep) scaled_timesteps_prev = noise_scheduler.scale_timestep(timestep_prev) - ema_model.copy_to(target_model.parameters()) + target_model_ema.copy_to(target_model.parameters()) with accelerator.accumulate(model): # Predict the noise residual @@ -603,7 +592,7 @@ def transform_images(examples): model_output, timestep, noised_image, use_noise=False ).prev_sample - # Heun Solver to get previous timestep image + # Heun Solver to get previous timestep image using teacher model samples = noised_image x = samples model_output = teacher_model(x, scaled_timesteps, class_labels=labels).sample @@ -626,7 +615,7 @@ def transform_images(examples): model_output, timestep_prev, denoised_image, use_noise=False ).prev_sample - loss = F.mse_loss(distiller, distiller_target) # this could have different weights! + loss = F.mse_loss(distiller, distiller_target) loss = loss.mean() accelerator.backward(loss) @@ -639,8 +628,7 @@ def transform_images(examples): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - if args.use_ema: - ema_model.step(model.parameters()) + target_model_ema.step(model.parameters()) progress_bar.update(1) global_step += 1 @@ -651,8 +639,7 @@ def transform_images(examples): logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} - if args.use_ema: - logs["ema_decay"] = ema_model.cur_decay_value + logs["ema_decay"] = target_model_ema.cur_decay_value progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) progress_bar.close() @@ -664,9 +651,8 @@ def transform_images(examples): if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: unet = accelerator.unwrap_model(model) - if args.use_ema: - ema_model.store(unet.parameters()) - ema_model.copy_to(unet.parameters()) + target_model_ema.store(unet.parameters()) + target_model_ema.copy_to(unet.parameters()) pipeline = ConsistencyModelPipeline( unet=unet, @@ -682,8 +668,7 @@ def transform_images(examples): output_type="numpy", ).images - if args.use_ema: - ema_model.restore(unet.parameters()) + target_model_ema.restore(unet.parameters()) # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") @@ -705,9 +690,8 @@ def transform_images(examples): # save the model unet = accelerator.unwrap_model(model) - if args.use_ema: - ema_model.store(unet.parameters()) - ema_model.copy_to(unet.parameters()) + target_model_ema.store(unet.parameters()) + target_model_ema.copy_to(unet.parameters()) pipeline = ConsistencyModelPipeline( unet=unet, @@ -716,8 +700,7 @@ def transform_images(examples): pipeline.save_pretrained(args.output_dir) - if args.use_ema: - ema_model.restore(unet.parameters()) + target_model_ema.restore(unet.parameters()) if args.push_to_hub: repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) From baefc87c8a305ace0e26ae9a8a51bbb14efd7bf8 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Thu, 15 Jun 2023 11:39:59 +0530 Subject: [PATCH 56/72] Add some comments --- .../train_consistency_distillation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 375ac4ab0466..a6d131bf7290 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -235,7 +235,7 @@ def parse_args(): type=int, default=500, help=( - "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" " training using `--resume_from_checkpoint`." ), ) @@ -369,7 +369,7 @@ def load_model_hook(models, input_dir): elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # Initialize the model + # Initialize the model, using a smaller model than the one defined in the original paper by default if args.model_config_name_or_path is None: model = UNet2DModel( sample_size= args.resolution, @@ -414,9 +414,10 @@ def load_model_hook(models, input_dir): model = UNet2DModel.from_config(config) target_model = UNet2DModel.from_config(config) + # load the model to distill into a consistency model teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet model = model.double() - target_model = target_model.double() + target_model = target_model.double() # TODO : support half precision training teacher_model = teacher_model.double() noise_scheduler = CMStochasticIterativeScheduler() num_scales = 40 @@ -579,7 +580,7 @@ def transform_images(examples): ).long() timestep = timesteps[index] timestep_prev = timestep + 1 - # TO-DO, we should have an add noise in the scheduler maybe? + # TODO, we should have an add noise in the scheduler maybe? noised_image = clean_images + noise*append_dims(timestep, clean_images.ndim) scaled_timesteps = noise_scheduler.scale_timestep(timestep) scaled_timesteps_prev = noise_scheduler.scale_timestep(timestep_prev) @@ -593,6 +594,7 @@ def transform_images(examples): ).prev_sample # Heun Solver to get previous timestep image using teacher model + # TODO - make this cleaner samples = noised_image x = samples model_output = teacher_model(x, scaled_timesteps, class_labels=labels).sample From 071f8506247f99db3ff01b7a38131e0f8415d12d Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 15 Jun 2023 17:59:13 -0700 Subject: [PATCH 57/72] apply suggestions --- .../pipeline_consistency_models.py | 8 -------- .../schedulers/scheduling_consistency_models.py | 17 ++--------------- 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index c622407b5dbb..ba9dcbc2dff4 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -41,8 +41,6 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasDiffusionSchedulers) -> N self.safety_checker = None - # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - # Modified to only offload self.unet def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, @@ -68,8 +66,6 @@ def enable_sequential_cpu_offload(self, gpu_id=0): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - # Modified to only offload self.unet def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -135,8 +131,6 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - # Unlike stable diffusion, no VAE so no vae_scale_factor, num_channels_latent => num_channels def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: @@ -293,8 +287,6 @@ def __call__( model_output, t, sample, use_noise=use_noise, **extra_step_kwargs ).prev_sample - # Note: differs from callback support in original code - # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459 # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 1ad03b5ce65b..e19788e4b13d 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,18 +26,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") - return x[(...,) + (None,) * dims_to_append] - - @dataclass class CMStochasticIterativeSchedulerOutput(BaseOutput): """ @@ -223,8 +211,7 @@ def set_timesteps( sigmas = sigmas.astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) - # Modified from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras - # Use self.rho instead of hardcoded 7.0 for rho, sigma_min/max from config, configurable ramp function + # Modified _convert_to_karras implementation that takes in ramp as argument def _convert_to_karras(self, ramp): """Constructs the noise schedule of Karras et al. (2022).""" From cd460ca2f5107ca6053e27c6e92f0aee2248627c Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Sat, 17 Jun 2023 09:10:42 +0530 Subject: [PATCH 58/72] Remove xformers, refactor ckpt resuming --- examples/consistency_models/script.sh | 2 +- .../train_consistency_distillation.py | 40 ++++++------------- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/examples/consistency_models/script.sh b/examples/consistency_models/script.sh index 5b5f6b4710e9..dfe6bcc5b7a0 100644 --- a/examples/consistency_models/script.sh +++ b/examples/consistency_models/script.sh @@ -1,3 +1,3 @@ #!/bin/bash -accelerate launch train_consistency_distillation.py --dataset_name="huggan/flowers-102-categories" --resolution=64 --center_crop --random_flip --output_dir="ddpm-ema-flowers-64" --train_batch_size=16 --num_epochs=100 --gradient_accumulation_steps=1 --use_ema --learning_rate=1e-4 --lr_warmup_steps=500 --mixed_precision=no --push_to_hub \ No newline at end of file +accelerate launch train_consistency_distillation.py --dataset_name="cifar10" --resolution=32 --center_crop --random_flip --output_dir="cifar10-32" --train_batch_size=16 --num_epochs=100 --gradient_accumulation_steps=1 --learning_rate=1e-4 --lr_warmup_steps=500 --mixed_precision=no --push_to_hub \ No newline at end of file diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index a6d131bf7290..2ffcd729887e 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -24,7 +24,6 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available -from diffusers.utils.import_utils import is_xformers_available #Copied from examples/unconditional_image_generation/train_unconditional.py for now @@ -258,9 +257,6 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) - parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." - ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -436,19 +432,6 @@ def load_model_hook(models, input_dir): model_config=model.config, ) - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - model.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - # Initialize the optimizer optimizer = torch.optim.AdamW( model.parameters(), @@ -554,22 +537,25 @@ def transform_images(examples): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # Train! for epoch in range(first_epoch, args.num_epochs): model.train() - progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) - progress_bar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - clean_images = batch["input"] labels = batch["labels"] # Sample noise that we'll add to the images From 6968615482e94aceee912a14bc5270f2ca9434db Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 17 Jun 2023 18:54:18 -0700 Subject: [PATCH 59/72] Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler). --- .../scheduling_consistency_models.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index e19788e4b13d..dab428208a03 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -340,5 +340,31 @@ def step( return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + def __len__(self): return self.config.num_train_timesteps From 9a816425914dd6db6e7ab4baba9dbbcdcf9a75e4 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Mon, 19 Jun 2023 12:11:56 +0530 Subject: [PATCH 60/72] Add input scaling, disable gradients --- .../train_consistency_distillation.py | 59 ++++++++++--------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 2ffcd729887e..99f7d06fa097 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -412,9 +412,9 @@ def load_model_hook(models, input_dir): # load the model to distill into a consistency model teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet - model = model.double() - target_model = target_model.double() # TODO : support half precision training - teacher_model = teacher_model.double() + model = model.float() + target_model = target_model.float() # TODO : support half precision training + teacher_model = teacher_model.float() noise_scheduler = CMStochasticIterativeScheduler() num_scales = 40 noise_scheduler.set_timesteps(num_scales) @@ -566,42 +566,43 @@ def transform_images(examples): ).long() timestep = timesteps[index] timestep_prev = timestep + 1 - # TODO, we should have an add noise in the scheduler maybe? - noised_image = clean_images + noise*append_dims(timestep, clean_images.ndim) + noised_image = noise_scheduler.add_noise(clean_images, noise, timestep) scaled_timesteps = noise_scheduler.scale_timestep(timestep) scaled_timesteps_prev = noise_scheduler.scale_timestep(timestep_prev) target_model_ema.copy_to(target_model.parameters()) with accelerator.accumulate(model): # Predict the noise residual - model_output = model(noised_image, scaled_timesteps, class_labels=labels).sample + + model_output = model(noise_scheduler.scale_model_input(noised_image, timestep), scaled_timesteps, class_labels=labels).sample distiller = noise_scheduler.step( model_output, timestep, noised_image, use_noise=False ).prev_sample - # Heun Solver to get previous timestep image using teacher model - # TODO - make this cleaner - samples = noised_image - x = samples - model_output = teacher_model(x, scaled_timesteps, class_labels=labels).sample - teacher_denoiser = noise_scheduler.step( - model_output, timestep, x, use_noise=False - ).prev_sample - d = (x - teacher_denoiser) / append_dims(scaled_timesteps, x.ndim) - samples = x + d * append_dims(scaled_timesteps_prev - scaled_timesteps, x.ndim) - model_output = teacher_model(samples, scaled_timesteps_prev, class_labels=labels).sample - teacher_denoiser = noise_scheduler.step( - model_output, timestep_prev, samples, use_noise=False - ).prev_sample - - next_d = (samples - teacher_denoiser) / append_dims(scaled_timesteps_prev, x.ndim) - denoised_image = x + (d + next_d) * append_dims((scaled_timesteps_prev - scaled_timesteps) /2, x.ndim) - - # get output from target model - model_output = target_model(denoised_image, scaled_timesteps_prev, class_labels=labels).sample - distiller_target = noise_scheduler.step( - model_output, timestep_prev, denoised_image, use_noise=False - ).prev_sample + with torch.no_grad(): + # Heun Solver to get previous timestep image using teacher model + # TODO - make this cleaner + samples = noised_image + x = samples + model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), scaled_timesteps, class_labels=labels).sample + teacher_denoiser = noise_scheduler.step( + model_output, timestep, x, use_noise=False + ).prev_sample + d = (x - teacher_denoiser) / append_dims(scaled_timesteps, x.ndim) + samples = x + d * append_dims(scaled_timesteps_prev - scaled_timesteps, x.ndim) + model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), scaled_timesteps_prev, class_labels=labels).sample + teacher_denoiser = noise_scheduler.step( + model_output, timestep_prev, samples, use_noise=False + ).prev_sample + + next_d = (samples - teacher_denoiser) / append_dims(scaled_timesteps_prev, x.ndim) + denoised_image = x + (d + next_d) * append_dims((scaled_timesteps_prev - scaled_timesteps) /2, x.ndim) + + # get output from target model + model_output = target_model(denoised_image, scaled_timesteps_prev, class_labels=labels).sample + distiller_target = noise_scheduler.step( + model_output, timestep_prev, denoised_image, use_noise=False + ).prev_sample loss = F.mse_loss(distiller, distiller_target) loss = loss.mean() From ca15734d907996d1562b3ffb0f52ea988fc362cc Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 21 Jun 2023 03:21:46 -0700 Subject: [PATCH 61/72] Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers. --- scripts/convert_consistency_to_diffusers.py | 112 +++++++++++++++++--- 1 file changed, 95 insertions(+), 17 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index e47e4a572259..545e55a342aa 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -1,8 +1,13 @@ import argparse +import os import torch -from diffusers.models.unet_2d import UNet2DModel +from diffusers import ( + CMStochasticIterativeScheduler, + ConsistencyModelPipeline, + UNet2DModel, +) TEST_UNET_CONFIG = { @@ -47,6 +52,51 @@ "resnet_time_scale_shift": "scale_shift", } +LSUN_256_UNET_CONFIG = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 2, + "num_class_embeds": None, + "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "AttnDownsampleBlock2D", + "AttnDownsampleBlock2D", + "AttnDownsampleBlock2D", + ], + "up_block_types": [ + "AttnUpsampleBlock2D", + "AttnUpsampleBlock2D", + "AttnUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "default", +} + +CD_SCHEDULER_CONFIG = { + "num_train_timesteps": 40, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + +CT_IMAGENET_64_SCHEDULER_CONFIG = { + "num_train_timesteps": 201, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + +CT_LSUN_256_SCHEDULER_CONFIG = { + "num_train_timesteps": 151, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + def str2bool(v): """ @@ -121,22 +171,27 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): down_block_types = unet_config["down_block_types"] layers_per_block = unet_config["layers_per_block"] attention_head_dim = unet_config["attention_head_dim"] + channels_list = unet_config["block_out_channels"] current_layer = 1 + prev_channels = channels_list[0] for i, layer_type in enumerate(down_block_types): + current_channels = channels_list[i] + downsample_block_has_skip = current_channels != prev_channels if layer_type == "ResnetDownsampleBlock2D": for j in range(layers_per_block): new_prefix = f"down_blocks.{i}.resnets.{j}" old_prefix = f"input_blocks.{current_layer}.0" - new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + has_skip = True if j == 0 and downsample_block_has_skip else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) current_layer += 1 elif layer_type == "AttnDownsampleBlock2D": for j in range(layers_per_block): new_prefix = f"down_blocks.{i}.resnets.{j}" old_prefix = f"input_blocks.{current_layer}.0" - has_skip = True if j == 0 else False - new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip) + has_skip = True if j == 0 and downsample_block_has_skip else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) new_prefix = f"down_blocks.{i}.attentions.{j}" old_prefix = f"input_blocks.{current_layer}.1" new_checkpoint = convert_attention( @@ -149,6 +204,8 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): old_prefix = f"input_blocks.{current_layer}.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) current_layer += 1 + + prev_channels = current_channels # hardcoded the mid-block for now new_prefix = "mid_block.resnets.0" @@ -171,6 +228,11 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): old_prefix = f"output_blocks.{current_layer}.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) current_layer += 1 + + if i != len(up_block_types) - 1: + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.1" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) elif layer_type == "AttnUpsampleBlock2D": for j in range(layers_per_block + 1): new_prefix = f"up_blocks.{i}.resnets.{j}" @@ -182,12 +244,11 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim ) current_layer += 1 - - new_prefix = f"up_blocks.{i}.upsamplers.0" - old_prefix = f"output_blocks.{current_layer-1}.2" - # print(new_prefix) - # print(old_prefix) - new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + if i != len(up_block_types) - 1: + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] @@ -204,18 +265,23 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): parser.add_argument( "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model." ) - parser.add_argument("--checkpoint_name", default="cd_imagenet64_l2", type=str, help="Checkpoint to convert.") parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.") args = parser.parse_args() args.class_cond = str2bool(args.class_cond) - if "imagenet64" in args.checkpoint_name: + ckpt_name = os.path.basename(args.unet_path) + print(f"Checkpoint: {ckpt_name}") + + # Get U-Net config + if "imagenet64" in ckpt_name: unet_config = IMAGENET_64_UNET_CONFIG - elif "test" in args.checkpoint_name: + elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)): + unet_config = LSUN_256_UNET_CONFIG + elif "test" in ckpt_name: unet_config = TEST_UNET_CONFIG else: - raise ValueError(f"Checkpoint type {args.checkpoint_name} is not currently supported.") + raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") if not args.class_cond: unet_config["num_class_embeds"] = None @@ -223,7 +289,19 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config) image_unet = UNet2DModel(**unet_config) - # print(image_unet) - # exit() image_unet.load_state_dict(converted_unet_ckpt) - image_unet.save_pretrained(args.dump_path) + + # Get scheduler config + if "cd" in ckpt_name or "test" in ckpt_name: + scheduler_config = CD_SCHEDULER_CONFIG + elif "ct" in ckpt_name and "imagenet64" in ckpt_name: + scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG + elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)): + scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG + else: + raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") + + cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config) + + consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler) + consistency_model.save_pretrained(args.dump_path) From a56d3d28c77a2a43e4bf8e21fd7c8132ad5f9383 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 21 Jun 2023 03:22:40 -0700 Subject: [PATCH 62/72] When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence). --- .../consistency_models/pipeline_consistency_models.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index ba9dcbc2dff4..7e17761e4e00 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -8,11 +8,15 @@ from ...utils import ( is_accelerate_available, is_accelerate_version, + logging, randn_tensor, ) from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class ConsistencyModelPipeline(DiffusionPipeline): r""" Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1]. @@ -169,7 +173,10 @@ def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_ raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") if num_inference_steps is not None and timesteps is not None: - raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") + logger.warning( + f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" + " `timesteps` will be used over `num_inference_steps`." + ) if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) From a17847ee4032b78b3bcc8f632f5b31c297a3216f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 21 Jun 2023 03:24:52 -0700 Subject: [PATCH 63/72] make style --- scripts/convert_consistency_to_diffusers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index 545e55a342aa..12165be88251 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -204,7 +204,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): old_prefix = f"input_blocks.{current_layer}.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) current_layer += 1 - + prev_channels = current_channels # hardcoded the mid-block for now @@ -228,7 +228,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): old_prefix = f"output_blocks.{current_layer}.0" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) current_layer += 1 - + if i != len(up_block_types) - 1: new_prefix = f"up_blocks.{i}.upsamplers.0" old_prefix = f"output_blocks.{current_layer-1}.1" @@ -244,7 +244,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim ) current_layer += 1 - + if i != len(up_block_types) - 1: new_prefix = f"up_blocks.{i}.upsamplers.0" old_prefix = f"output_blocks.{current_layer-1}.2" @@ -300,7 +300,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG else: raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") - + cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config) consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler) From 8214a33eb59f9d48ec010aceca9f206079e01bb7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 21 Jun 2023 05:01:49 -0700 Subject: [PATCH 64/72] Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example. --- docs/source/en/api/pipelines/consistency_models.mdx | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/consistency_models.mdx b/docs/source/en/api/pipelines/consistency_models.mdx index 9707cc910359..6a2f174e64a0 100644 --- a/docs/source/en/api/pipelines/consistency_models.mdx +++ b/docs/source/en/api/pipelines/consistency_models.mdx @@ -13,7 +13,14 @@ Resources: Available Checkpoints are: - *cd_imagenet64_l2 (64x64 resolution)* [dg845/consistency-model-pipelines](https://huggingface.co/dg845/consistency-model-pipelines) -- TODO: add more checkpoints from original release? +- *cd_imagenet64_lpips (64x64 resolution)* [dg845/diffusers-cd_imagenet64_lpips](https://huggingface.co/dg845/diffusers-cd_imagenet64_lpips) +- *ct_imagenet64 (64x64 resolution)* [dg845/diffusers-ct_imagenet64](https://huggingface.co/dg845/diffusers-ct_imagenet64) +- *cd_bedroom256_l2 (256x256 resolution)* [dg845/diffusers-cd_bedroom256_l2](https://huggingface.co/dg845/diffusers-cd_bedroom256_l2) +- *cd_bedroom256_lpips (256x256 resolution)* [dg845/diffusers-cd_bedroom256_lpips](https://huggingface.co/dg845/diffusers-cd_bedroom256_lpips) +- *ct_bedroom256 (256x256 resolution)* [dg845/diffusers-ct_bedroom256](https://huggingface.co/dg845/diffusers-ct_bedroom256) +- *cd_cat256_l2 (256x256 resolution)* [dg845/diffusers-cd_cat256_l2](https://huggingface.co/dg845/diffusers-cd_cat256_l2) +- *cd_cat256_lpips (256x256 resolution)* [dg845/diffusers-cd_cat256_lpips](https://huggingface.co/dg845/diffusers-cd_cat256_lpips) +- *ct_cat256 (256x256 resolution)* [dg845/diffusers-ct_cat256](https://huggingface.co/dg845/diffusers-ct_cat256) ## Available Pipelines @@ -44,7 +51,9 @@ image = pipe(num_inference_steps=1, class_labels=145).images[0] image.save("consistency_model_onestep_sample_penguin.png") # Multistep sampling, class-conditional image generation -image = pipe(num_inference_steps=2, class_labels=145).images[0] +# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo. +# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 +image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] image.save("consistency_model_multistep_sample_penguin.png") ``` From 2606d82eb86889d7648cd4c203b1aca3e5c17c46 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 28 Jun 2023 19:48:59 -0700 Subject: [PATCH 65/72] apply suggestions from review --- .../pipeline_consistency_models.py | 94 +++++++++---------- .../scheduling_consistency_models.py | 38 ++++---- .../test_consistency_models.py | 4 +- 3 files changed, 66 insertions(+), 70 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 7e17761e4e00..0b9ce6e61d59 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -6,6 +6,7 @@ from ...models import UNet2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -117,24 +118,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: @@ -151,6 +134,34 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess_image(self, latents: torch.FloatTensor, output_type: str = "pil"): + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + # Return latents without modification + return latents + + # Equivalent to diffusers.VaeImageProcessor.denormalize + latents = (latents / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return latents + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + latents = latents.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return latents + + # Output_type must be 'pil' + latents = self.numpy_to_pil(latents) + return latents def prepare_class_labels(self, batch_size, device, class_labels=None): if self.unet.config.num_class_embeds is not None: @@ -196,9 +207,8 @@ def __call__( self, batch_size: int = 1, class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, - num_inference_steps: int = 40, + num_inference_steps: int = 1, timesteps: List[int] = None, - eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -213,15 +223,12 @@ def __call__( class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): Optional class labels for conditioning class-conditional consistency models. Will not be used if the model is not class-conditional. - num_inference_steps (`int`, *optional*, defaults to 40): + num_inference_steps (`int`, *optional*, defaults to 1): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -255,7 +262,7 @@ def __call__( # Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = self.prepare_latents( batch_size=batch_size, - num_channels=3, + num_channels=self.unet.config.in_channels, height=img_size, width=img_size, dtype=self.unet.dtype, @@ -276,42 +283,31 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps - # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 6. Denoising loop + # 5. Denoising loop # Multistep sampling: implements Algorithm 1 in the paper - # Onestep sampling does not use random noise - use_noise = False if num_inference_steps == 1 else True - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + use_noise = False if num_inference_steps == 1 else True # Onestep sampling does not use random noise with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): scaled_sample = self.scheduler.scale_model_input(sample, t) - scaled_t = self.scheduler.scale_timestep(t) - model_output = self.unet(scaled_sample, scaled_t, class_labels=class_labels).sample + model_output = self.unet(scaled_sample, t, class_labels=class_labels).sample sample = self.scheduler.step( - model_output, t, sample, use_noise=use_noise, **extra_step_kwargs + model_output, t, sample, use_noise=use_noise, generator=generator ).prev_sample # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, sample) - - # 7. Post-process image sample - sample = (sample / 2 + 0.5).clamp(0, 1) - sample = sample.cpu().permute(0, 2, 3, 1).numpy() + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) - if output_type == "pil": - sample = self.numpy_to_pil(sample) - - if not return_dict: - return (sample,) + # 6. Post-process image sample + image = self.postprocess_image(sample, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() + + if not return_dict: + return (image,) - return ImagePipelineOutput(images=sample) + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index dab428208a03..3eab567e0a77 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -128,22 +128,18 @@ def scale_model_input( self.is_scale_input_called = True return sample - def scale_timestep(self, timestep: Union[float, torch.FloatTensor]): + def _sigma_to_t(self, sigmas: np.ndarray): """ - Scales the timestep based on the associated Karras sigma, for input to the consistency model. + Gets scaled timesteps from the Karras sigmas, for input to the consistency model. Args: - timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + sigmas (`np.ndarray`, *optional*): array of Karras sigmas Returns: - `torch.FloatTensor`: scaled input timestep + `np.ndarray`: scaled input timestep """ - # Get sigma corresponding to timestep - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_idx = self.index_for_timestep(timestep) - sigma = self.sigmas[step_idx] + timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44) - return 1000 * 0.25 * torch.log(sigma + 1e-44) + return timesteps def set_timesteps( self, @@ -159,6 +155,10 @@ def set_timesteps( the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, optional): + custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` + must be `None`. """ if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") @@ -194,23 +194,23 @@ def set_timesteps( timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False - if str(device).startswith("mps"): - # mps does not support float64 - self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) - else: - self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float64) - # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 num_train_timesteps = self.config.num_train_timesteps - # Append num_train_timesteps - 1 so sigmas[-1] == sigma_min - ramp = np.append(timesteps[::-1].copy(), [num_train_timesteps - 1]) + ramp = timesteps[::-1].copy() ramp = ramp / (num_train_timesteps - 1) sigmas = self._convert_to_karras(ramp) + timesteps = self._sigma_to_t(sigmas) - sigmas = sigmas.astype(np.float32) + sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + # Modified _convert_to_karras implementation that takes in ramp as argument def _convert_to_karras(self, ramp): """Constructs the noise schedule of Karras et al. (2022).""" diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index 2277d85e0b6e..361a49857e68 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -82,7 +82,7 @@ def get_dummy_inputs(self, device, seed=0): "num_inference_steps": None, "timesteps": [22, 0], "generator": generator, - "output_type": "numpy", + "output_type": "np", } return inputs @@ -174,7 +174,7 @@ def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch. "timesteps": [22, 0], "class_labels": 0, "generator": generator, - "output_type": "numpy", + "output_type": "np", } if get_fixed_latents: From 075351ad4d617e309d2eb2cd1ad642f9c5fdc448 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 28 Jun 2023 19:53:11 -0700 Subject: [PATCH 66/72] make style --- .../pipeline_consistency_models.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index 0b9ce6e61d59..a427f5bca5e6 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -1,4 +1,3 @@ -import inspect from typing import Callable, List, Optional, Union import torch @@ -134,7 +133,7 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - + # Follows diffusers.VaeImageProcessor.postprocess def postprocess_image(self, latents: torch.FloatTensor, output_type: str = "pil"): if output_type not in ["latent", "pt", "np", "pil"]: @@ -144,21 +143,21 @@ def postprocess_image(self, latents: torch.FloatTensor, output_type: str = "pil" ) deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" - + if output_type == "latent": # Return latents without modification return latents - + # Equivalent to diffusers.VaeImageProcessor.denormalize latents = (latents / 2 + 0.5).clamp(0, 1) if output_type == "pt": return latents - + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy latents = latents.cpu().permute(0, 2, 3, 1).numpy() if output_type == "np": return latents - + # Output_type must be 'pil' latents = self.numpy_to_pil(latents) return latents @@ -306,7 +305,7 @@ def __call__( # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - + if not return_dict: return (image,) From ea7d75f0e2ce51628a095f1039f737510a49547d Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Mon, 10 Jul 2023 12:58:05 +0530 Subject: [PATCH 67/72] Update training script to main, fix timesteps --- .../train_consistency_distillation.py | 79 ++++--- src/diffusers/models/unet_2d_blocks.py | 219 ------------------ .../scheduling_consistency_models.py | 3 +- 3 files changed, 52 insertions(+), 249 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 99f7d06fa097..0937f1332639 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -5,7 +5,7 @@ import os from pathlib import Path from typing import Optional - +import shutil import accelerate import datasets import torch @@ -23,7 +23,7 @@ from diffusers import DDPMPipeline, UNet2DModel, CMStochasticIterativeScheduler, ConsistencyModelPipeline from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available +from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available, is_xformers_available #Copied from examples/unconditional_image_generation/train_unconditional.py for now @@ -281,14 +281,12 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) - - accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.logger, - logging_dir=logging_dir, project_config=accelerator_project_config, ) @@ -377,14 +375,15 @@ def load_model_hook(models, input_dir): attention_head_dim=8, down_block_types= [ "ResnetDownsampleBlock2D", - "AttnDownsampleBlock2D", + "AttnDownBlock2D", ], up_block_types= [ - "AttnUpsampleBlock2D", + "AttnUpBlock2D", "ResnetUpsampleBlock2D", ], resnet_time_scale_shift="scale_shift", - + upsample_type="resnet", + downsample_type="resnet" ) target_model = UNet2DModel( sample_size= args.resolution, @@ -396,19 +395,21 @@ def load_model_hook(models, input_dir): attention_head_dim=8, down_block_types= [ "ResnetDownsampleBlock2D", - "AttnDownsampleBlock2D", + "AttnDownBlock2D", ], up_block_types= [ - "AttnUpsampleBlock2D", + "AttnUpBlock2D", "ResnetUpsampleBlock2D", ], resnet_time_scale_shift="scale_shift", - + upsample_type="resnet", + downsample_type="resnet" ) else: config = UNet2DModel.load_config(args.model_config_name_or_path) model = UNet2DModel.from_config(config) target_model = UNet2DModel.from_config(config) + # load the model to distill into a consistency model teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet @@ -417,8 +418,6 @@ def load_model_hook(models, input_dir): teacher_model = teacher_model.float() noise_scheduler = CMStochasticIterativeScheduler() num_scales = 40 - noise_scheduler.set_timesteps(num_scales) - timesteps = noise_scheduler.timesteps # Create EMA for the model, this is the target model in the paper @@ -489,12 +488,12 @@ def transform_images(examples): ) # Prepare everything with our `accelerator`. - model, optimizer, train_dataloader, lr_scheduler, teacher_model, target_model, target_model_ema = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler, teacher_model, target_model, target_model_ema + model, optimizer, noise_scheduler, train_dataloader, lr_scheduler, teacher_model, target_model, target_model_ema = accelerator.prepare( + model, optimizer, noise_scheduler, train_dataloader, lr_scheduler, teacher_model, target_model, target_model_ema ) + noise_scheduler.set_timesteps(num_scales, device=accelerator.device) target_model_ema.to(accelerator.device) - # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: @@ -550,7 +549,8 @@ def transform_images(examples): disable=not accelerator.is_local_main_process, ) - + timesteps = noise_scheduler.timesteps + sigmas = noise_scheduler.sigmas # Train! for epoch in range(first_epoch, args.num_epochs): @@ -564,17 +564,18 @@ def transform_images(examples): index = torch.randint( 0, noise_scheduler.config.num_train_timesteps-1, (1,), device=clean_images.device ).long() + # timestep is the scaled timestep, sigma is the unscaled timestep timestep = timesteps[index] - timestep_prev = timestep + 1 + sigma = sigmas[index] + timestep_prev = timesteps[index+1] + sigma_prev = sigmas[index+1] + # add noise expects the scaled timestep only and internally converts to sigma noised_image = noise_scheduler.add_noise(clean_images, noise, timestep) - scaled_timesteps = noise_scheduler.scale_timestep(timestep) - scaled_timesteps_prev = noise_scheduler.scale_timestep(timestep_prev) target_model_ema.copy_to(target_model.parameters()) with accelerator.accumulate(model): # Predict the noise residual - - model_output = model(noise_scheduler.scale_model_input(noised_image, timestep), scaled_timesteps, class_labels=labels).sample + model_output = model(noise_scheduler.scale_model_input(noised_image, timestep), timestep, class_labels=labels).sample distiller = noise_scheduler.step( model_output, timestep, noised_image, use_noise=False ).prev_sample @@ -584,22 +585,22 @@ def transform_images(examples): # TODO - make this cleaner samples = noised_image x = samples - model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), scaled_timesteps, class_labels=labels).sample + model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), timestep, class_labels=labels).sample teacher_denoiser = noise_scheduler.step( model_output, timestep, x, use_noise=False ).prev_sample - d = (x - teacher_denoiser) / append_dims(scaled_timesteps, x.ndim) - samples = x + d * append_dims(scaled_timesteps_prev - scaled_timesteps, x.ndim) - model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), scaled_timesteps_prev, class_labels=labels).sample + d = (x - teacher_denoiser) / append_dims(sigma, x.ndim) + samples = x + d * append_dims(sigma_prev - sigma, x.ndim) + model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample teacher_denoiser = noise_scheduler.step( model_output, timestep_prev, samples, use_noise=False ).prev_sample - next_d = (samples - teacher_denoiser) / append_dims(scaled_timesteps_prev, x.ndim) - denoised_image = x + (d + next_d) * append_dims((scaled_timesteps_prev - scaled_timesteps) /2, x.ndim) + next_d = (samples - teacher_denoiser) / append_dims(sigma_prev, x.ndim) + denoised_image = x + (d + next_d) * append_dims((sigma_prev - sigma) /2, x.ndim) # get output from target model - model_output = target_model(denoised_image, scaled_timesteps_prev, class_labels=labels).sample + model_output = target_model(noise_scheduler.scale_model_input(denoised_image, timestep_prev), timestep_prev, class_labels=labels).sample distiller_target = noise_scheduler.step( model_output, timestep_prev, denoised_image, use_noise=False ).prev_sample @@ -622,6 +623,26 @@ def transform_images(examples): global_step += 1 if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 25ac6d7b1186..cb3452f4459c 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -107,20 +107,6 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, ) - elif down_block_type == "AttnDownsampleBlock2D": - return AttnDownsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") @@ -361,20 +347,6 @@ def get_up_block( resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, ) - elif up_block_type == "AttnUpsampleBlock2D": - return AttnUpsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( num_layers=num_layers, @@ -887,100 +859,6 @@ def forward(self, hidden_states, temb=None, upsample_size=None): return hidden_states, output_states -class AttnDownsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - Attention( - out_channels, - heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, - dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None, upsample_size=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - - output_states += (hidden_states,) - - return hidden_states, output_states - - class CrossAttnDownBlock2D(nn.Module): def __init__( self, @@ -2114,103 +1992,6 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si return hidden_states -class AttnUpsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - Attention( - out_channels, - heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, - dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - - else: - self.upsamplers = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - - return hidden_states - - class CrossAttnUpBlock2D(nn.Module): def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index fb296054d65b..f580c28453cd 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -268,6 +268,7 @@ def step( sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, + use_noise: bool = True ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -331,7 +332,7 @@ def step( # 2. Sample z ~ N(0, s_noise^2 * I) # Noise is not used for onestep sampling. - if len(self.timesteps) > 1: + if len(self.timesteps) > 1 and use_noise: noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) From a32b8691b04d8f417601a2db5b032041408b162f Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Mon, 10 Jul 2023 13:08:11 +0530 Subject: [PATCH 68/72] Fix bug in timestep ordering --- .../train_consistency_distillation.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 0937f1332639..31fe32bc70b8 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -549,8 +549,9 @@ def transform_images(examples): disable=not accelerator.is_local_main_process, ) - timesteps = noise_scheduler.timesteps - sigmas = noise_scheduler.sigmas + timesteps = noise_scheduler.timesteps + sigmas = noise_scheduler.sigmas # in reverse order, sigma0 is sigma_max + # Train! for epoch in range(first_epoch, args.num_epochs): @@ -565,10 +566,10 @@ def transform_images(examples): 0, noise_scheduler.config.num_train_timesteps-1, (1,), device=clean_images.device ).long() # timestep is the scaled timestep, sigma is the unscaled timestep - timestep = timesteps[index] - sigma = sigmas[index] - timestep_prev = timesteps[index+1] - sigma_prev = sigmas[index+1] + timestep = timesteps[index+1] + sigma = sigmas[index+1] + timestep_prev = timesteps[index] + sigma_prev = sigmas[index] # add noise expects the scaled timestep only and internally converts to sigma noised_image = noise_scheduler.add_noise(clean_images, noise, timestep) target_model_ema.copy_to(target_model.parameters()) From 8742e4e57edc92215950ffaec07af57da48ee960 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Tue, 11 Jul 2023 12:57:52 +0530 Subject: [PATCH 69/72] Add review suggestions --- examples/consistency_models/requirements.txt | 2 + examples/consistency_models/script.sh | 3 - .../train_consistency_distillation.py | 65 ++++--------------- 3 files changed, 13 insertions(+), 57 deletions(-) delete mode 100644 examples/consistency_models/script.sh diff --git a/examples/consistency_models/requirements.txt b/examples/consistency_models/requirements.txt index f366720afd11..bc7b6a7f238b 100644 --- a/examples/consistency_models/requirements.txt +++ b/examples/consistency_models/requirements.txt @@ -1,3 +1,5 @@ accelerate>=0.16.0 torchvision datasets +wandb +tensrboard diff --git a/examples/consistency_models/script.sh b/examples/consistency_models/script.sh deleted file mode 100644 index dfe6bcc5b7a0..000000000000 --- a/examples/consistency_models/script.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -accelerate launch train_consistency_distillation.py --dataset_name="cifar10" --resolution=32 --center_crop --random_flip --output_dir="cifar10-32" --train_batch_size=16 --num_epochs=100 --gradient_accumulation_steps=1 --learning_rate=1e-4 --lr_warmup_steps=500 --mixed_precision=no --push_to_hub \ No newline at end of file diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 31fe32bc70b8..29cdf725e008 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -18,7 +18,7 @@ from packaging import version from torchvision import transforms from tqdm.auto import tqdm - +import wandb import diffusers from diffusers import DDPMPipeline, UNet2DModel, CMStochasticIterativeScheduler, ConsistencyModelPipeline from diffusers.optimization import get_scheduler @@ -33,35 +33,6 @@ logger = get_logger(__name__, log_level="INFO") -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") - return x[(...,) + (None,) * dims_to_append] - - - - - - -def _extract_into_tensor(arr, timesteps, broadcast_shape): - """ - Extract values from a 1-D numpy array for a batch of indices. - - :param arr: the 1-D numpy array. - :param timesteps: a tensor of indices into the array to extract. - :param broadcast_shape: a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. - """ - if not isinstance(arr, torch.Tensor): - arr = torch.from_numpy(arr) - res = arr[timesteps].float().to(timesteps.device) - while len(res.shape) < len(broadcast_shape): - res = res[..., None] - return res.expand(broadcast_shape) - def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") @@ -290,15 +261,6 @@ def main(args): project_config=accelerator_project_config, ) - if args.logger == "tensorboard": - if not is_tensorboard_available(): - raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") - - elif args.logger == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb - # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -413,9 +375,6 @@ def load_model_hook(models, input_dir): # load the model to distill into a consistency model teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet - model = model.float() - target_model = target_model.float() # TODO : support half precision training - teacher_model = teacher_model.float() noise_scheduler = CMStochasticIterativeScheduler() num_scales = 40 @@ -586,24 +545,22 @@ def transform_images(examples): # TODO - make this cleaner samples = noised_image x = samples - model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), timestep, class_labels=labels).sample + teacher_model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), timestep, class_labels=labels).sample teacher_denoiser = noise_scheduler.step( - model_output, timestep, x, use_noise=False + teacher_model_output, timestep, x, use_noise=False ).prev_sample - d = (x - teacher_denoiser) / append_dims(sigma, x.ndim) - samples = x + d * append_dims(sigma_prev - sigma, x.ndim) - model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample + d = (x - teacher_denoiser) / sigma[(...,) + (None,) * 3] + samples = x + d * (sigma_prev - sigma)[(...,) + (None,) * 3] + teacher_model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample teacher_denoiser = noise_scheduler.step( - model_output, timestep_prev, samples, use_noise=False + teacher_model_output, timestep_prev, samples, use_noise=False ).prev_sample - - next_d = (samples - teacher_denoiser) / append_dims(sigma_prev, x.ndim) - denoised_image = x + (d + next_d) * append_dims((sigma_prev - sigma) /2, x.ndim) - + next_d = (samples - teacher_denoiser) / sigma_prev[(...,) + (None,) * 3] + denoised_image = x + (d + next_d) * ((sigma_prev - sigma) /2)[(...,) + (None,) * 3] # get output from target model - model_output = target_model(noise_scheduler.scale_model_input(denoised_image, timestep_prev), timestep_prev, class_labels=labels).sample + target_model_output = target_model(noise_scheduler.scale_model_input(denoised_image, timestep_prev), timestep_prev, class_labels=labels).sample distiller_target = noise_scheduler.step( - model_output, timestep_prev, denoised_image, use_noise=False + target_model_output, timestep_prev, denoised_image, use_noise=False ).prev_sample loss = F.mse_loss(distiller, distiller_target) From 943c88b05b629abd5dd4b604464040dee6e689be Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Wed, 12 Jul 2023 12:30:16 +0530 Subject: [PATCH 70/72] Integrate accelerator better, change model upload --- .../train_consistency_distillation.py | 131 +++++++++++++----- 1 file changed, 97 insertions(+), 34 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 29cdf725e008..9cd1a4ef56a3 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -14,7 +14,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import HfFolder, Repository, create_repo, whoami, upload_folder from packaging import version from torchvision import transforms from tqdm.auto import tqdm @@ -34,6 +34,39 @@ logger = get_logger(__name__, log_level="INFO") +def save_model_card( + repo_id: str, + images=None, + base_model=str, + repo_folder=None, + pipeline: ConsistencyModelPipeline = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- consistency models +- diffusers +inference: true +--- + """ + model_card = f""" +# Consistency Model - {repo_id} + +This is a consistency model distilled from {base_model}. +You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -58,6 +91,13 @@ def parse_args(): default=None, help="The config of the UNet model to train, leave as None to use standard Consistency configuration.", ) + parser.add_argument( + "--pretrained_teacher_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models to be used as teacher model", + ) parser.add_argument( "--train_data_dir", type=str, @@ -314,14 +354,9 @@ def load_model_hook(models, input_dir): repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token) repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) @@ -374,10 +409,27 @@ def load_model_hook(models, input_dir): # load the model to distill into a consistency model - teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet + teacher_model = DDPMPipeline.from_pretrained(args.pretrained_teacher_model_name_or_path).unet noise_scheduler = CMStochasticIterativeScheduler() num_scales = 40 + # Check that all trainable models are in full precision + low_precision_error_string = ( + "Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training. copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(model).dtype != torch.float32: + raise ValueError( + f"Unet loaded as datatype {accelerator.unwrap_model(model).dtype}. {low_precision_error_string}" + ) + + if args.train_text_encoder and accelerator.unwrap_model(teacher_model).dtype != torch.float32: + raise ValueError( + f"Text encoder loaded as datatype {accelerator.unwrap_model(teacher_model).dtype}." + f" {low_precision_error_string}" + ) + # Create EMA for the model, this is the target model in the paper target_model_ema = EMAModel( @@ -452,12 +504,10 @@ def transform_images(examples): ) noise_scheduler.set_timesteps(num_scales, device=accelerator.device) - target_model_ema.to(accelerator.device) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - run = os.path.split(__file__)[-1].split(".")[0] - accelerator.init_trackers(run) + accelerator.init_trackers("consistency-distillation", vars(args)) total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -519,10 +569,10 @@ def transform_images(examples): clean_images = batch["input"] labels = batch["labels"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn(clean_images.shape).to(accelerator.device) # Sample a random timestep for each image, TODO - allow different timesteps in a batch index = torch.randint( - 0, noise_scheduler.config.num_train_timesteps-1, (1,), device=clean_images.device + 0, noise_scheduler.config.num_train_timesteps-1, (1,), device=accelerator.device ).long() # timestep is the scaled timestep, sigma is the unscaled timestep timestep = timesteps[index+1] @@ -580,31 +630,31 @@ def transform_images(examples): progress_bar.update(1) global_step += 1 - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) - if accelerator.is_main_process: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} logs["ema_decay"] = target_model_ema.cur_decay_value @@ -671,6 +721,19 @@ def transform_images(examples): target_model_ema.restore(unet.parameters()) if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_teacher_model_name_or_path, + repo_folder=args.output_dir, + pipeline=pipeline, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) accelerator.end_training() From 6b58d8195f60c4fb9eabaac10cdd0c41d23a7d2c Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Thu, 13 Jul 2023 12:44:07 +0530 Subject: [PATCH 71/72] Fix checkpointing and add test --- .../train_consistency_distillation.py | 100 +++++++++++------- examples/test_examples.py | 23 ++++ 2 files changed, 86 insertions(+), 37 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 9cd1a4ef56a3..8fb306cd514b 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -199,6 +199,8 @@ def parse_args(): parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--testing", action="store_true", help="If running a test") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", @@ -354,14 +356,23 @@ def load_model_hook(models, input_dir): repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id - repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # Initialize the model, using a smaller model than the one defined in the original paper by default - if args.model_config_name_or_path is None: + # For testing use a dummy model + if args.testing: + config = UNet2DModel.load_config('diffusers/consistency-models-test', subfolder="test_unet") + elif args.model_config_name_or_path is not None: + config = UNet2DModel.load_config(args.model_config_name_or_path) + # Use the config if provided, model and target model have the same structure + if config is not None: + model = UNet2DModel.from_config(config) + target_model = UNet2DModel.from_config(config) + # Otherwise, use a default config + else: model = UNet2DModel( sample_size= args.resolution, in_channels=3, @@ -401,15 +412,12 @@ def load_model_hook(models, input_dir): resnet_time_scale_shift="scale_shift", upsample_type="resnet", downsample_type="resnet" - ) + ) + if args.testing: + teacher_model = UNet2DModel.from_config(config) else: - config = UNet2DModel.load_config(args.model_config_name_or_path) - model = UNet2DModel.from_config(config) - target_model = UNet2DModel.from_config(config) - - - # load the model to distill into a consistency model - teacher_model = DDPMPipeline.from_pretrained(args.pretrained_teacher_model_name_or_path).unet + # load the model to distill into a consistency model + teacher_model = DDPMPipeline.from_pretrained(args.pretrained_teacher_model_name_or_path).unet noise_scheduler = CMStochasticIterativeScheduler() num_scales = 40 @@ -421,12 +429,12 @@ def load_model_hook(models, input_dir): if accelerator.unwrap_model(model).dtype != torch.float32: raise ValueError( - f"Unet loaded as datatype {accelerator.unwrap_model(model).dtype}. {low_precision_error_string}" + f"Consistency Model loaded as datatype {accelerator.unwrap_model(model).dtype}. {low_precision_error_string}" ) - if args.train_text_encoder and accelerator.unwrap_model(teacher_model).dtype != torch.float32: + if accelerator.unwrap_model(teacher_model).dtype != torch.float32: raise ValueError( - f"Text encoder loaded as datatype {accelerator.unwrap_model(teacher_model).dtype}." + f"Teacher_model loaded as datatype {accelerator.unwrap_model(teacher_model).dtype}." f" {low_precision_error_string}" ) @@ -442,7 +450,7 @@ def load_model_hook(models, input_dir): model_config=model.config, ) - # Initialize the optimizer + # Initialize the optimizer # TODO: Change this to match the paper, RAdam optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, @@ -480,12 +488,16 @@ def load_model_hook(models, input_dir): ) def transform_images(examples): - images = [augmentations(image.convert("RGB")) for image in examples["img"]] + img_key = "image" if "image" in examples else "img" + images = [augmentations(image.convert("RGB")) for image in examples[img_key]] labels = [torch.tensor(label) for label in examples["label"]] return {"input": images, "labels": labels} + + logger.info(f"Dataset size: {len(dataset)}") dataset.set_transform(transform_images) + train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) @@ -499,10 +511,11 @@ def transform_images(examples): ) # Prepare everything with our `accelerator`. - model, optimizer, noise_scheduler, train_dataloader, lr_scheduler, teacher_model, target_model, target_model_ema = accelerator.prepare( - model, optimizer, noise_scheduler, train_dataloader, lr_scheduler, teacher_model, target_model, target_model_ema + model, optimizer, noise_scheduler, train_dataloader, lr_scheduler, teacher_model, target_model = accelerator.prepare( + model, optimizer, noise_scheduler, train_dataloader, lr_scheduler, teacher_model, target_model ) noise_scheduler.set_timesteps(num_scales, device=accelerator.device) + target_model_ema.to(accelerator.device) # TODO accelerate.prepare doesn't work on this for some reason # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -668,7 +681,6 @@ def transform_images(examples): if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: unet = accelerator.unwrap_model(model) - target_model_ema.store(unet.parameters()) target_model_ema.copy_to(unet.parameters()) @@ -683,7 +695,7 @@ def transform_images(examples): generator=generator, batch_size=args.eval_batch_size, num_inference_steps=1, - output_type="numpy", + output_type="np", ).images target_model_ema.restore(unet.parameters()) @@ -707,7 +719,6 @@ def transform_images(examples): if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model unet = accelerator.unwrap_model(model) - target_model_ema.store(unet.parameters()) target_model_ema.copy_to(unet.parameters()) @@ -717,24 +728,39 @@ def transform_images(examples): ) pipeline.save_pretrained(args.output_dir) - target_model_ema.restore(unet.parameters()) + + if accelerator.is_main_process and args.push_to_hub: + unet = accelerator.unwrap_model(model) + target_model_ema.copy_to(unet.parameters()) + + pipeline = ConsistencyModelPipeline( + unet=unet, + scheduler=noise_scheduler, + ) - if args.push_to_hub: - save_model_card( - repo_id, - images=images, - base_model=args.pretrained_teacher_model_name_or_path, - repo_folder=args.output_dir, - pipeline=pipeline, - ) - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - ) - repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) + generator = torch.Generator(device=pipeline.device).manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + images = pipeline( + generator=generator, + batch_size=args.eval_batch_size, + num_inference_steps=1, + output_type="pil", + ).images + + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_teacher_model_name_or_path, + repo_folder=args.output_dir, + pipeline=pipeline, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/test_examples.py b/examples/test_examples.py index d11841350064..f1096f9e1289 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -96,6 +96,29 @@ def test_train_unconditional(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_train_consistency(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/consistency_models/train_consistency_distillation.py + --dataset_name hf-internal-testing/dummy_image_class_data + --resolution 32 + --output_dir {tmpdir} + --train_batch_size 2 + --num_epochs 1 + --gradient_accumulation_steps 1 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --testing + --pretrained_teacher_model_name_or_path google/ddpm-cifar10-32 + """.split() + + run_command(self._launch_args + test_args, return_stdout=True) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + + def test_textual_inversion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From 180b3e80a4f60341d4ff139156e7330290d3b9f3 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Wed, 19 Jul 2023 16:16:41 +0530 Subject: [PATCH 72/72] Remove hardcoded configs, add DiffusionPipeline --- .../train_consistency_distillation.py | 73 ++++--------------- 1 file changed, 15 insertions(+), 58 deletions(-) diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py index 8fb306cd514b..551b5d0c1b1e 100644 --- a/examples/consistency_models/train_consistency_distillation.py +++ b/examples/consistency_models/train_consistency_distillation.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm import wandb import diffusers -from diffusers import DDPMPipeline, UNet2DModel, CMStochasticIterativeScheduler, ConsistencyModelPipeline +from diffusers import DiffusionPipeline, UNet2DModel, CMStochasticIterativeScheduler, ConsistencyModelPipeline from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available, is_xformers_available @@ -362,63 +362,19 @@ def load_model_hook(models, input_dir): elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # For testing use a dummy model - if args.testing: - config = UNet2DModel.load_config('diffusers/consistency-models-test', subfolder="test_unet") - elif args.model_config_name_or_path is not None: - config = UNet2DModel.load_config(args.model_config_name_or_path) # Use the config if provided, model and target model have the same structure - if config is not None: - model = UNet2DModel.from_config(config) - target_model = UNet2DModel.from_config(config) - # Otherwise, use a default config - else: - model = UNet2DModel( - sample_size= args.resolution, - in_channels=3, - out_channels=3, - layers_per_block=2, - num_class_embeds=1000, - block_out_channels= [32, 64], - attention_head_dim=8, - down_block_types= [ - "ResnetDownsampleBlock2D", - "AttnDownBlock2D", - ], - up_block_types= [ - "AttnUpBlock2D", - "ResnetUpsampleBlock2D", - ], - resnet_time_scale_shift="scale_shift", - upsample_type="resnet", - downsample_type="resnet" - ) - target_model = UNet2DModel( - sample_size= args.resolution, - in_channels=3, - out_channels=3, - layers_per_block=2, - num_class_embeds=1000, - block_out_channels= [32, 64], - attention_head_dim=8, - down_block_types= [ - "ResnetDownsampleBlock2D", - "AttnDownBlock2D", - ], - up_block_types= [ - "AttnUpBlock2D", - "ResnetUpsampleBlock2D", - ], - resnet_time_scale_shift="scale_shift", - upsample_type="resnet", - downsample_type="resnet" - ) - if args.testing: - teacher_model = UNet2DModel.from_config(config) + if args.model_config_name_or_path is not None: + config = UNet2DModel.load_config(args.model_config_name_or_path) + # Else use a default config else: - # load the model to distill into a consistency model - teacher_model = DDPMPipeline.from_pretrained(args.pretrained_teacher_model_name_or_path).unet + config = UNet2DModel.load_config("ayushtues/consistency_tiny_unet") + model = UNet2DModel.from_config(config) + target_model = UNet2DModel.from_config(config) noise_scheduler = CMStochasticIterativeScheduler() + # load the model to distill into a consistency model + teacher_pipeline = DiffusionPipeline.from_pretrained(args.pretrained_teacher_model_name_or_path) + teacher_model = teacher_pipeline.unet + teacher_scheduler = teacher_pipeline.scheduler num_scales = 40 # Check that all trainable models are in full precision @@ -493,11 +449,8 @@ def transform_images(examples): labels = [torch.tensor(label) for label in examples["label"]] return {"input": images, "labels": labels} - - logger.info(f"Dataset size: {len(dataset)}") dataset.set_transform(transform_images) - train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) @@ -572,6 +525,7 @@ def transform_images(examples): ) timesteps = noise_scheduler.timesteps + teacher_scheduler.set_timesteps(timesteps=timesteps, device=accelerator.device) sigmas = noise_scheduler.sigmas # in reverse order, sigma0 is sigma_max @@ -614,6 +568,9 @@ def transform_images(examples): ).prev_sample d = (x - teacher_denoiser) / sigma[(...,) + (None,) * 3] samples = x + d * (sigma_prev - sigma)[(...,) + (None,) * 3] + # We probably want to use Sigma for an arbitrary teacher model here, since that corresponds to the unscaled timestep + # We just want a denoised image from an input x, t using the teacher model, since that is used in the score function + # So we should figure out how to get the denoised image from the teacher model teacher_model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample teacher_denoiser = noise_scheduler.step( teacher_model_output, timestep_prev, samples, use_noise=False