-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Consistency models distillation examples #3992
[WIP] Consistency models distillation examples #3992
Conversation
…ampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.
[WIP] Add Unet for consistency models
…IterativeScheduler and add initial version of tests.
…le_shift to 'scale_shift' to better match consistency model checkpoints.
…s to the consistency models conversion script.
…ncy models implementation.
- Get small testing checkpoints from hub - Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline - Add onestep, multistep tests for distillation and distillation + class conditional - Add expected image slices for onestep tests
- Add initial support for class-conditional generation - Fix initial sigma for onestep generation - Fix some sigma shape issues
- add latents __call__ argument and prepare_latents method - add check_inputs method - add initial docstrings for ConsistencyModelPipeline.__call__
… and make related changes to the pipeline and tests.
- in pipeline, call self.scheduler.scale_model_input before denoising - get expected slices for Euler and Heun scheduler tests - make Euler test pass - mark Heun test as expected fail because it doesn't support prediction_type "sample" yet - remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@ayushtues, thanks for this work! Let us know when you'd like a review from our end. The PR doesn't have to be super tight for an initial review just as an FYI. |
Hey @sayakpaul yes I think it is good enough for an initial review right now, one major thing we need to figure out is how to do the scaling of the inputs for different teacher models (as mentioned in dg845#2 (comment)). The original paper used EDM as a teacher model, and that was trained using a Karras Scheduler, so it followed the same scaling and processing as the Conistency model, but if we want to use something else like Stable Diffusion or DDPM as the teacher model, need to figure out what input scaling and processing to use for them |
@dg845, @patrickvonplaten @williamberman @sayakpaul do you think we should add examples of other applications of Consistency models mentioned in the paper like inpainting, colorization, super-resolution? They mostly seem to involve some clever processing of the inputs and the outputs of the model, for eg code reference in the original repo : https://github.com/openai/consistency_models/blob/ac278060af7175e37f4c1e79e69fa521234c04de/cm/karras_diffusion.py#L722 |
If we decided to say, distill https://huggingface.co/valhalla/sd-pokemon-model, can't we use the same scheduler for distilling the student (with appropriate changes as needed) and use the same preprocessing steps? I like what @dg845 is proposing here: dg845#2 (comment).
Let's gauge the community interest for that first and then we can definitely consider. |
@@ -0,0 +1,3 @@ | |||
#!/bin/bash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't include shell-scripts in the examples.
@@ -268,6 +268,7 @@ def step( | |||
sample: torch.FloatTensor, | |||
generator: Optional[torch.Generator] = None, | |||
return_dict: bool = True, | |||
use_noise: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to introduce this argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we maybe make this a config attribute instead? use_noise=True
will rarely changed to use_noise=False
during inference - either it's always False or True. Could we add a force_deterministic=True/False
to the config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better to put the section of the function that computes the unnoised prediction in its own function that can be called.
i.e.
def step(...):
denoised = self.denoise(...)
if len(self.timesteps) > 1:
prev_sample = ...
else:
prev_sample = denoised
return prev_sample
def denoise(...):
...
this would avoid extra global config to the class and the training script could just call denoise()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would also work for me - is probably indeed the better option (as my first proposal would force people to update the config after training, which is probs more brittle)
if args.logger == "tensorboard": | ||
if not is_tensorboard_available(): | ||
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") | ||
|
||
elif args.logger == "wandb": | ||
if not is_wandb_available(): | ||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") | ||
import wandb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both of these dependencies need to go to requirements.txt
.
model = UNet2DModel.from_config(config) | ||
target_model = UNet2DModel.from_config(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this, in a sense, is step distillation i.e., we're not essentially compressing the model but reducing the number of steps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it is a distinction without a difference, but I think of it as reducing the number of steps for a (diffusion model, sampler) pair. In my mind, this distinction is important because my guess is that the quality of the sampler affects the quality of the distilled consistency model; if we distill using a sampler which produces bad samples for the diffusion model, the consistency model might learn to faithfully produce similar samples to sampling from the diffusion model using the sampler, but which are still poor quality because the samples it learned from are poor quality. (Conversely, using a better sampler might produce a better consistency model.)
[Another way to phrase it might be that it's compressing the sampling procedure, which depends on both the diffusion model and sampler.]
|
||
|
||
# load the model to distill into a consistency model | ||
teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this path ("google/ddpm-cifar10-32") a CLI arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And it should also work with DiffusionPipeline
.
|
||
|
||
# load the model to distill into a consistency model | ||
teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not clear what model
, target_model
, and teacher_model
mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is as follows (things in square brackets [] are the variable names in the current PR code / terminology follows Section 4/Algorithm 2 in the paper):
- The
teacher_model
$s_{\phi}$ is the already-trained diffusion model we want to distill to a consistency model (for example,google/ddpm-cifar10-32
). - The
model
is a consistency model$\boldsymbol{f_\theta}$ which is used to calculate the current output of the distilled consistency model$\boldsymbol{f_\theta}(x_{t_{n + 1}}, t_{n + 1})$ [model_output
no. 1*] on the noisy sample$x_{t_{n + 1}}$ [noised_image
]. The parameters$\boldsymbol{\theta}$ are updated via gradient descent. (The paper calls this the "online" model, following RL terminology.) - The
target_model
is a consistency model$\boldsymbol{f_{\theta^-}}$ which is used to calculate the distillation target, which we want the consistency model output$\boldsymbol{f_\theta}(x_{t_{n + 1}}, t_{n + 1})$ to match. This is obtained by first running one step of the sampler (in the paper and code, one step of the Karras Heun sampler) on theteacher_model
to obtain a previous noisy sample$\hat{x}_{t_n}^\phi$ [model_output
no. 2*], and then getting thetarget_model
consistency model's output$\boldsymbol{f_{\theta^-}}(\hat{x}_{t_{n}}^\phi, t_{n})$ [model_output
no. 3*] on$\hat{x}_{t_n}^\phi$ . The target parameters$\boldsymbol{\theta^-}$ are updated via an EMA update (using$\boldsymbol{\theta}$ ). (This is analogous to target models in DQN-style RL algorithms, and I believe the EMA update is equivalent to what is sometimes called polyak averaging in RL.)
The loss function
and then we take a gradient descent step with respect to model
's parameters target_model
's parameters model
and target_model
is outputted as the final trained checkpoint in the end; according to the paper, theoretically the parameters for the models should converge during training [e.g.
Some proposals:
- I agree that what
model
,target_model
,teacher_model
are doing might be hard to follow from the code alone, perhaps it might make sense to add comments explaining what each one does when they are initialized? - Rename
model_output
no. 2 toteacher_model_output
andmodel_output
no. 3 totarget_model_output
. (If we want to renamemodel
[see below], perhapsmodel_output
no. 1 can be renamed similarly.) - Perhaps renaming
model
to something likeonline_model
orcurrent_consistency_model
might be more clear? (Not sure on this one, especially thecurrent_consistency_model
suggestion, since I'm not sure if it's accurate; would feel more comfortable with this ifmodel
is what is outputted as the final trained checkpoint, see above.)
(*) Technically, these quantities correspond to distiller
, denoised_image
, and distiller_target
in the code, respectively, but they're all preceded by a calculated raw U-Net output named model_output
.
(**) The loss function in Algorithm 2 includes a timestep-dependent weight function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for explaining @dg845. See my responses below:
I agree that what model, target_model, teacher_model are doing might be hard to follow from the code alone, perhaps it might make sense to add comments explaining what each one does when they are initialized?
100% agree. The math is heavy enough and we should grab the opportunity here to make it easily digestible to the end users.
Rename model_output no. 2 to teacher_model_output and model_output no. 3 to target_model_output.
Works for me.
Perhaps renaming model to something like online_model or current_consistency_model might be more clear? (Not sure on this one, especially the current_consistency_model suggestion, since I'm not sure if it's accurate; would feel more comfortable with this if model is what is outputted as the final trained checkpoint, see above.)
Let's maybe try to output EMA'd model only since that is more common in practice?
This is obtained by first running one step of the sampler (in the paper and code, one step of the Karras Heun sampler) on the teacher_model to obtain a previous noisy sample
[model_output no. 2*]
Have we checked if the teacher_model
performs okay with that sampler? Does it matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe try to output EMA'd model only since that is more common in practice?
Sounds good.
Have we checked if the
teacher_model
performs okay with that sampler? Does it matter?
I think there is a bug in the current code where we use noise_scheduler
when performing one step of sampling on the teacher_model
:
diffusers/examples/consistency_models/train_consistency_distillation.py
Lines 554 to 557 in 8742e4e
teacher_model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample | |
teacher_denoiser = noise_scheduler.step( | |
teacher_model_output, timestep_prev, samples, use_noise=False | |
).prev_sample |
because noise_scheduler
is a CMStochasticIterativeScheduler
:
noise_scheduler = CMStochasticIterativeScheduler() |
which isn't compatible with diffusion models. I think the scheduler should either be teacher_scheduler = DiffusionPipeline.from_pretrained(<model_id>).scheduler
(if we want to use the scheduler the teacher_model
was trained with) or HeunDiscreteScheduler
(if we follow the original implementation). (We would still need noise_scheduler
to get the model outputs for the consistency models model
and teacher_model
.)
As for whether the Heun scheduler works well with google/ddpm-cifar10-32
, which is presumably trained with DDPMScheduler
, I'm not sure whether it will affect the training results. I guess practically speaking we could try it out as is, and if the resulting model doesn't create reasonable samples, then we can try either switching the sampler for getting teacher samples to DDPMScheduler
or switching the teacher_model
to something that was trained with HeunDiscreteScheduler
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a plan! Thanks so much for being generous with your explanations!
# The trackers initializes automatically on the main process. | ||
if accelerator.is_main_process: | ||
run = os.path.split(__file__)[-1].split(".")[0] | ||
accelerator.init_trackers(run) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually perform this initialization like so:
accelerator.init_trackers("dreambooth-lora", config=vars(args)) |
|
||
with accelerator.accumulate(model): | ||
# Predict the noise residual | ||
model_output = model(noise_scheduler.scale_model_input(noised_image, timestep), timestep, class_labels=labels).sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are three model_output
variables here. Let's maybe try to give them more appropriate names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #3992 (comment) for some suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good.
From my end, I think we need to:
- Have
accelerate
deeply and better integrated. - Improve the readability of the code by using better variable names, adding comments wherever applicable, and refactoring some bits wherever applicable.
- Use this method to upload the trained checkpoints to the Hub.
- Add tests and docs.
- Add utility for performing validation inference runs with Weights and Biases.
- Demonstrate a full run with reasonable results.
More than happy to help with any of these :)
Also, IIUC, Consistency Models can either be trained in isolation or can be used to step-distill a pre-trained model. If that's the case, should we allow both from the training script? Or am I wrong?
Hey @sayakpaul can you help with the better integration with |
If you follow one of our examples from here, I think you will have a better idea this for example).
Here you go: https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py |
else: | ||
model = UNet2DModel( | ||
sample_size= args.resolution, | ||
in_channels=3, | ||
out_channels=3, | ||
layers_per_block=2, | ||
num_class_embeds=1000, | ||
block_out_channels= [32, 64], | ||
attention_head_dim=8, | ||
down_block_types= [ | ||
"ResnetDownsampleBlock2D", | ||
"AttnDownBlock2D", | ||
], | ||
up_block_types= [ | ||
"AttnUpBlock2D", | ||
"ResnetUpsampleBlock2D", | ||
], | ||
resnet_time_scale_shift="scale_shift", | ||
upsample_type="resnet", | ||
downsample_type="resnet" | ||
) | ||
target_model = UNet2DModel( | ||
sample_size= args.resolution, | ||
in_channels=3, | ||
out_channels=3, | ||
layers_per_block=2, | ||
num_class_embeds=1000, | ||
block_out_channels= [32, 64], | ||
attention_head_dim=8, | ||
down_block_types= [ | ||
"ResnetDownsampleBlock2D", | ||
"AttnDownBlock2D", | ||
], | ||
up_block_types= [ | ||
"AttnUpBlock2D", | ||
"ResnetUpsampleBlock2D", | ||
], | ||
resnet_time_scale_shift="scale_shift", | ||
upsample_type="resnet", | ||
downsample_type="resnet" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we remove this? would prefer to not hardcode a config in the training script
if args.testing: | ||
config = UNet2DModel.load_config('diffusers/consistency-models-test', subfolder="test_unet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this? we should be able to pass the preferred model in from the standard cli args
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Add distillation training scripts for Consistency models
Continuation of dg845#2, adding distillation examples for the recently added Consistency Models (paper, original code, diffusers pipeline)
Model/Pipeline Description
Consistency Models (paper, code) are a new family of generative models similar to continuous-time diffusion models which support fast one-step generation. Diffusion models can be distilled into a consistency model for faster sampling, and consistency models can also be trained from scratch. From the paper abstract:
In this PR, we implement the distillation procedure to distill a diffusion model into a consistency model as described in the paper
TODO
CC
@dg845
@williamberman
@patrickvonplaten