diff --git a/examples/consistency_models/requirements.txt b/examples/consistency_models/requirements.txt new file mode 100644 index 000000000000..bc7b6a7f238b --- /dev/null +++ b/examples/consistency_models/requirements.txt @@ -0,0 +1,5 @@ +accelerate>=0.16.0 +torchvision +datasets +wandb +tensrboard diff --git a/examples/consistency_models/train_consistency_distillation.py b/examples/consistency_models/train_consistency_distillation.py new file mode 100644 index 000000000000..551b5d0c1b1e --- /dev/null +++ b/examples/consistency_models/train_consistency_distillation.py @@ -0,0 +1,727 @@ +import argparse +import inspect +import logging +import math +import os +from pathlib import Path +from typing import Optional +import shutil +import accelerate +import datasets +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from datasets import load_dataset +from huggingface_hub import HfFolder, Repository, create_repo, whoami, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +import wandb +import diffusers +from diffusers import DiffusionPipeline, UNet2DModel, CMStochasticIterativeScheduler, ConsistencyModelPipeline +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available, is_xformers_available + + +#Copied from examples/unconditional_image_generation/train_unconditional.py for now + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + repo_folder=None, + pipeline: ConsistencyModelPipeline = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- consistency models +- diffusers +inference: true +--- + """ + model_card = f""" +# Consistency Model - {repo_id} + +This is a consistency model distilled from {base_model}. +You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that HF Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--model_config_name_or_path", + type=str, + default=None, + help="The config of the UNet model to train, leave as None to use standard Consistency configuration.", + ) + parser.add_argument( + "--pretrained_teacher_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models to be used as teacher model", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="ddpm-model-64", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--resolution", + type=int, + default=64, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + default=False, + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") + parser.add_argument( + "--save_model_epochs", type=int, default=10, help="How often to save the model during training." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer." + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.") + parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") + parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") + parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--testing", action="store_true", help="If running a test") + + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." + ) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + help=( + "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)" + " for experiment tracking and logging of model metrics and model checkpoints" + ), + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(args): + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.logger, + project_config=accelerator_project_config, + ) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + target_model_ema.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel) + target_model_ema.load_state_dict(load_model.state_dict()) + target_model_ema.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Use the config if provided, model and target model have the same structure + if args.model_config_name_or_path is not None: + config = UNet2DModel.load_config(args.model_config_name_or_path) + # Else use a default config + else: + config = UNet2DModel.load_config("ayushtues/consistency_tiny_unet") + model = UNet2DModel.from_config(config) + target_model = UNet2DModel.from_config(config) + noise_scheduler = CMStochasticIterativeScheduler() + # load the model to distill into a consistency model + teacher_pipeline = DiffusionPipeline.from_pretrained(args.pretrained_teacher_model_name_or_path) + teacher_model = teacher_pipeline.unet + teacher_scheduler = teacher_pipeline.scheduler + num_scales = 40 + + # Check that all trainable models are in full precision + low_precision_error_string = ( + "Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training. copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(model).dtype != torch.float32: + raise ValueError( + f"Consistency Model loaded as datatype {accelerator.unwrap_model(model).dtype}. {low_precision_error_string}" + ) + + if accelerator.unwrap_model(teacher_model).dtype != torch.float32: + raise ValueError( + f"Teacher_model loaded as datatype {accelerator.unwrap_model(teacher_model).dtype}." + f" {low_precision_error_string}" + ) + + + # Create EMA for the model, this is the target model in the paper + target_model_ema = EMAModel( + model.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + model_cls=UNet2DModel, + model_config=model.config, + ) + + # Initialize the optimizer # TODO: Change this to match the paper, RAdam + optimizer = torch.optim.AdamW( + model.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + split="train", + ) + else: + dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets and DataLoaders creation. + augmentations = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def transform_images(examples): + img_key = "image" if "image" in examples else "img" + images = [augmentations(image.convert("RGB")) for image in examples[img_key]] + labels = [torch.tensor(label) for label in examples["label"]] + return {"input": images, "labels": labels} + + logger.info(f"Dataset size: {len(dataset)}") + dataset.set_transform(transform_images) + train_dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) + + # Initialize the learning rate scheduler + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs), + ) + + # Prepare everything with our `accelerator`. + model, optimizer, noise_scheduler, train_dataloader, lr_scheduler, teacher_model, target_model = accelerator.prepare( + model, optimizer, noise_scheduler, train_dataloader, lr_scheduler, teacher_model, target_model + ) + noise_scheduler.set_timesteps(num_scales, device=accelerator.device) + target_model_ema.to(accelerator.device) # TODO accelerate.prepare doesn't work on this for some reason + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("consistency-distillation", vars(args)) + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + max_train_steps = args.num_epochs * num_update_steps_per_epoch + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num Epochs = {args.num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + timesteps = noise_scheduler.timesteps + teacher_scheduler.set_timesteps(timesteps=timesteps, device=accelerator.device) + sigmas = noise_scheduler.sigmas # in reverse order, sigma0 is sigma_max + + + # Train! + for epoch in range(first_epoch, args.num_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + clean_images = batch["input"] + labels = batch["labels"] + # Sample noise that we'll add to the images + noise = torch.randn(clean_images.shape).to(accelerator.device) + # Sample a random timestep for each image, TODO - allow different timesteps in a batch + index = torch.randint( + 0, noise_scheduler.config.num_train_timesteps-1, (1,), device=accelerator.device + ).long() + # timestep is the scaled timestep, sigma is the unscaled timestep + timestep = timesteps[index+1] + sigma = sigmas[index+1] + timestep_prev = timesteps[index] + sigma_prev = sigmas[index] + # add noise expects the scaled timestep only and internally converts to sigma + noised_image = noise_scheduler.add_noise(clean_images, noise, timestep) + target_model_ema.copy_to(target_model.parameters()) + + with accelerator.accumulate(model): + # Predict the noise residual + model_output = model(noise_scheduler.scale_model_input(noised_image, timestep), timestep, class_labels=labels).sample + distiller = noise_scheduler.step( + model_output, timestep, noised_image, use_noise=False + ).prev_sample + + with torch.no_grad(): + # Heun Solver to get previous timestep image using teacher model + # TODO - make this cleaner + samples = noised_image + x = samples + teacher_model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), timestep, class_labels=labels).sample + teacher_denoiser = noise_scheduler.step( + teacher_model_output, timestep, x, use_noise=False + ).prev_sample + d = (x - teacher_denoiser) / sigma[(...,) + (None,) * 3] + samples = x + d * (sigma_prev - sigma)[(...,) + (None,) * 3] + # We probably want to use Sigma for an arbitrary teacher model here, since that corresponds to the unscaled timestep + # We just want a denoised image from an input x, t using the teacher model, since that is used in the score function + # So we should figure out how to get the denoised image from the teacher model + teacher_model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample + teacher_denoiser = noise_scheduler.step( + teacher_model_output, timestep_prev, samples, use_noise=False + ).prev_sample + next_d = (samples - teacher_denoiser) / sigma_prev[(...,) + (None,) * 3] + denoised_image = x + (d + next_d) * ((sigma_prev - sigma) /2)[(...,) + (None,) * 3] + # get output from target model + target_model_output = target_model(noise_scheduler.scale_model_input(denoised_image, timestep_prev), timestep_prev, class_labels=labels).sample + distiller_target = noise_scheduler.step( + target_model_output, timestep_prev, denoised_image, use_noise=False + ).prev_sample + + loss = F.mse_loss(distiller, distiller_target) + loss = loss.mean() + + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + target_model_ema.step(model.parameters()) + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + logs["ema_decay"] = target_model_ema.cur_decay_value + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + progress_bar.close() + + accelerator.wait_for_everyone() + + # Generate sample images for visual inspection + if accelerator.is_main_process: + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + unet = accelerator.unwrap_model(model) + target_model_ema.store(unet.parameters()) + target_model_ema.copy_to(unet.parameters()) + + pipeline = ConsistencyModelPipeline( + unet=unet, + scheduler=noise_scheduler, + ) + + generator = torch.Generator(device=pipeline.device).manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + images = pipeline( + generator=generator, + batch_size=args.eval_batch_size, + num_inference_steps=1, + output_type="np", + ).images + + target_model_ema.restore(unet.parameters()) + + # denormalize the images and save to tensorboard + images_processed = (images * 255).round().astype("uint8") + + if args.logger == "tensorboard": + if is_accelerate_version(">=", "0.17.0.dev0"): + tracker = accelerator.get_tracker("tensorboard", unwrap=True) + else: + tracker = accelerator.get_tracker("tensorboard") + tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch) + elif args.logger == "wandb": + # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files + accelerator.get_tracker("wandb").log( + {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, + step=global_step, + ) + + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + # save the model + unet = accelerator.unwrap_model(model) + target_model_ema.store(unet.parameters()) + target_model_ema.copy_to(unet.parameters()) + + pipeline = ConsistencyModelPipeline( + unet=unet, + scheduler=noise_scheduler, + ) + + pipeline.save_pretrained(args.output_dir) + target_model_ema.restore(unet.parameters()) + + if accelerator.is_main_process and args.push_to_hub: + unet = accelerator.unwrap_model(model) + target_model_ema.copy_to(unet.parameters()) + + pipeline = ConsistencyModelPipeline( + unet=unet, + scheduler=noise_scheduler, + ) + + generator = torch.Generator(device=pipeline.device).manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + images = pipeline( + generator=generator, + batch_size=args.eval_batch_size, + num_inference_steps=1, + output_type="pil", + ).images + + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_teacher_model_name_or_path, + repo_folder=args.output_dir, + pipeline=pipeline, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/test_examples.py b/examples/test_examples.py index d11841350064..f1096f9e1289 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -96,6 +96,29 @@ def test_train_unconditional(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_train_consistency(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/consistency_models/train_consistency_distillation.py + --dataset_name hf-internal-testing/dummy_image_class_data + --resolution 32 + --output_dir {tmpdir} + --train_batch_size 2 + --num_epochs 1 + --gradient_accumulation_steps 1 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --testing + --pretrained_teacher_model_name_or_path google/ddpm-cifar10-32 + """.split() + + run_command(self._launch_args + test_args, return_stdout=True) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + + def test_textual_inversion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index fb296054d65b..f580c28453cd 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -268,6 +268,7 @@ def step( sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, + use_noise: bool = True ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -331,7 +332,7 @@ def step( # 2. Sample z ~ N(0, s_noise^2 * I) # Noise is not used for onestep sampling. - if len(self.timesteps) > 1: + if len(self.timesteps) > 1 and use_noise: noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator )