-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Consistency models distillation examples #3992
Closed
ayushtues
wants to merge
80
commits into
huggingface:main
from
ayushtues:consistency-models-examples
Closed
Changes from 76 commits
Commits
Show all changes
80 commits
Select commit
Hold shift + click to select a range
799ab23
initial commit
dg845 63b7f01
Improve consistency models sampling implementation.
dg845 f2e53da
Add CMStochasticIterativeScheduler, which implements the multi-step s…
dg845 37e1895
Merge branch 'main2' into consistency_unet
7c1e81f
Add Unet blocks for consistency models
3a151bd
Add conversion script for Unet
b6c5e15
Fix bug in new unet blocks
4e93f09
Fix attention weight loading
d137d11
Merge pull request #1 from ayushtues/consistency_unet
dg845 9ae7669
Make design improvements to ConsistencyModelPipeline and CMStochastic…
dg845 54b287e
make style
dg845 067a9ef
Add initial training script
ayushtues 5a27c2f
Make small random test UNet class conditional and set resnet_time_sca…
dg845 f2783a8
Add support for converting a test UNet and non-class-conditional UNet…
dg845 ed53b85
make style
dg845 a505c6c
Change num_class_embeds to 1000 to better match the original consiste…
dg845 a927a4a
Add support for distillation in pipeline_consistency_models.py.
dg845 b2e6424
Improve consistency model tests:
dg845 c37e302
make style
dg845 a0a164c
Improve ConsistencyModelPipeline:
dg845 0d1de08
make style
dg845 5f4f406
Improve ConsistencyModelPipeline:
dg845 213b25d
make style
dg845 fbe34c3
Fix bug when randomly generating class labels for class-conditional g…
dg845 0e53d8b
Switch CMStochasticIterativeScheduler to configuring a sigma schedule…
dg845 c3b242e
Remove some unused code and make style.
dg845 43e4379
Fix small bug in CMStochasticIterativeScheduler.
dg845 94c99ca
Add expected slices for multistep sampling tests and make them pass.
dg845 0773b27
Work on consistency model fast tests:
dg845 6867a3a
make style
dg845 f0c85d3
Refactor conversion script to make it easier to add more model archit…
dg845 6adb589
Work on ConsistencyModelPipeline tests:
dg845 c213bf7
make style
dg845 3be6e08
Remove the add_noise and add_noise_to_input methods from CMStochastic…
dg845 660909f
Run python utils/check_copies.py --fix_and_overwrite
dg845 ad5abdc
Make fast tests from PipelineTesterMixin pass.
dg845 a1c11d3
make style
dg845 a15c168
Merge branch 'consistency-models-pipeline' into consistency-models-ex…
ayushtues 2dc2fa3
Refactor consistency models pipeline and scheduler:
dg845 7d3dbe3
make style
dg845 55f80ed
Add docstrings and further refactor pipeline and scheduler.
dg845 3234543
make style
dg845 dc4349c
Add initial version of the consistency models documentation.
dg845 787e65a
Minor
ayushtues 50136d0
Merge branch 'consistency-models-pipeline' into consistency-models-ex…
ayushtues 6530b17
Add training code
ayushtues 8f858cb
Fix bugs in training
ayushtues e56b870
Remove some args, add target model
ayushtues 85eb796
Refactor custom timesteps logic following DDPMScheduler/IFPipeline an…
dg845 bf3a405
make style
dg845 1488180
attention weight loading fix
ayushtues fd0a253
Convert current slow tests to use fp16 and flash attention.
dg845 7c03ff6
make style
dg845 80d7745
Add slow tests for normal attention on cuda device.
dg845 0662f63
make style
dg845 cce04c0
Fix attention weights loading
ayushtues c1ba8b9
Merge pull request #3 from ayushtues/fix-attention-loading
dg845 62a49a2
Update consistency model fast tests for new test checkpoints with att…
dg845 58f12ca
make style
dg845 98e1381
Renaming ema model to target
ayushtues baefc87
Add some comments
ayushtues 071f850
apply suggestions
dg845 cd460ca
Remove xformers, refactor ckpt resuming
ayushtues 6968615
Add add_noise method to CMStochasticIterativeScheduler (copied from E…
dg845 9bc3167
Merge branch 'consistency-models-pipeline' into consistency-models-ex…
ayushtues 9a81642
Add input scaling, disable gradients
ayushtues ca15734
Conversion script now outputs pipeline instead of UNet and add suppor…
dg845 a56d3d2
When both timesteps and num_inference_steps are supplied, raise warni…
dg845 a17847e
make style
dg845 8214a33
Add remaining diffusers model checkpoints for models in the original …
dg845 2606d82
apply suggestions from review
dg845 075351a
make style
dg845 0240407
Merge branch 'consistency-models-pipeline' into consistency-models-ex…
ayushtues 88356e6
Merge branch 'main_hf' into consistency-models-examples
ayushtues ea7d75f
Update training script to main, fix timesteps
ayushtues a32b869
Fix bug in timestep ordering
ayushtues 8742e4e
Add review suggestions
ayushtues 943c88b
Integrate accelerator better, change model upload
ayushtues 6b58d81
Fix checkpointing and add test
ayushtues 180b3e8
Remove hardcoded configs, add DiffusionPipeline
ayushtues File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
accelerate>=0.16.0 | ||
torchvision | ||
datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
accelerate launch train_consistency_distillation.py --dataset_name="cifar10" --resolution=32 --center_crop --random_flip --output_dir="cifar10-32" --train_batch_size=16 --num_epochs=100 --gradient_accumulation_steps=1 --learning_rate=1e-4 --lr_warmup_steps=500 --mixed_precision=no --push_to_hub |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't include shell-scripts in the examples.