Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Consistency models distillation examples #3992

Closed

Conversation

ayushtues
Copy link
Contributor

Add distillation training scripts for Consistency models

Continuation of dg845#2, adding distillation examples for the recently added Consistency Models (paper, original code, diffusers pipeline)

Model/Pipeline Description

Consistency Models (paper, code) are a new family of generative models similar to continuous-time diffusion models which support fast one-step generation. Diffusion models can be distilled into a consistency model for faster sampling, and consistency models can also be trained from scratch. From the paper abstract:

Diffusion models have made significant breakthroughs in image, audio, and video generation, but they depend on an iterative generation process that causes slow sampling speed and caps their potential for real-time applications. To overcome this limitation, we propose consistency models, a new family of generative models that achieve high sample quality without adversarial training. They support fast one-step generation by design, while still allowing for few-step sampling to trade compute for sample quality. They also support zero-shot data editing, like image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either as a way to distill pre-trained diffusion models, or as standalone generative models. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step generation. For example, we achieve the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained as standalone generative models, consistency models also outperform single-step, non-adversarial generative models on standard benchmarks like CIFAR-10, ImageNet 64x64 and LSUN 256x256.

In this PR, we implement the distillation procedure to distill a diffusion model into a consistency model as described in the paper

TODO

  • Add training scripts for distillation

CC
@dg845
@williamberman
@patrickvonplaten

dg845 and others added 30 commits May 15, 2023 08:04
…ampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.
…IterativeScheduler and add initial version of tests.
…le_shift to 'scale_shift' to better match consistency model checkpoints.
…s to the consistency models conversion script.
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__
… and make related changes to the pipeline and tests.
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sayakpaul
Copy link
Member

@ayushtues, thanks for this work! Let us know when you'd like a review from our end. The PR doesn't have to be super tight for an initial review just as an FYI.

@ayushtues
Copy link
Contributor Author

Hey @sayakpaul yes I think it is good enough for an initial review right now, one major thing we need to figure out is how to do the scaling of the inputs for different teacher models (as mentioned in dg845#2 (comment)). The original paper used EDM as a teacher model, and that was trained using a Karras Scheduler, so it followed the same scaling and processing as the Conistency model, but if we want to use something else like Stable Diffusion or DDPM as the teacher model, need to figure out what input scaling and processing to use for them

@ayushtues
Copy link
Contributor Author

ayushtues commented Jul 10, 2023

@dg845, @patrickvonplaten @williamberman @sayakpaul do you think we should add examples of other applications of Consistency models mentioned in the paper like inpainting, colorization, super-resolution? They mostly seem to involve some clever processing of the inputs and the outputs of the model, for eg code reference in the original repo : https://github.com/openai/consistency_models/blob/ac278060af7175e37f4c1e79e69fa521234c04de/cm/karras_diffusion.py#L722

@sayakpaul
Copy link
Member

sayakpaul commented Jul 10, 2023

one major thing we need to figure out is how to do the scaling of the inputs for different teacher models (as mentioned in dg845#2 (comment)).

If we decided to say, distill https://huggingface.co/valhalla/sd-pokemon-model, can't we use the same scheduler for distilling the student (with appropriate changes as needed) and use the same preprocessing steps? I like what @dg845 is proposing here: dg845#2 (comment).

@dg845, @patrickvonplaten @williamberman @sayakpaul do you think we should add examples of other applications of Consistency models mentioned in the paper like inpainting, colorization, super-resolution?

Let's gauge the community interest for that first and then we can definitely consider.

@@ -0,0 +1,3 @@
#!/bin/bash
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't include shell-scripts in the examples.

@@ -268,6 +268,7 @@ def step(
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
use_noise: bool = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to introduce this argument?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe make this a config attribute instead? use_noise=True will rarely changed to use_noise=False during inference - either it's always False or True. Could we add a force_deterministic=True/False to the config?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to put the section of the function that computes the unnoised prediction in its own function that can be called.

i.e.

def step(...):
    denoised = self.denoise(...)

    if len(self.timesteps) > 1:
         prev_sample = ...
    else:
        prev_sample = denoised

    return prev_sample

def denoise(...):
    ...

this would avoid extra global config to the class and the training script could just call denoise()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would also work for me - is probably indeed the better option (as my first proposal would force people to update the config after training, which is probs more brittle)

Comment on lines 293 to 300
if args.logger == "tensorboard":
if not is_tensorboard_available():
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")

elif args.logger == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of these dependencies need to go to requirements.txt.

Comment on lines 410 to 411
model = UNet2DModel.from_config(config)
target_model = UNet2DModel.from_config(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this, in a sense, is step distillation i.e., we're not essentially compressing the model but reducing the number of steps.

Copy link
Contributor

@dg845 dg845 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it is a distinction without a difference, but I think of it as reducing the number of steps for a (diffusion model, sampler) pair. In my mind, this distinction is important because my guess is that the quality of the sampler affects the quality of the distilled consistency model; if we distill using a sampler which produces bad samples for the diffusion model, the consistency model might learn to faithfully produce similar samples to sampling from the diffusion model using the sampler, but which are still poor quality because the samples it learned from are poor quality. (Conversely, using a better sampler might produce a better consistency model.)

[Another way to phrase it might be that it's compressing the sampling procedure, which depends on both the diffusion model and sampler.]



# load the model to distill into a consistency model
teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this path ("google/ddpm-cifar10-32") a CLI arg.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it should also work with DiffusionPipeline.



# load the model to distill into a consistency model
teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear what model, target_model, and teacher_model mean.

Copy link
Contributor

@dg845 dg845 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is as follows (things in square brackets [] are the variable names in the current PR code / terminology follows Section 4/Algorithm 2 in the paper):

  • The teacher_model $s_{\phi}$ is the already-trained diffusion model we want to distill to a consistency model (for example, google/ddpm-cifar10-32).
  • The model is a consistency model $\boldsymbol{f_\theta}$ which is used to calculate the current output of the distilled consistency model $\boldsymbol{f_\theta}(x_{t_{n + 1}}, t_{n + 1})$ [model_output no. 1*] on the noisy sample $x_{t_{n + 1}}$ [noised_image]. The parameters $\boldsymbol{\theta}$ are updated via gradient descent. (The paper calls this the "online" model, following RL terminology.)
  • The target_model is a consistency model $\boldsymbol{f_{\theta^-}}$ which is used to calculate the distillation target, which we want the consistency model output $\boldsymbol{f_\theta}(x_{t_{n + 1}}, t_{n + 1})$ to match. This is obtained by first running one step of the sampler (in the paper and code, one step of the Karras Heun sampler) on the teacher_model to obtain a previous noisy sample $\hat{x}_{t_n}^\phi$ [model_output no. 2*], and then getting the target_model consistency model's output $\boldsymbol{f_{\theta^-}}(\hat{x}_{t_{n}}^\phi, t_{n})$ [model_output no. 3*] on $\hat{x}_{t_n}^\phi$. The target parameters $\boldsymbol{\theta^-}$ are updated via an EMA update (using $\boldsymbol{\theta}$). (This is analogous to target models in DQN-style RL algorithms, and I believe the EMA update is equivalent to what is sometimes called polyak averaging in RL.)

The loss function $L^{\phi}(\boldsymbol{\theta}, \boldsymbol{\theta^-})$ is then calculated using a distance metric $d(\cdot, \cdot)$ (MSE in the code)**:

$$L^{\phi}(\boldsymbol{\theta}, \boldsymbol{\theta^-}) = d(\boldsymbol{f_\theta}(x_{t_{n + 1}}, t_{n + 1}), \boldsymbol{f_{\theta^-}}(\hat{x}_{t_{n}}^\phi, t_{n}))$$

and then we take a gradient descent step with respect to model's parameters $\boldsymbol{\theta}$ and run an EMA update for target_model's parameters $\boldsymbol{\theta^-}$. (It's not clear to me which of model and target_model is outputted as the final trained checkpoint in the end; according to the paper, theoretically the parameters for the models should converge during training [e.g. $\boldsymbol{\theta} = \boldsymbol{\theta^-}$ at the end of training], so using either one should be fine.)

Some proposals:

  • I agree that what model, target_model, teacher_model are doing might be hard to follow from the code alone, perhaps it might make sense to add comments explaining what each one does when they are initialized?
  • Rename model_output no. 2 to teacher_model_output and model_output no. 3 to target_model_output. (If we want to rename model [see below], perhaps model_output no. 1 can be renamed similarly.)
  • Perhaps renaming model to something like online_model or current_consistency_model might be more clear? (Not sure on this one, especially the current_consistency_model suggestion, since I'm not sure if it's accurate; would feel more comfortable with this if model is what is outputted as the final trained checkpoint, see above.)

(*) Technically, these quantities correspond to distiller, denoised_image, and distiller_target in the code, respectively, but they're all preceded by a calculated raw U-Net output named model_output.
(**) The loss function in Algorithm 2 includes a timestep-dependent weight function $\lambda(t_n)$, but in practice it seems that a uniform weight schedule ($\lambda(t_n) = 1$) is usually used for consistency distillation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for explaining @dg845. See my responses below:

I agree that what model, target_model, teacher_model are doing might be hard to follow from the code alone, perhaps it might make sense to add comments explaining what each one does when they are initialized?

100% agree. The math is heavy enough and we should grab the opportunity here to make it easily digestible to the end users.

Rename model_output no. 2 to teacher_model_output and model_output no. 3 to target_model_output.

Works for me.

Perhaps renaming model to something like online_model or current_consistency_model might be more clear? (Not sure on this one, especially the current_consistency_model suggestion, since I'm not sure if it's accurate; would feel more comfortable with this if model is what is outputted as the final trained checkpoint, see above.)

Let's maybe try to output EMA'd model only since that is more common in practice?

This is obtained by first running one step of the sampler (in the paper and code, one step of the Karras Heun sampler) on the teacher_model to obtain a previous noisy sample
[model_output no. 2*]

Have we checked if the teacher_model performs okay with that sampler? Does it matter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe try to output EMA'd model only since that is more common in practice?

Sounds good.

Have we checked if the teacher_model performs okay with that sampler? Does it matter?

I think there is a bug in the current code where we use noise_scheduler when performing one step of sampling on the teacher_model:

teacher_model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample
teacher_denoiser = noise_scheduler.step(
teacher_model_output, timestep_prev, samples, use_noise=False
).prev_sample

because noise_scheduler is a CMStochasticIterativeScheduler:

noise_scheduler = CMStochasticIterativeScheduler()

which isn't compatible with diffusion models. I think the scheduler should either be teacher_scheduler = DiffusionPipeline.from_pretrained(<model_id>).scheduler (if we want to use the scheduler the teacher_model was trained with) or HeunDiscreteScheduler (if we follow the original implementation). (We would still need noise_scheduler to get the model outputs for the consistency models model and teacher_model.)

As for whether the Heun scheduler works well with google/ddpm-cifar10-32, which is presumably trained with DDPMScheduler, I'm not sure whether it will affect the training results. I guess practically speaking we could try it out as is, and if the resulting model doesn't create reasonable samples, then we can try either switching the sampler for getting teacher samples to DDPMScheduler or switching the teacher_model to something that was trained with HeunDiscreteScheduler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a plan! Thanks so much for being generous with your explanations!

# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually perform this initialization like so:

accelerator.init_trackers("dreambooth-lora", config=vars(args))


with accelerator.accumulate(model):
# Predict the noise residual
model_output = model(noise_scheduler.scale_model_input(noised_image, timestep), timestep, class_labels=labels).sample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three model_output variables here. Let's maybe try to give them more appropriate names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #3992 (comment) for some suggestions.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good.

From my end, I think we need to:

  • Have accelerate deeply and better integrated.
  • Improve the readability of the code by using better variable names, adding comments wherever applicable, and refactoring some bits wherever applicable.
  • Use this method to upload the trained checkpoints to the Hub.
  • Add tests and docs.
  • Add utility for performing validation inference runs with Weights and Biases.
  • Demonstrate a full run with reasonable results.

More than happy to help with any of these :)

Also, IIUC, Consistency Models can either be trained in isolation or can be used to step-distill a pre-trained model. If that's the case, should we allow both from the training script? Or am I wrong?

@ayushtues
Copy link
Contributor Author

Hey @sayakpaul can you help with the better integration with accelerator? I don't have much idea about how to use it better. Also can you link me to some relevant tests in existing examples which I can refer to

@sayakpaul
Copy link
Member

Hey @sayakpaul can you help with the better integration with accelerator? I don't have much idea about how to use it better.

If you follow one of our examples from here, I think you will have a better idea this for example).

Also can you link me to some relevant tests in existing examples which I can refer to

Here you go: https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py

Comment on lines 375 to 415
else:
model = UNet2DModel(
sample_size= args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
num_class_embeds=1000,
block_out_channels= [32, 64],
attention_head_dim=8,
down_block_types= [
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
],
up_block_types= [
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
],
resnet_time_scale_shift="scale_shift",
upsample_type="resnet",
downsample_type="resnet"
)
target_model = UNet2DModel(
sample_size= args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
num_class_embeds=1000,
block_out_channels= [32, 64],
attention_head_dim=8,
down_block_types= [
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
],
up_block_types= [
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
],
resnet_time_scale_shift="scale_shift",
upsample_type="resnet",
downsample_type="resnet"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we remove this? would prefer to not hardcode a config in the training script

Comment on lines 366 to 367
if args.testing:
config = UNet2DModel.load_config('diffusers/consistency-models-test', subfolder="test_unet")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this? we should be able to pass the preferred model in from the standard cli args

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Aug 12, 2023
@github-actions github-actions bot closed this Aug 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants