diff --git a/library/lumina_models.py b/library/lumina_models.py index 365453c1c..d86a9cb2b 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -880,8 +880,8 @@ def __init__( self.n_heads = n_heads self.gradient_checkpointing = False - self.cpu_offload_checkpointing = False - self.blocks_to_swap = None + self.cpu_offload_checkpointing = False # TODO: not yet supported + self.blocks_to_swap = None # TODO: not yet supported @property def device(self): @@ -982,8 +982,8 @@ def patchify_and_embed( l_effective_cap_len = cap_mask.sum(dim=1).tolist() encoder_seq_len = cap_mask.shape[1] - image_seq_len = (height // self.patch_size) * (width // self.patch_size) + seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len] max_seq_len = max(seq_lengths) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 172d09eac..4aa48e8b2 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -4,7 +4,7 @@ import os import numpy as np import time -from typing import Callable, Dict, List, Optional, Tuple, Any, Union +from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator import torch from torch import Tensor @@ -32,6 +32,59 @@ # region sample images +def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]: + """ + Group prompt dictionaries into batches with configurable batch size. + + Args: + prompt_dicts (list): List of dictionaries containing prompt parameters. + batch_size (int, optional): Number of prompts per batch. Defaults to None. + + Yields: + list[dict[str, str]]: Batch of prompts. + """ + # Validate batch_size + if batch_size is not None: + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size must be a positive integer or None") + + # Group prompts by their parameters + batches = {} + for prompt_dict in prompt_dicts: + # Extract parameters + width = int(prompt_dict.get("width", 1024)) + height = int(prompt_dict.get("height", 1024)) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + guidance_scale = float(prompt_dict.get("scale", 3.5)) + sample_steps = int(prompt_dict.get("sample_steps", 38)) + seed = prompt_dict.get("seed", None) + seed = int(seed) if seed is not None else None + + # Create a key based on the parameters + key = (width, height, guidance_scale, seed, sample_steps) + + # Add the prompt_dict to the corresponding batch + if key not in batches: + batches[key] = [] + batches[key].append(prompt_dict) + + # Yield each batch with its parameters + for key in batches: + prompts = batches[key] + if batch_size is None: + # Yield the entire group as a single batch + yield prompts + else: + # Split the group into batches of size `batch_size` + start = 0 + while start < len(prompts): + end = start + batch_size + batch = prompts[start:end] + yield batch + start = end + + @torch.no_grad() def sample_images( accelerator: Accelerator, @@ -39,9 +92,9 @@ def sample_images( epoch: int, global_step: int, nextdit: lumina_models.NextDiT, - vae: torch.nn.Module, + vae: AutoEncoder, gemma2_model: Gemma2Model, - sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]], + sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], prompt_replacement: Optional[Tuple[str, str]] = None, controlnet=None, ): @@ -54,11 +107,13 @@ def sample_images( epoch (int): Current epoch number. global_step (int): Current global step number. nextdit (lumina_models.NextDiT): The NextDiT model instance. - vae (torch.nn.Module): The VAE module. + vae (AutoEncoder): The VAE module. gemma2_model (Gemma2Model): The Gemma2 model instance. - sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing the encoded prompts, text masks, and timestep for each sample. - prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None. - controlnet:: ControlNet model + sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]): + Dictionary ist of tuples containing the encoded prompts, text masks, and timestep for each sample. + prompt_replacement (Optional[Tuple[str, str]], optional): + Tuple containing the prompt and negative prompt replacements. Defaults to None. + controlnet (): ControlNet model, not yet supported Returns: None @@ -110,9 +165,12 @@ def sample_images( except Exception: pass + batch_size = args.sample_batch_size or args.train_batch_size or 1 + if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - for prompt_dict in prompts: + # TODO: batch prompts together with buckets of image sizes + for prompt_dicts in batchify(prompts, batch_size): sample_image_inference( accelerator, args, @@ -120,7 +178,7 @@ def sample_images( gemma2_model, vae, save_dir, - prompt_dict, + prompt_dicts, epoch, global_step, sample_prompts_gemma2_outputs, @@ -135,7 +193,8 @@ def sample_images( per_process_prompts.append(prompts[i :: distributed_state.num_processes]) with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: + # TODO: batch prompts together with buckets of image sizes + for prompt_dicts in batchify(prompt_dict_lists[0], batch_size): sample_image_inference( accelerator, args, @@ -143,7 +202,7 @@ def sample_images( gemma2_model, vae, save_dir, - prompt_dict, + prompt_dicts, epoch, global_step, sample_prompts_gemma2_outputs, @@ -166,10 +225,10 @@ def sample_image_inference( gemma2_model: Gemma2Model, vae: AutoEncoder, save_dir: str, - prompt_dict: Dict[str, str], + prompt_dicts: list[Dict[str, str]], epoch: int, global_step: int, - sample_prompts_gemma2_outputs: dict[str, List[Tuple[Tensor, Tensor, Tensor]]], + sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], prompt_replacement: Optional[Tuple[str, str]] = None, controlnet=None, ): @@ -192,43 +251,6 @@ def sample_image_inference( Returns: None """ - assert isinstance(prompt_dict, dict) - # negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = int(prompt_dict.get("sample_steps", 38)) - width = int(prompt_dict.get("width", 1024)) - height = int(prompt_dict.get("height", 1024)) - guidance_scale = float(prompt_dict.get("scale", 3.5)) - seed = prompt_dict.get("seed", None) - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - negative_prompt: str = prompt_dict.get("negative_prompt", "") - # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - seed = int(seed) if seed is not None else None - assert seed is None or seed > 0, f"Invalid seed {seed}" - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - generator = torch.Generator(device=accelerator.device) - if seed is not None: - generator.manual_seed(seed) - - # if negative_prompt is None: - # negative_prompt = "" - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - logger.info(f"prompt: {prompt}") - logger.info(f"negative_prompt: {negative_prompt}") - logger.info(f"height: {height}") - logger.info(f"width: {width}") - logger.info(f"sample_steps: {sample_steps}") - logger.info(f"scale: {guidance_scale}") - # logger.info(f"sample_sampler: {sampler_name}") - if seed is not None: - logger.info(f"seed: {seed}") # encode prompts tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() @@ -237,33 +259,86 @@ def sample_image_inference( assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - system_prompt = args.system_prompt or "" - - # Apply system prompt to prompts - prompt = system_prompt + prompt - negative_prompt = system_prompt + negative_prompt - - # Get sample prompts from cache - if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: - gemma2_conds = sample_prompts_gemma2_outputs[prompt] - logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") + text_conds = [] - if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: - neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] - logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + # assuming seed, width, height, sample steps, guidance are the same + width = int(prompt_dicts[0].get("width", 1024)) + height = int(prompt_dicts[0].get("height", 1024)) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 - # Load sample prompts from Gemma 2 - if gemma2_model is not None: - logger.info(f"Encoding prompt with Gemma2: {prompt}") - tokens_and_masks = tokenize_strategy.tokenize(prompt) - gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + guidance_scale = float(prompt_dicts[0].get("scale", 3.5)) + sample_steps = int(prompt_dicts[0].get("sample_steps", 36)) + seed = prompt_dicts[0].get("seed", None) + seed = int(seed) if seed is not None else None + assert seed is None or seed > 0, f"Invalid seed {seed}" + generator = torch.Generator(device=accelerator.device) + if seed is not None: + generator.manual_seed(seed) - tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) - neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + for prompt_dict in prompt_dicts: + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + negative_prompt = prompt_dict.get("negative_prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if negative_prompt is None: + negative_prompt = "" + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {guidance_scale}") + # logger.info(f"sample_sampler: {sampler_name}") + + system_prompt = args.system_prompt or "" + + # Apply system prompt to prompts + prompt = system_prompt + prompt + negative_prompt = system_prompt + negative_prompt + + # Get sample prompts from cache + if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: + gemma2_conds = sample_prompts_gemma2_outputs[prompt] + logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") + + if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: + neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] + logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + + # Load sample prompts from Gemma 2 + if gemma2_model is not None: + logger.info(f"Encoding prompt with Gemma2: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + + # Unpack Gemma2 outputs + gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds + neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds + + text_conds.append( + ( + gemma2_hidden_states.squeeze(0), + gemma2_attn_mask.squeeze(0), + neg_gemma2_hidden_states.squeeze(0), + neg_gemma2_attn_mask.squeeze(0), + ) + ) - # Unpack Gemma2 outputs - gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds - neg_gemma2_hidden_states, neg_input_ids, neg_gemma2_attn_mask = neg_gemma2_conds + # Stack conditioning + cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device) + cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device) + uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device) + uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device) # sample image weight_dtype = vae.dtype # TOFO give dtype as argument @@ -279,6 +354,7 @@ def sample_image_inference( dtype=weight_dtype, generator=generator, ) + noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1) scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0) timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps) @@ -294,10 +370,10 @@ def sample_image_inference( scheduler, nextdit, noise, - gemma2_hidden_states, - gemma2_attn_mask.to(accelerator.device), - neg_gemma2_hidden_states, - neg_gemma2_attn_mask.to(accelerator.device), + cond_hidden_states, + cond_attn_masks, + uncond_hidden_states, + uncond_attn_masks, timesteps=timesteps, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, @@ -307,33 +383,43 @@ def sample_image_inference( clean_memory_on_device(accelerator.device) org_vae_device = vae.device # will be on cpu vae.to(accelerator.device) # distributed_state.device is same as accelerator.device - with accelerator.autocast(): - x = vae.decode((x / vae.scale_factor) + vae.shift_factor) - vae.to(org_vae_device) - clean_memory_on_device(accelerator.device) + for img, prompt_dict in zip(x, prompt_dicts): + + img = (img / vae.scale_factor) + vae.shift_factor + + with accelerator.autocast(): + # Add a single batch image for the VAE to decode + img = vae.decode(img.unsqueeze(0)) - x = x.clamp(-1, 1) - x = x.permute(0, 2, 3, 1) - image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + img = img.clamp(-1, 1) + img = img.permute(0, 2, 3, 1) # B, H, W, C + # Scale images back to 0 to 255 + img = (127.5 * (img + 1.0)).float().cpu().numpy().astype(np.uint8) - # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list - # but adding 'enum' to the filename should be enough + # Get single image + image = Image.fromarray(img[0]) - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - i: int = int(prompt_dict.get("enum", 0)) - img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" - image.save(os.path.join(save_dir, img_filename)) + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough - # send images to wandb if enabled - if "wandb" in [tracker.name for tracker in accelerator.trackers]: - wandb_tracker = accelerator.get_tracker("wandb") + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = int(prompt_dict.get("enum", 0)) + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) - import wandb + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") - # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) def time_shift(mu: float, sigma: float, t: torch.Tensor): @@ -879,16 +965,22 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): "--discrete_flow_shift", type=float, default=6.0, - help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0。", + help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0 / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0", ) parser.add_argument( "--use_flash_attn", action="store_true", - help="Use Flash Attention for the model. / モデルにFlash Attentionを使用する。", + help="Use Flash Attention for the model / モデルにFlash Attentionを使用する", ) parser.add_argument( "--system_prompt", type=str, default="", - help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。", + help="System prompt to add to the prompt / プロンプトに追加するシステムプロンプト", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=None, + help="Batch size to use for sampling, defaults to --training_batch_size value. Sample batches are bucketed by width, height, guidance scale, and seed / サンプリングに使用するバッチサイズ。デフォルトは --training_batch_size の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます", ) diff --git a/train_network.py b/train_network.py index 2cf11af73..07de30b3b 100644 --- a/train_network.py +++ b/train_network.py @@ -1242,6 +1242,7 @@ def remove_model(old_ckpt_name): # For --sample_at_first optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + progress_bar.unpause() # Reset progress bar to before sampling images optimizer_train_fn() is_tracking = len(accelerator.trackers) > 0 if is_tracking: @@ -1344,6 +1345,7 @@ def remove_model(old_ckpt_name): self.sample_images( accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet ) + progress_bar.unpause() # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1531,6 +1533,7 @@ def remove_model(old_ckpt_name): train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + progress_bar.unpause() optimizer_train_fn() # end of epoch