From 27082465c7ff38106a4ee3964152a4ab2073e3ce Mon Sep 17 00:00:00 2001 From: kliyer Date: Wed, 2 Oct 2024 16:38:18 +0200 Subject: [PATCH] add training code --- configs/experiment/sample_struct_attn.yaml | 41 +++ configs/experiment/train_struct_sd15.yaml | 38 +++ configs/experiment/train_style_sd15.yaml | 37 +++ configs/model/sd15.yaml | 2 - configs/model/sdxl.yaml | 1 - configs/sample.yaml | 3 +- configs/train.yaml | 37 +++ src/data/local.py | 53 +++- src/model.py | 116 +++----- train.py | 297 +++++++++++++++++++++ 10 files changed, 536 insertions(+), 89 deletions(-) create mode 100644 configs/experiment/sample_struct_attn.yaml create mode 100644 configs/experiment/train_struct_sd15.yaml create mode 100644 configs/experiment/train_style_sd15.yaml create mode 100644 configs/train.yaml create mode 100644 train.py diff --git a/configs/experiment/sample_struct_attn.yaml b/configs/experiment/sample_struct_attn.yaml new file mode 100644 index 0000000..eb6f94c --- /dev/null +++ b/configs/experiment/sample_struct_attn.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +defaults: + - /lora@lora.struct: struct_attn + - override /lora/encoder@lora.struct.encoder: midas # hed + - override /model: sd15 + - override /data: local + - _self_ + + +size: 512 +n_samples: 4 + +save_grid: true +log_cond: true + +data: + caption_from_name: true + caption_prefix: "a picture of " + directories: + - data + +model: + guidance_scale: 7.5 + +prompt: '' + +lora: + struct: + cfg: false + # ckpt_path: checkpoints/sd15-hed-128-only-res + ckpt_path: checkpoints/sd15-depth-02-self + config: + + c_dim: 128 + rank: 0.2 + adaption_mode: only_self + +tag: struct + +bf16: true \ No newline at end of file diff --git a/configs/experiment/train_struct_sd15.yaml b/configs/experiment/train_struct_sd15.yaml new file mode 100644 index 0000000..e7d7ece --- /dev/null +++ b/configs/experiment/train_struct_sd15.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +defaults: + - /lora@lora.struct: struct + - override /lora/encoder@lora.struct.encoder: midas + - override /model: sd15 + - override /data: local + - _self_ + +data: + batch_size: 8 + caption_from_name: true + caption_prefix: "a picture of " + directories: + - data + +lora: + struct: + optimize: true + + +size: 512 + +log_c: true + +val_batches: 4 + +learning_rate: 1e-4 + +ckpt_steps: 3000 +val_steps: 3000 + +epochs: 10 + +prompt: null + +# model: +# guidance_scale: 1.5 \ No newline at end of file diff --git a/configs/experiment/train_style_sd15.yaml b/configs/experiment/train_style_sd15.yaml new file mode 100644 index 0000000..a3a4cc2 --- /dev/null +++ b/configs/experiment/train_style_sd15.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +defaults: + - /lora@lora.style: style + - override /model: sd15 + - override /data: local + - _self_ + +data: + batch_size: 8 + caption_from_name: true + caption_prefix: "a picture of " + directories: + - data + +val_batches: 1 + +lora: + style: + # rank: 208 + # rank: 16 + adaption_mode: only_cross + optimize: true + +size: 512 + +learning_rate: 1e-4 + +ckpt_steps: 1000 +val_steps: 1000 + +epochs: 100 + +prompt: null + +# model: +# guidance_scale: 1.5 \ No newline at end of file diff --git a/configs/model/sd15.yaml b/configs/model/sd15.yaml index c6edfee..a20a8ac 100644 --- a/configs/model/sd15.yaml +++ b/configs/model/sd15.yaml @@ -4,6 +4,4 @@ defaults: _target_: src.model.SD15 pipeline_type: diffusers.StableDiffusionPipeline model_name: runwayml/stable-diffusion-v1-5 -use_embeds: ${use_embeds} -dtype: fp32 local_files_only: ${local_files_only} \ No newline at end of file diff --git a/configs/model/sdxl.yaml b/configs/model/sdxl.yaml index e3ab6b5..94be90b 100644 --- a/configs/model/sdxl.yaml +++ b/configs/model/sdxl.yaml @@ -4,5 +4,4 @@ defaults: _target_: src.model.SDXL pipeline_type: diffusers.StableDiffusionXLPipeline model_name: stabilityai/stable-diffusion-xl-base-1.0 -use_embeds: ${use_embeds} local_files_only: ${local_files_only} \ No newline at end of file diff --git a/configs/sample.yaml b/configs/sample.yaml index 814d663..ea4b48d 100644 --- a/configs/sample.yaml +++ b/configs/sample.yaml @@ -22,5 +22,4 @@ hydra: job: chdir: true -local_files_only: false -use_embeds: false \ No newline at end of file +local_files_only: false \ No newline at end of file diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000..de41a63 --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,37 @@ +defaults: + - data: ??? + - model: ??? + - _self_ + - experiment: null + +size: ??? +max_train_steps: null +epochs: 20 +learning_rate: 1e-4 + +lr_warmup_steps: 0 +lr_scheduler: constant + +prompt: null +gradient_accumulation_steps: 1 + +ckpt_steps: 1000 +val_steps: 1000 +val_images: 4 +seed: 42 +n_samples: 4 + + + +tag: '' + +local_files_only: false + +hydra: + run: + dir: outputs/train/${tag}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: outputs/train/${tag}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + job: + chdir: true \ No newline at end of file diff --git a/src/data/local.py b/src/data/local.py index 9badda2..31b3ae4 100644 --- a/src/data/local.py +++ b/src/data/local.py @@ -66,22 +66,53 @@ def __getitem__(self, idx: int): class ImageDataModule: - def __init__(self, directories: list[str], transform: list, batch_size: int = 32, caption_from_name: bool = False, caption_prefix: str = ""): + def __init__( + self, + directories: list[str], + transform: list, + val_directories: list[str] = [], + batch_size: int = 32, + val_batch_size: int = 1, + workers: int = 4, + val_workers: int = 1, + caption_from_name: bool = False, + caption_prefix: str = "", + ): super().__init__() + + self.batch_size = batch_size + self.val_batch_size = val_batch_size + self.workers = workers + self.val_workers = val_workers + project_root = Path(os.path.abspath(__file__)).parent.parent.parent + + self.train_dataset = ZipDataset( + [ + ImageFolderDataset( + directory=Path(project_root, d), + transform=transforms.Compose(transform), + caption_from_name=caption_from_name, + caption_prefix=caption_prefix, + ) + for d in directories + ] + ) + self.val_dataset = ZipDataset( - [ImageFolderDataset(directory=Path(project_root, d), transform=transforms.Compose(transform), caption_from_name=caption_from_name, caption_prefix=caption_prefix) for d in directories] + [ + ImageFolderDataset( + directory=Path(project_root, d), + transform=transforms.Compose(transform), + caption_from_name=caption_from_name, + caption_prefix=caption_prefix, + ) + for d in val_directories + ] ) - self.batch_size = batch_size def train_dataloader(self): - raise Exception("Not implemented") + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers) def val_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False) - - def test_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False) - - def predict_dataloader(self): - raise Exception("Not implemented") + return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.val_workers) diff --git a/src/model.py b/src/model.py index dd4db85..b727333 100644 --- a/src/model.py +++ b/src/model.py @@ -56,13 +56,11 @@ def __init__( self, pipeline_type: str, model_name: str, - dtype: str = "fp32", local_files_only: bool = True, c_dropout: float = 0.05, guidance_scale: float = 7.5, use_controlnet: bool = False, annotator: None | nn.Module = None, - use_embeds: bool = False, tiny_vae: bool = False, ) -> None: super().__init__() @@ -80,7 +78,6 @@ def __init__( self.guidance_scale = guidance_scale self.use_controlnet = use_controlnet - self.use_embeds = use_embeds addition_config = {} # Note that this requires the controlnet pipe which also has to be set in the config @@ -89,21 +86,11 @@ def __init__( vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", local_files_only=local_files_only) addition_config["vae"] = vae - if dtype == "fp16": - addition_config["torch_dtype"] = torch.float16 - addition_config["variant"] = "fp16" - - # is this needed? - # vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - # note that this might conflict with the tiny vae - if self.use_controlnet: assert annotator is not None, "Need annotator for controlnet" controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-depth", - # "lllyasviel/control_v11f1p_sd15_depth", - # "lllyasviel/sd-controlnet-hed", use_safetensors=True, local_files_only=local_files_only, **addition_config, @@ -137,33 +124,17 @@ def __init__( self.unet = self.pipe.unet self.unet.requires_grad_(False) - if not use_embeds: - self.vae = self.pipe.vae - self.text_encoder = self.pipe.text_encoder - self.tokenizer = self.pipe.tokenizer + self.vae = self.pipe.vae + self.text_encoder = self.pipe.text_encoder + self.tokenizer = self.pipe.tokenizer - self.vae = self.pipe.vae - self.text_encoder.requires_grad_(False) - - # handle sdxl case - if hasattr(self.pipe, "text_encoder_2"): - self.text_encoder_2 = self.pipe.text_encoder_2 - self.text_encoder_2.requires_grad_(False) - else: - # maybe only delete encoder - # such that decoder can be used for validation - self.vae = self.pipe.vae - self.vae = self.pipe.vae + self.vae = self.pipe.vae + self.text_encoder.requires_grad_(False) - del self.pipe.vae.encoder # keep decoder for val samples - del self.pipe.text_encoder - del self.pipe.tokenizer - - if hasattr(self.pipe, "text_encoder_2"): - print("deleting sdxl text encoder 2") - del self.pipe.text_encoder_2 - # this is needed otherwise some weird ref is left - self.pipe.text_encoder_2 = None + # handle sdxl case + if hasattr(self.pipe, "text_encoder_2"): + self.text_encoder_2 = self.pipe.text_encoder_2 + self.text_encoder_2.requires_grad_(False) def add_lora_to_unet( self, @@ -234,7 +205,11 @@ def add_lora_to_unet( if adaption_mode == "b-lora_style" and ("up_blocks.0.attentions.1" in path and "attn" in path): _continue = False - if adaption_mode == "b-lora" and ("up_blocks.0.attentions.0" in path or "up_blocks.0.attentions.1" in path) and "attn" in path: + if ( + adaption_mode == "b-lora" + and ("up_blocks.0.attentions.0" in path or "up_blocks.0.attentions.1" in path) + and "attn" in path + ): _continue = False # supposed setting content to have no effect @@ -242,13 +217,25 @@ def add_lora_to_unet( # class_config["lora_scale"] = 0.0 # "down_blocks.2.attentions.1" in path or - if adaption_mode == "sdxl_inner" and ("mid_block" in path or "up_blocks.0.attentions.0" in path or "up_blocks.0.attentions.1" in path) and "attn2" in path: + if ( + adaption_mode == "sdxl_inner" + and ("mid_block" in path or "up_blocks.0.attentions.0" in path or "up_blocks.0.attentions.1" in path) + and "attn2" in path + ): _continue = False - if adaption_mode == "sdxl_cross" and ("down_blocks.2" in path or "up_blocks.0" in path or "mid_block" in path) and "attn2" in path: + if ( + adaption_mode == "sdxl_cross" + and ("down_blocks.2" in path or "up_blocks.0" in path or "mid_block" in path) + and "attn2" in path + ): _continue = False - if adaption_mode == "sdxl_self" and ("down_blocks.2" in path or "up_blocks.0" in path or "mid_block" in path) and "attn1" in path: + if ( + adaption_mode == "sdxl_self" + and ("down_blocks.2" in path or "up_blocks.0" in path or "mid_block" in path) + and "attn1" in path + ): _continue = False if _continue: @@ -485,7 +472,9 @@ def forward( dp.set_batch(mapped_cond) # Predict the noise residual - model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states=prompt_embeds, **additional_inputs).sample + model_pred = self.unet( + noisy_latents, timesteps, encoder_hidden_states=prompt_embeds, **additional_inputs + ).sample # get x0 prediction alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(device=model_pred.device, dtype=model_pred.dtype) @@ -572,7 +561,9 @@ def sample_custom( device = self.unet.device - prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(prompt, device, num_images_per_prompt, True) # do cfg + prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( + prompt, device, num_images_per_prompt, True + ) # do cfg dtype = prompt_embeds.dtype # for cfg @@ -801,30 +792,16 @@ def forward( B = imgs.shape[0] - if self.use_embeds: - assert batch is not None, "batch must be provided when use_embeds is True" + with torch.no_grad(): + # Convert images to latent space + imgs = imgs.to(self.unet.device) + latents = self.pipe.vae.encode(imgs).latent_dist.sample() + latents = latents * self.pipe.vae.config.scaling_factor - latents = batch["latents.npy"].to(self.unet.device) + # prompt dropout + prompts = ["" if random.random() < self.c_dropout else p for p in prompts] - add_time_ids = self.compute_time_ids(self.unet.device, torch.float32) - add_time_ids = add_time_ids.to(self.unet.device).repeat(B, 1) - - c = { - "prompt_embeds": batch["prompt_embeds.npy"].to(self.unet.device), - "add_text_embeds": batch["pooled_prompt_embeds.npy"].to(self.unet.device), - "add_time_ids": add_time_ids, - } - else: - with torch.no_grad(): - # Convert images to latent space - imgs = imgs.to(self.unet.device) - latents = self.pipe.vae.encode(imgs).latent_dist.sample() - latents = latents * self.pipe.vae.config.scaling_factor - - # prompt dropout - prompts = ["" if random.random() < self.c_dropout else p for p in prompts] - - c = self.get_conditioning(prompts, B, latents.device, latents.dtype) + c = self.get_conditioning(prompts, B, latents.device, latents.dtype) unet_added_conditions = { "time_ids": c["add_time_ids"], @@ -922,13 +899,6 @@ def sample( prompt_embeds = None pooled_prompt_embeds = None - if self.use_embeds: - assert batch is not None, "batch must be provided when use_embeds is True" - - prompt_embeds = batch["prompt_embeds.npy"].to(self.unet.device) - pooled_prompt_embeds = batch["pooled_prompt_embeds.npy"].to(self.unet.device) - - prompt = None # we have to do two separate forward passes for the cfg with the loras # add our lora conditioning diff --git a/train.py b/train.py new file mode 100644 index 0000000..632e1d5 --- /dev/null +++ b/train.py @@ -0,0 +1,297 @@ +import hydra +import math +from src.model import ModelBase +from diffusers.optimization import get_scheduler +import torch +from accelerate import Accelerator +from tqdm.auto import tqdm +from pathlib import Path +import numpy as np +import torchvision.transforms.functional as TF +from accelerate.logging import get_logger +import signal +import einops +import os +import traceback +from functools import reduce + +from src.utils import add_lora_from_config, save_checkpoint + + +torch.set_float32_matmul_precision("high") + + +stop_training = False + + +def signal_handler(sig, frame): + global stop_training + stop_training = True + print("got stop signal") + + +@hydra.main(config_path="configs", config_name="train") +def main(cfg): + signal.signal(signal.SIGUSR1, signal_handler) + # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + output_path = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) + + accelerator = Accelerator( + project_dir=output_path / "logs", + log_with="tensorboard", + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + mixed_precision="bf16", + ) + + logger = get_logger(__name__) + + logger.info("==================================") + logger.info(cfg) + logger.info(output_path) + + cfg = hydra.utils.instantiate(cfg) + model: ModelBase = cfg.model + + model = model.to(accelerator.device) + model.pipe.to(accelerator.device) + n_loras = len(cfg.lora.keys()) + + cfg_mask = add_lora_from_config(model, cfg, accelerator.device) + + if cfg.get("gradient_checkpointing", False): + model.unet.enable_gradient_checkpointing() + + dm = cfg.data + + train_dataloader = dm.train_dataloader() + val_dataloader = dm.val_dataloader() + + mappers_params = list( + filter(lambda p: p.requires_grad, reduce(lambda x, y: x + list(y.parameters()), model.mappers, [])) + ) + encoder_params = list( + filter(lambda p: p.requires_grad, reduce(lambda x, y: x + list(y.parameters()), model.encoders, [])) + ) + + optimizer = torch.optim.AdamW( + model.params_to_optimize + mappers_params + encoder_params, + lr=cfg.learning_rate, + ) + + lr_scheduler = get_scheduler( + cfg.lr_scheduler, + optimizer=optimizer, + ) + + logger.info(f"Number params Mapper Network(s) {sum(p.numel() for p in mappers_params):,}") + logger.info(f"Number params Encoder Network(s) {sum(p.numel() for p in encoder_params):,}") + logger.info(f"Number params all LoRAs(s) {sum(p.numel() for p in model.params_to_optimize):,}") + + logger.info("init trackers") + if accelerator.is_main_process: + accelerator.init_trackers("tensorboard") + + logger.info("prepare network") + + prepared = accelerator.prepare( + *model.mappers, + *model.encoders, + model.unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) + + mappers = prepared[: len(model.mappers)] + encoders = prepared[len(model.mappers) : len(model.mappers) + len(model.encoders)] + (unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) = prepared[ + len(model.mappers) + len(model.encoders) : + ] + model.unet = unet + model.mappers = mappers + model.encoders = encoders + + try: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + + if cfg.get("max_train_steps", None) is None: + max_train_steps = cfg.epochs * num_update_steps_per_epoch + else: + max_train_steps = cfg.max_train_steps + except: + max_train_steps = 10000000 + + global_step = 0 + progress_bar = tqdm( + range(global_step, max_train_steps), + disable=not accelerator.is_main_process, + ) + progress_bar.set_description("Steps") + + logger.info("start training") + for epoch in range(cfg.epochs): + logger.info("new epoch") + unet.train() + map(lambda m: m.train(), mappers) + map(lambda m: m.train(), encoders) + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet, *mappers, *encoders): + imgs = batch["jpg"] + imgs = imgs.to(accelerator.device) + imgs = imgs.clip(-1.0, 1.0) + B = imgs.shape[0] + + cs = [imgs] * n_loras + + if cfg.get("prompt", None) is not None: + prompts = [cfg.prompt] * B + else: + prompts = batch["caption"] + + # cfg mask to always true such that the model always learns dropout + model_pred, loss, x0, _ = model.forward_easy( + imgs, + prompts, + cs, + cfg_mask=[True for _ in cfg_mask], + batch=batch, + ) + + accelerator.backward(loss) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs, refresh=False) + accelerator.log(logs, step=global_step) + + # after every gradient update step + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % cfg.val_steps != 0 and not stop_training: + continue + + # VALIDATION + with torch.no_grad(): + try: + unet.eval() + map(lambda m: m.eval(), mappers) + map(lambda m: m.eval(), encoders) + + generator = torch.Generator(device=accelerator.device).manual_seed(cfg.seed) + + val_prompts = [] + for i, val_batch in enumerate(val_dataloader): + + B = val_batch["jpg"].shape[0] + + if i >= cfg.get("val_batches", 4): + break + + if cfg.get("prompt", None) is not None: + prompts = [cfg.prompt] * B + else: + prompts = val_batch["caption"] + + val_prompts = prompts + + imgs = val_batch["jpg"] + imgs = imgs.to(accelerator.device) + imgs = imgs.clip(-1.0, 1.0) + + cs = [imgs] * n_loras + + pipeline_args = { + "prompt": prompts, + "num_images_per_prompt": 1, + "cs": cs, + "generator": generator, + "cfg_mask": cfg_mask, + "batch": val_batch, + } + + preds = model.sample(**pipeline_args) + + if accelerator.is_main_process: + # IMAGE saving + if cfg.get("log_c", False): + # ALWAYS in [0, 1] + lp = model.encoders[0](cs[-1]).cpu() + else: + lp = (imgs.cpu() + 1) / 2 + + lp = torch.nn.functional.interpolate( + lp, + size=(cfg.size, cfg.size), + mode="bicubic", + align_corners=False, + ) + + log_cond = TF.to_pil_image(einops.rearrange(lp, "b c h w -> c h (b w) ")) + log_cond = log_cond.convert("RGB") + log_cond = np.asarray(log_cond) + + log_pred = np.concatenate( + [np.asarray(img.resize((cfg.size, cfg.size))) for img in preds], + axis=1, + ) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.concatenate([log_cond, log_pred], axis=0) + tracker.writer.add_images( + "validation", + np_images, + global_step, + dataformats="HWC", + ) + tracker.writer.add_scalar("lr", lr_scheduler.get_last_lr()[0], global_step) + tracker.writer.add_scalar("loss", loss.detach().item(), global_step) + tracker.writer.add_text( + "prompts", + "------------".join(val_prompts), + global_step, + ) + + except Exception as e: + print("!!!!!!!!!!!!!!!!!!!") + print("ERROR IN VALIDATION") + print(e) + print(traceback.format_exc()) + print("!!!!!!!!!!!!!!!!!!!") + + finally: + if accelerator.is_main_process: + save_checkpoint( + model.get_lora_state_dict(accelerator.unwrap_model(unet)), + [accelerator.unwrap_model(m).state_dict() for m in mappers], + None, + output_path / f"checkpoint-{global_step}", + ) + + unet.train() + map(lambda m: m.train(), mappers) + map(lambda m: m.train(), encoders) + + if stop_training: + break + + accelerator.wait_for_everyone() + save_checkpoint( + model.get_lora_state_dict(accelerator.unwrap_model(unet)), + [accelerator.unwrap_model(m).state_dict() for m in mappers], + None, + output_path / f"checkpoint-{global_step}", + ) + + +if __name__ == "__main__": + main()