diff --git a/library/config_util.py b/library/config_util.py index a2e07dc6c..ca14dfb13 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -75,6 +75,7 @@ class BaseSubsetParams: custom_attributes: Optional[Dict[str, Any]] = None validation_seed: int = 0 validation_split: float = 0.0 + system_prompt: Optional[str] = None @dataclass @@ -106,6 +107,7 @@ class BaseDatasetParams: debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 + system_prompt: Optional[str] = None @dataclass @@ -196,6 +198,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "caption_prefix": str, "caption_suffix": str, "custom_attributes": dict, + "system_prompt": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -241,6 +244,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "system_prompt": str, } # options handled by argparse but not handled by user config @@ -526,6 +530,7 @@ def print_info(_datasets, dataset_type: str): batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} + system_prompt: {dataset.system_prompt} """) if dataset.enable_bucket: @@ -559,6 +564,7 @@ def print_info(_datasets, dataset_type: str): token_warmup_step: {subset.token_warmup_step}, alpha_mask: {subset.alpha_mask} custom_attributes: {subset.custom_attributes} + system_prompt: {subset.system_prompt} """), " ") if is_dreambooth: diff --git a/library/lumina_models.py b/library/lumina_models.py new file mode 100644 index 000000000..1a441a69d --- /dev/null +++ b/library/lumina_models.py @@ -0,0 +1,1250 @@ +# Copyright Alpha VLLM/Lumina Image 2.0 and contributors +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math +from typing import List, Optional, Tuple +from dataclasses import dataclass + +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import torch.nn as nn +import torch.nn.functional as F + +try: + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +except: + # flash_attn may not be available but it is not required + pass + +try: + from apex.normalization import FusedRMSNorm as RMSNorm +except: + import warnings + + warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + + ############################################################################# + # RMSNorm # + ############################################################################# + + class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x) -> Tensor: + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor): + """ + Apply RMSNorm to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + """ + x_dtype = x.dtype + # To handle float8 we need to convert the tensor to float + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) + + + +@dataclass +class LuminaParams: + """Parameters for Lumina model configuration""" + + patch_size: int = 2 + in_channels: int = 4 + dim: int = 4096 + n_layers: int = 30 + n_refiner_layers: int = 2 + n_heads: int = 24 + n_kv_heads: int = 8 + multiple_of: int = 256 + axes_dims: List[int] = None + axes_lens: List[int] = None + qk_norm: bool = False + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + scaling_factor: float = 1.0 + cap_feat_dim: int = 32 + + def __post_init__(self): + if self.axes_dims is None: + self.axes_dims = [36, 36, 36] + if self.axes_lens is None: + self.axes_lens = [300, 512, 512] + + @classmethod + def get_2b_config(cls) -> "LuminaParams": + """Returns the configuration for the 2B parameter model""" + return cls( + patch_size=2, + in_channels=16, # VAE channels + dim=2304, + n_layers=26, + n_heads=24, + n_kv_heads=8, + axes_dims=[32, 32, 32], + axes_lens=[300, 512, 512], + qk_norm=True, + cap_feat_dim=2304, # Gemma 2 hidden_size + ) + + @classmethod + def get_7b_config(cls) -> "LuminaParams": + """Returns the configuration for the 7B parameter model""" + return cls( + patch_size=2, + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[64, 64, 64], + axes_lens=[300, 512, 512], + ) + + +class GradientCheckpointMixin(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = False + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + + +def modulate(x, scale): + return x * (1 + scale.unsqueeze(1)) + + +############################################################################# +# Embedding Layers for Timesteps and Class Labels # +############################################################################# + + +class TimestepEmbedder(GradientCheckpointMixin): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + hidden_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + hidden_size, + hidden_size, + bias=True, + ), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.zeros_(self.mlp[0].bias) + nn.init.normal_(self.mlp[2].weight, std=0.02) + nn.init.zeros_(self.mlp[2].bias) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def _forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + return t_emb + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + +############################################################################# +# Core NextDiT Model # +############################################################################# + + +class JointAttention(nn.Module): + """Multi-head attention module.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + qk_norm: bool, + use_flash_attn=False, + ): + """ + Initialize the Attention module. + + Args: + dim (int): Number of input dimensions. + n_heads (int): Number of heads. + n_kv_heads (Optional[int]): Number of kv heads, if using GQA. + qk_norm (bool): Whether to use normalization for queries and keys. + + """ + super().__init__() + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_heads = n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads + + self.qkv = nn.Linear( + dim, + (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.qkv.weight) + + self.out = nn.Linear( + n_heads * self.head_dim, + dim, + bias=False, + ) + nn.init.xavier_uniform_(self.out.weight) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + else: + self.q_norm = self.k_norm = nn.Identity() + + self.use_flash_attn = use_flash_attn + + # self.attention_processor = xformers.ops.memory_efficient_attention + self.attention_processor = F.scaled_dot_product_attention + + def set_attention_processor(self, attention_processor): + self.attention_processor = attention_processor + + def forward( + self, + x: Tensor, + x_mask: Tensor, + freqs_cis: Tensor, + ) -> Tensor: + """ + Args: + x: + x_mask: + freqs_cis: + """ + bsz, seqlen, _ = x.shape + dtype = x.dtype + + xq, xk, xv = torch.split( + self.qkv(x), + [ + self.n_local_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + ], + dim=-1, + ) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xq = apply_rope(xq, freqs_cis=freqs_cis) + xk = apply_rope(xk, freqs_cis=freqs_cis) + xq, xk = xq.to(dtype), xk.to(dtype) + + softmax_scale = math.sqrt(1 / self.head_dim) + + if self.use_flash_attn: + output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale) + else: + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + output = ( + self.attention_processor( + xq.permute(0, 2, 1, 3), + xk.permute(0, 2, 1, 3), + xv.permute(0, 2, 1, 3), + attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), + scale=softmax_scale, + ) + .permute(0, 2, 1, 3) + .to(dtype) + ) + + output = output.flatten(-2) + return self.out(output) + + # copied from huggingface modeling_llama.py + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def flash_attn( + self, + q: Tensor, + k: Tensor, + v: Tensor, + x_mask: Tensor, + softmax_scale, + ) -> Tensor: + bsz, seqlen, _, _ = q.shape + + try: + # begin var_len flash attn + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(q, k, v, x_mask, seqlen) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) + # end var_len_flash_attn + + return output + except NameError as e: + raise RuntimeError( + f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}" + ) + + +def apply_rope( + x_in: torch.Tensor, + freqs_cis: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency + tensor. + + This function applies rotary embeddings to the given query 'xq' and + key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The + input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors + contain rotary embeddings and are returned as real tensors. + + Args: + x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex + exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor + and key tensor with rotary embeddings. + """ + with torch.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + + return x_out.type_as(x_in) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden + dimension. Defaults to None. + + """ + super().__init__() + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w1.weight) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w2.weight) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w3.weight) + + # @torch.compile + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class JointTransformerBlock(GradientCheckpointMixin): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + multiple_of: int, + ffn_dim_multiplier: Optional[float], + norm_eps: float, + qk_norm: bool, + modulation=True, + use_flash_attn=False, + ) -> None: + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + dim (int): Embedding dimension of the input features. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of attention heads in key and + value features (if using GQA), or set to None for the same as + query. + multiple_of (int): Number of multiple of the hidden dimension. + ffn_dim_multiplier (Optional[float]): Dimension multiplier for the + feedforward layer. + norm_eps (float): Epsilon value for normalization. + qk_norm (bool): Whether to use normalization for queries and keys. + modulation (bool): Whether to use modulation for the attention + layer. + """ + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + min(dim, 1024), + 4 * dim, + bias=True, + ), + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def _forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + pe: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (Tensor): Input tensor. + pe (Tensor): Rope position embedding. + + Returns: + Tensor: Output tensor after applying attention and + feedforward layers. + + """ + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) + + x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + self.attention( + modulate(self.attention_norm1(x), scale_msa), + x_mask, + pe, + ) + ) + x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + self.feed_forward( + modulate(self.ffn_norm1(x), scale_mlp), + ) + ) + else: + assert adaln_input is None + x = x + self.attention_norm2( + self.attention( + self.attention_norm1(x), + x_mask, + pe, + ) + ) + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + return x + + +class FinalLayer(GradientCheckpointMixin): + """ + The final layer of NextDiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + """ + Initialize the FinalLayer. + + Args: + hidden_size (int): Hidden size of the input features. + patch_size (int): Patch size of the input features. + out_channels (int): Number of output channels. + """ + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + ) + self.linear = nn.Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + ) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + min(hidden_size, 1024), + hidden_size, + bias=True, + ), + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + scale = self.adaLN_modulation(c) + x = modulate(self.norm_final(x), scale) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 10000.0, + axes_dims: List[int] = [16, 56, 56], + axes_lens: List[int] = [1, 512, 512], + ): + super().__init__() + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + + def __call__(self, ids: torch.Tensor): + device = ids.device + self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] + result = [] + for i in range(len(self.axes_dims)): + freqs = self.freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1) + + +class NextDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 4, + dim: int = 4096, + n_layers: int = 32, + n_refiner_layers: int = 2, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + qk_norm: bool = False, + cap_feat_dim: int = 5120, + axes_dims: List[int] = [16, 56, 56], + axes_lens: List[int] = [1, 512, 512], + use_flash_attn=False, + ) -> None: + """ + Initialize the NextDiT model. + + Args: + patch_size (int): Patch size of the input features. + in_channels (int): Number of input channels. + dim (int): Hidden size of the input features. + n_layers (int): Number of Transformer layers. + n_refiner_layers (int): Number of refiner layers. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of attention heads in key and + value features (if using GQA), or set to None for the same as + query. + multiple_of (int): Multiple of the hidden size. + ffn_dim_multiplier (Optional[float]): Dimension multiplier for the + feedforward layer. + norm_eps (float): Epsilon value for normalization. + qk_norm (bool): Whether to use query key normalization. + cap_feat_dim (int): Dimension of the caption features. + axes_dims (List[int]): List of dimensions for the axes. + axes_lens (List[int]): List of lengths for the axes. + + Returns: + None + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + + self.t_embedder = TimestepEmbedder(min(dim, 1024)) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear( + cap_feat_dim, + dim, + bias=True, + ), + ) + + self.context_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + use_flash_attn=use_flash_attn, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=dim, + bias=True, + ) + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + self.noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + use_flash_attn=use_flash_attn, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) + # nn.init.zeros_(self.cap_embedder[1].weight) + nn.init.zeros_(self.cap_embedder[1].bias) + + self.layers = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + use_flash_attn=use_flash_attn, + ) + for layer_id in range(n_layers) + ] + ) + self.norm_final = RMSNorm(dim, eps=norm_eps) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels) + + assert (dim // n_heads) == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + self.dim = dim + self.n_heads = n_heads + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False # TODO: not yet supported + self.blocks_to_swap = None # TODO: not yet supported + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.t_embedder.enable_gradient_checkpointing() + + for block in self.layers + self.context_refiner + self.noise_refiner: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + self.final_layer.enable_gradient_checkpointing() + + print(f"Lumina: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.t_embedder.disable_gradient_checkpointing() + + for block in self.layers + self.context_refiner + self.noise_refiner: + block.disable_gradient_checkpointing() + + self.final_layer.disable_gradient_checkpointing() + + print("Lumina: Gradient checkpointing disabled.") + + def unpatchify( + self, + x: Tensor, + width: int, + height: int, + encoder_seq_lengths: List[int], + seq_lengths: List[int], + ) -> Tensor: + """ + Unpatchify the input tensor and embed the caption features. + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + + Args: + x (Tensor): Input tensor. + width (int): Width of the input tensor. + height (int): Height of the input tensor. + encoder_seq_lengths (List[int]): List of encoder sequence lengths. + seq_lengths (List[int]): List of sequence lengths + + Returns: + output: (N, C, H, W) + """ + pH = pW = self.patch_size + + output = [] + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + output.append( + x[i][encoder_seq_len:seq_len] + .view(height // pH, width // pW, pH, pW, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + output = torch.stack(output, dim=0) + + return output + + def patchify_and_embed( + self, + x: Tensor, + cap_feats: Tensor, + cap_mask: Tensor, + t: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, List[int], List[int]]: + """ + Patchify and embed the input image and caption features. + + Args: + x: (N, C, H, W) image latents + cap_feats: (N, C, D) caption features + cap_mask: (N, C, D) caption attention mask + t: (N), T timesteps + + Returns: + Tuple[Tensor, Tensor, Tensor, List[int], List[int]]: + + return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths + """ + bsz, channels, height, width = x.shape + pH = pW = self.patch_size + device = x.device + + l_effective_cap_len = cap_mask.sum(dim=1).tolist() + encoder_seq_len = cap_mask.shape[1] + image_seq_len = (height // self.patch_size) * (width // self.patch_size) + + seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len] + max_seq_len = max(seq_lengths) + + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + H_tokens, W_tokens = height // pH, width // pW + + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + position_ids[i, cap_len:seq_len, 0] = cap_len + + row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + + position_ids[i, cap_len:seq_len, 1] = row_ids + position_ids[i, cap_len:seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self.rope_embedder(position_ids) + + # Create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + bsz, + encoder_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + img_freqs_cis = torch.zeros( + bsz, + image_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] + img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_len:seq_len] + + # Refine caption context + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + + x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) + x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device) + + x = self.x_embedder(x) + + # Refine image context + for layer in self.noise_refiner: + x = layer(x, x_mask, img_freqs_cis, t) + + joint_hidden_states = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x.dtype) + attention_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :cap_len] = cap_feats[i, :cap_len] + joint_hidden_states[i, cap_len:seq_len] = x[i] + + x = joint_hidden_states + + return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths + + def forward(self, x: Tensor, t: Tensor, cap_feats: Tensor, cap_mask: Tensor) -> Tensor: + """ + Forward pass of NextDiT. + Args: + x: (N, C, H, W) image latents + t: (N,) tensor of diffusion timesteps + cap_feats: (N, L, D) caption features + cap_mask: (N, L) caption attention mask + + Returns: + x: (N, C, H, W) denoised latents + """ + _, _, height, width = x.shape # B, C, H, W + t = self.t_embedder(t) # (N, D) + cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + + x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t) + + for layer in self.layers: + x = layer(x, mask, freqs_cis, t) + + x = self.final_layer(x, t) + x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths) + + return x + + def forward_with_cfg( + self, + x: Tensor, + t: Tensor, + cap_feats: Tensor, + cap_mask: Tensor, + cfg_scale: float, + cfg_trunc: float = 0.25, + renorm_cfg: float = 1.0, + ): + """ + Forward pass of NextDiT, but also batches the unconditional forward pass + for classifier-free guidance. + """ + # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + if t[0] < cfg_trunc: + combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128] + assert ( + cap_mask.shape[0] == combined.shape[0] + ), f"caption attention mask shape: {cap_mask.shape[0]} latents shape: {combined.shape[0]}" + model_out = self.forward(x, t, cap_feats, cap_mask) # [2, 16, 128, 128] + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = ( + model_out[:, : self.in_channels], + model_out[:, self.in_channels :], + ) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + if float(renorm_cfg) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True) + max_new_norm = ori_pos_norm * float(renorm_cfg) + new_pos_norm = torch.linalg.vector_norm(half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True) + if new_pos_norm >= max_new_norm: + half_eps = half_eps * (max_new_norm / new_pos_norm) + else: + combined = half + model_out = self.forward( + combined, + t[: len(x) // 2], + cap_feats[: len(x) // 2], + cap_mask[: len(x) // 2], + ) + eps, rest = ( + model_out[:, : self.in_channels], + model_out[:, self.in_channels :], + ) + half_eps = eps + + output = torch.cat([half_eps, half_eps], dim=0) + return output + + @staticmethod + def precompute_freqs_cis( + dim: List[int], + end: List[int], + theta: float = 10000.0, + ) -> List[Tensor]: + """ + Precompute the frequency tensor for complex exponentials (cis) with + given dimensions. + + This function calculates a frequency tensor with complex exponentials + using the given dimension 'dim' and the end index 'end'. The 'theta' + parameter scales the frequencies. The returned tensor contains complex + values in complex64 data type. + + Args: + dim (list): Dimension of the frequency tensor. + end (list): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. + Defaults to 10000.0. + + Returns: + List[torch.Tensor]: Precomputed frequency tensor with complex + exponentials. + """ + freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + for i, (d, e) in enumerate(zip(dim, end)): + pos = torch.arange(e, dtype=freqs_dtype, device="cpu") + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=freqs_dtype, device="cpu") / d)) + freqs = torch.outer(pos, freqs) + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2] + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def parameter_count(self) -> int: + total_params = 0 + + def _recursive_count_params(module): + nonlocal total_params + for param in module.parameters(recurse=False): + total_params += param.numel() + for submodule in module.children(): + _recursive_count_params(submodule) + + _recursive_count_params(self) + return total_params + + def get_fsdp_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + +############################################################################# +# NextDiT Configs # +############################################################################# + + +def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs): + if params is None: + params = LuminaParams.get_2b_config() + + return NextDiT( + patch_size=params.patch_size, + in_channels=params.in_channels, + dim=params.dim, + n_layers=params.n_layers, + n_heads=params.n_heads, + n_kv_heads=params.n_kv_heads, + axes_dims=params.axes_dims, + axes_lens=params.axes_lens, + qk_norm=params.qk_norm, + ffn_dim_multiplier=params.ffn_dim_multiplier, + norm_eps=params.norm_eps, + cap_feat_dim=params.cap_feat_dim, + **kwargs, + ) + + +def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=2592, + n_layers=30, + n_heads=24, + n_kv_heads=8, + axes_dims=[36, 36, 36], + axes_lens=[300, 512, 512], + **kwargs, + ) + + +def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=2880, + n_layers=32, + n_heads=24, + n_kv_heads=8, + axes_dims=[40, 40, 40], + axes_lens=[300, 512, 512], + **kwargs, + ) + + +def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=3840, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[40, 40, 40], + axes_lens=[300, 512, 512], + **kwargs, + ) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py new file mode 100644 index 000000000..20df7eef6 --- /dev/null +++ b/library/lumina_train_util.py @@ -0,0 +1,983 @@ +import inspect +import argparse +import math +import os +import numpy as np +import time +from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator + +import torch +from torch import Tensor +from accelerate import Accelerator, PartialState +from transformers import Gemma2Model +from tqdm import tqdm +from PIL import Image +from safetensors.torch import save_file + +from library import lumina_models, strategy_base, strategy_lumina, train_util +from library.flux_models import AutoEncoder +from library.device_utils import init_ipex, clean_memory_on_device +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +init_ipex() + +from .utils import setup_logging, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# region sample images + + +def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]: + """ + Group prompt dictionaries into batches with configurable batch size. + + Args: + prompt_dicts (list): List of dictionaries containing prompt parameters. + batch_size (int, optional): Number of prompts per batch. Defaults to None. + + Yields: + list[dict[str, str]]: Batch of prompts. + """ + # Validate batch_size + if batch_size is not None: + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size must be a positive integer or None") + + # Group prompts by their parameters + batches = {} + for prompt_dict in prompt_dicts: + # Extract parameters + width = int(prompt_dict.get("width", 1024)) + height = int(prompt_dict.get("height", 1024)) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + guidance_scale = float(prompt_dict.get("scale", 3.5)) + sample_steps = int(prompt_dict.get("sample_steps", 38)) + cfg_trunc_ratio = float(prompt_dict.get("cfg_trunc_ratio", 0.25)) + renorm_cfg = float(prompt_dict.get("renorm_cfg", 1.0)) + seed = prompt_dict.get("seed", None) + seed = int(seed) if seed is not None else None + + # Create a key based on the parameters + key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg) + + # Add the prompt_dict to the corresponding batch + if key not in batches: + batches[key] = [] + batches[key].append(prompt_dict) + + # Yield each batch with its parameters + for key in batches: + prompts = batches[key] + if batch_size is None: + # Yield the entire group as a single batch + yield prompts + else: + # Split the group into batches of size `batch_size` + start = 0 + while start < len(prompts): + end = start + batch_size + batch = prompts[start:end] + yield batch + start = end + + +@torch.no_grad() +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch: int, + global_step: int, + nextdit: lumina_models.NextDiT, + vae: AutoEncoder, + gemma2_model: Gemma2Model, + sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], + prompt_replacement: Optional[Tuple[str, str]] = None, + controlnet=None, +): + """ + Generate sample images using the NextDiT model. + + Args: + accelerator (Accelerator): Accelerator instance. + args (argparse.Namespace): Command-line arguments. + epoch (int): Current epoch number. + global_step (int): Current global step number. + nextdit (lumina_models.NextDiT): The NextDiT model instance. + vae (AutoEncoder): The VAE module. + gemma2_model (Gemma2Model): The Gemma2 model instance. + sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]): + Dictionary of tuples containing the encoded prompts, text masks, and timestep for each sample. + prompt_replacement (Optional[Tuple[str, str]], optional): + Tuple containing the prompt and negative prompt replacements. Defaults to None. + controlnet (): ControlNet model, not yet supported + + Returns: + None + """ + if global_step == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + assert ( + args.sample_prompts is not None + ), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください" + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}") + if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap nextdit and gemma2_model + nextdit = accelerator.unwrap_model(nextdit) + if gemma2_model is not None: + gemma2_model = accelerator.unwrap_model(gemma2_model) + # if controlnet is not None: + # controlnet = accelerator.unwrap_model(controlnet) + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + batch_size = args.sample_batch_size or args.train_batch_size or 1 + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + # TODO: batch prompts together with buckets of image sizes + for prompt_dicts in batchify(prompts, batch_size): + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + vae, + save_dir, + prompt_dicts, + epoch, + global_step, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + # TODO: batch prompts together with buckets of image sizes + for prompt_dicts in batchify(prompt_dict_lists[0], batch_size): + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + vae, + save_dir, + prompt_dicts, + epoch, + global_step, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +@torch.no_grad() +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + nextdit: lumina_models.NextDiT, + gemma2_model: Gemma2Model, + vae: AutoEncoder, + save_dir: str, + prompt_dicts: list[Dict[str, str]], + epoch: int, + global_step: int, + sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], + prompt_replacement: Optional[Tuple[str, str]] = None, + controlnet=None, +): + """ + Generates sample images + + Args: + accelerator (Accelerator): Accelerator object + args (argparse.Namespace): Arguments object + nextdit (lumina_models.NextDiT): NextDiT model + gemma2_model (Gemma2Model): Gemma2 model + vae (AutoEncoder): VAE model + save_dir (str): Directory to save images + prompt_dict (Dict[str, str]): Prompt dictionary + epoch (int): Epoch number + steps (int): Number of steps to run + sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs + prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None. + + Returns: + None + """ + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) + assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) + + text_conds = [] + + # assuming seed, width, height, sample steps, guidance are the same + width = int(prompt_dicts[0].get("width", 1024)) + height = int(prompt_dicts[0].get("height", 1024)) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + + guidance_scale = float(prompt_dicts[0].get("scale", 3.5)) + cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25)) + renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0)) + sample_steps = int(prompt_dicts[0].get("sample_steps", 36)) + seed = prompt_dicts[0].get("seed", None) + seed = int(seed) if seed is not None else None + assert seed is None or seed > 0, f"Invalid seed {seed}" + generator = torch.Generator(device=accelerator.device) + if seed is not None: + generator.manual_seed(seed) + + for prompt_dict in prompt_dicts: + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + negative_prompt = prompt_dict.get("negative_prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if negative_prompt is None: + negative_prompt = "" + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {guidance_scale}") + logger.info(f"trunc: {cfg_trunc_ratio}") + logger.info(f"renorm: {renorm_cfg}") + # logger.info(f"sample_sampler: {sampler_name}") + + system_prompt = args.system_prompt or "" + + # Apply system prompt to prompts + prompt = system_prompt + prompt + negative_prompt = system_prompt + negative_prompt + + # Get sample prompts from cache + if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: + gemma2_conds = sample_prompts_gemma2_outputs[prompt] + logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") + + if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: + neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] + logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + + # Load sample prompts from Gemma 2 + if gemma2_model is not None: + tokens_and_masks = tokenize_strategy.tokenize(prompt) + gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + + # Unpack Gemma2 outputs + gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds + neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds + + text_conds.append( + ( + gemma2_hidden_states.squeeze(0), + gemma2_attn_mask.squeeze(0), + neg_gemma2_hidden_states.squeeze(0), + neg_gemma2_attn_mask.squeeze(0), + ) + ) + + # Stack conditioning + cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device) + cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device) + uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device) + uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device) + + # sample image + weight_dtype = vae.dtype # TOFO give dtype as argument + latent_height = height // 8 + latent_width = width // 8 + latent_channels = 16 + noise = torch.randn( + 1, + latent_channels, + latent_height, + latent_width, + device=accelerator.device, + dtype=weight_dtype, + generator=generator, + ) + noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1) + + scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0) + timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps) + + # if controlnet_image is not None: + # controlnet_image = Image.open(controlnet_image).convert("RGB") + # controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + + with accelerator.autocast(): + x = denoise( + scheduler, + nextdit, + noise, + cond_hidden_states, + cond_attn_masks, + uncond_hidden_states, + uncond_attn_masks, + timesteps=timesteps, + guidance_scale=guidance_scale, + cfg_trunc_ratio=cfg_trunc_ratio, + renorm_cfg=renorm_cfg, + ) + + # Latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) # distributed_state.device is same as accelerator.device + for img, prompt_dict in zip(x, prompt_dicts): + + img = (img / vae.scale_factor) + vae.shift_factor + + with accelerator.autocast(): + # Add a single batch image for the VAE to decode + img = vae.decode(img.unsqueeze(0)) + + img = img.clamp(-1, 1) + img = img.permute(0, 2, 3, 1) # B, H, W, C + # Scale images back to 0 to 255 + img = (127.5 * (img + 1.0)).float().cpu().numpy().astype(np.uint8) + + # Get single image + image = Image.fromarray(img[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = int(prompt_dict.get("enum", 0)) + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + # the following implementation was original for t=0: clean / t=1: noise + # Since we adopt the reverse, the 1-t operations are needed + t = 1 - t + t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + t = 1 - t + return t + + +def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]: + """ + Get linear function + + Args: + image_seq_len, + x1 base_seq_len: int = 256, + y2 max_seq_len: int = 4096, + y1 base_shift: float = 0.5, + y2 max_shift: float = 1.15, + + Return: + Callable[[float], float]: linear function + """ + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + """ + Get timesteps schedule + + Args: + num_steps (int): Number of steps in the schedule. + image_seq_len (int): Sequence length of the image. + base_shift (float, optional): Base shift value. Defaults to 0.5. + max_shift (float, optional): Maximum shift value. Defaults to 1.15. + shift (bool, optional): Whether to shift the schedule. Defaults to True. + + Return: + List[float]: timesteps schedule + """ + timesteps = torch.linspace(1, 1 / num_steps, num_steps) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +) -> Tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def denoise( + scheduler, + model: lumina_models.NextDiT, + img: Tensor, + txt: Tensor, + txt_mask: Tensor, + neg_txt: Tensor, + neg_txt_mask: Tensor, + timesteps: Union[List[float], torch.Tensor], + guidance_scale: float = 4.0, + cfg_trunc_ratio: float = 0.25, + renorm_cfg: float = 1.0, +): + """ + Denoise an image using the NextDiT model. + + Args: + scheduler (): + Noise scheduler + model (lumina_models.NextDiT): The NextDiT model instance. + img (Tensor): + The input image latent tensor. + txt (Tensor): + The input text tensor. + txt_mask (Tensor): + The input text mask tensor. + neg_txt (Tensor): + The negative input txt tensor + neg_txt_mask (Tensor): + The negative input text mask tensor. + timesteps (List[Union[float, torch.FloatTensor]]): + A list of timesteps for the denoising process. + guidance_scale (float, optional): + The guidance scale for the denoising process. Defaults to 4.0. + cfg_trunc_ratio (float, optional): + The ratio of the timestep interval to apply normalization-based guidance scale. + renorm_cfg (float, optional): + The factor to limit the maximum norm after guidance. Default: 1.0 + Returns: + img (Tensor): Denoised latent tensor + """ + + for i, t in enumerate(tqdm(timesteps)): + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - t / scheduler.config.num_train_timesteps + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device) + + noise_pred_cond = model( + img, + current_timestep, + cap_feats=txt, # Gemma2的hidden states作为caption features + cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask + ) + + # compute whether to apply classifier-free guidance based on current timestep + if current_timestep[0] < cfg_trunc_ratio: + noise_pred_uncond = model( + img, + current_timestep, + cap_feats=neg_txt, # Gemma2的hidden states作为caption features + cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask + ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + # apply normalization after classifier-free guidance + if float(renorm_cfg) > 0.0: + cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True) + max_new_norm = cond_norm * float(renorm_cfg) + noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True) + if noise_norm >= max_new_norm: + noise_pred = noise_pred * (max_new_norm / noise_norm) + else: + noise_pred = noise_pred_cond + + img_dtype = img.dtype + + if img.dtype != img_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + img = img.to(img_dtype) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred = -noise_pred + img = scheduler.step(noise_pred, t, img, return_dict=False)[0] + + return img + + +# endregion + + +# region train +def get_sigmas( + noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32 +) -> Tensor: + """ + Get sigmas for timesteps + + Args: + noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler instance. + timesteps (Tensor): A tensor of timesteps for the denoising process. + device (torch.device): The device on which the tensors are stored. + n_dim (int, optional): The number of dimensions for the output tensor. Defaults to 4. + dtype (torch.dtype, optional): The data type for the output tensor. Defaults to torch.float32. + + Returns: + sigmas (Tensor): The sigmas tensor. + """ + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """ + Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + + Args: + weighting_scheme (str): The weighting scheme to use. + batch_size (int): The batch size for the sampling process. + logit_mean (float, optional): The mean of the logit distribution. Defaults to None. + logit_std (float, optional): The standard deviation of the logit distribution. Defaults to None. + mode_scale (float, optional): The mode scale for the mode weighting scheme. Defaults to None. + + Returns: + u (Tensor): The sampled timesteps. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor: + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + + Args: + weighting_scheme (str): The weighting scheme to use. + sigmas (Tensor, optional): The sigmas tensor. Defaults to None. + + Returns: + u (Tensor): The sampled timesteps. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]: + """ + Get noisy model input and timesteps. + + Args: + args (argparse.Namespace): Arguments. + noise_scheduler (noise_scheduler): Noise scheduler. + latents (Tensor): Latents. + noise (Tensor): Latent noise. + device (torch.device): Device. + dtype (torch.dtype): Data type + + Return: + Tuple[Tensor, Tensor, Tensor]: + noisy model input + timesteps + sigmas + """ + bsz, _, h, w = latents.shape + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "nextdit_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) + timesteps = time_shift(mu, 1.0, timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + + +def apply_model_prediction_type( + args, model_pred: Tensor, noisy_model_input: Tensor, sigmas: Tensor +) -> Tuple[Tensor, Optional[Tensor]]: + """ + Apply model prediction type to the model prediction and the sigmas. + + Args: + args (argparse.Namespace): Arguments. + model_pred (Tensor): Model prediction. + noisy_model_input (Tensor): Noisy model input. + sigmas (Tensor): Sigmas. + + Return: + Tuple[Tensor, Optional[Tensor]]: + """ + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models( + ckpt_path: str, + lumina: lumina_models.NextDiT, + sai_metadata: Dict[str, Any], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): + """ + Save the model to the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + lumina (lumina_models.NextDiT): NextDIT model. + sai_metadata (Optional[dict]): Metadata for the SAI model. + save_dtype (Optional[torch.dtype]): Data + + Return: + None + """ + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None and v.dtype != save_dtype: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", lumina.state_dict()) + + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_lumina_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2" + ) + save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_lumina_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator: Accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + lumina: lumina_models.NextDiT, +): + """ + Save the model to the checkpoint path. + + Args: + args (argparse.Namespace): Arguments. + save_dtype (torch.dtype): Data type. + epoch (int): Epoch. + global_step (int): Global step. + lumina (lumina_models.NextDiT): NextDIT model. + + Return: + None + """ + + def sd_saver(ckpt_file: str, epoch_no: int, global_step: int): + sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_lumina_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--gemma2", + type=str, + help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--gemma2_max_token_length", + type=int, + default=None, + help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev" + " / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=6.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0 / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0", + ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + help="Use Flash Attention for the model / モデルにFlash Attentionを使用する", + ) + parser.add_argument( + "--system_prompt", + type=str, + default="", + help="System prompt to add to the prompt / プロンプトに追加するシステムプロンプト", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=None, + help="Batch size to use for sampling, defaults to --training_batch_size value. Sample batches are bucketed by width, height, guidance scale, and seed / サンプリングに使用するバッチサイズ。デフォルトは --training_batch_size の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます", + ) diff --git a/library/lumina_util.py b/library/lumina_util.py new file mode 100644 index 000000000..d9c899386 --- /dev/null +++ b/library/lumina_util.py @@ -0,0 +1,233 @@ +import json +import os +from dataclasses import replace +from typing import List, Optional, Tuple, Union + +import einops +import torch +from accelerate import init_empty_weights +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import Gemma2Config, Gemma2Model + +from library.utils import setup_logging +from library import lumina_models, flux_models +from library.utils import load_safetensors +import logging + +setup_logging() +logger = logging.getLogger(__name__) + +MODEL_VERSION_LUMINA_V2 = "lumina2" + + +def load_lumina_model( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: torch.device, + disable_mmap: bool = False, + use_flash_attn: bool = False, +): + """ + Load the Lumina model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (torch.device): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False. + + Returns: + model (lumina_models.NextDiT): The loaded model. + """ + logger.info("Building Lumina") + with torch.device("meta"): + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + info = model.load_state_dict(state_dict, strict=False, assign=True) + logger.info(f"Loaded Lumina: {info}") + return model + + +def load_ae( + ckpt_path: str, + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, +) -> flux_models.AutoEncoder: + """ + Load the AutoEncoder model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (Union[str, torch.device]): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + + Returns: + ae (flux_models.AutoEncoder): The loaded model. + """ + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_gemma2( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> Gemma2Model: + """ + Load the Gemma2 model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (Union[str, torch.device]): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + state_dict (Optional[dict], optional): The state dict to load. Defaults to None. + + Returns: + gemma2 (Gemma2Model): The loaded model + """ + logger.info("Building Gemma2") + GEMMA2_CONFIG = { + "_name_or_path": "google/gemma-2-2b", + "architectures": ["Gemma2Model"], + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": 50.0, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": 1, + "final_logit_softcapping": 30.0, + "head_dim": 256, + "hidden_act": "gelu_pytorch_tanh", + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2304, + "initializer_range": 0.02, + "intermediate_size": 9216, + "max_position_embeddings": 8192, + "model_type": "gemma2", + "num_attention_heads": 8, + "num_hidden_layers": 26, + "num_key_value_heads": 4, + "pad_token_id": 0, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "sliding_window": 4096, + "torch_dtype": "float32", + "transformers_version": "4.44.2", + "use_cache": True, + "vocab_size": 256000, + } + + config = Gemma2Config(**GEMMA2_CONFIG) + with init_empty_weights(): + gemma2 = Gemma2Model._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + for key in list(sd.keys()): + new_key = key.replace("model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + + info = gemma2.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Gemma2: {info}") + return gemma2 + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + + +DIFFUSERS_TO_ALPHA_VLLM_MAP = { + # Embedding layers + "cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"], + "cap_embedder.1.weight": "time_caption_embed.caption_embedder.1.weight", + "cap_embedder.1.bias": "text_embedder.1.bias", + "x_embedder.weight": "patch_embedder.proj.weight", + "x_embedder.bias": "patch_embedder.proj.bias", + # Attention modulation + "layers.().adaLN_modulation.1.weight": "transformer_blocks.().adaln_modulation.1.weight", + "layers.().adaLN_modulation.1.bias": "transformer_blocks.().adaln_modulation.1.bias", + # Final layers + "final_layer.adaLN_modulation.1.weight": "final_adaln_modulation.1.weight", + "final_layer.adaLN_modulation.1.bias": "final_adaln_modulation.1.bias", + "final_layer.linear.weight": "final_linear.weight", + "final_layer.linear.bias": "final_linear.bias", + # Noise refiner + "noise_refiner.().adaLN_modulation.1.weight": "single_transformer_blocks.().adaln_modulation.1.weight", + "noise_refiner.().adaLN_modulation.1.bias": "single_transformer_blocks.().adaln_modulation.1.bias", + "noise_refiner.().attention.qkv.weight": "single_transformer_blocks.().attn.to_qkv.weight", + "noise_refiner.().attention.out.weight": "single_transformer_blocks.().attn.to_out.0.weight", + # Time embedding + "t_embedder.mlp.0.weight": "time_embedder.0.weight", + "t_embedder.mlp.0.bias": "time_embedder.0.bias", + "t_embedder.mlp.2.weight": "time_embedder.2.weight", + "t_embedder.mlp.2.bias": "time_embedder.2.bias", + # Context attention + "context_refiner.().attention.qkv.weight": "transformer_blocks.().attn2.to_qkv.weight", + "context_refiner.().attention.out.weight": "transformer_blocks.().attn2.to_out.0.weight", + # Normalization + "layers.().attention_norm1.weight": "transformer_blocks.().norm1.weight", + "layers.().attention_norm2.weight": "transformer_blocks.().norm2.weight", + # FFN + "layers.().feed_forward.w1.weight": "transformer_blocks.().ff.net.0.proj.weight", + "layers.().feed_forward.w2.weight": "transformer_blocks.().ff.net.2.weight", + "layers.().feed_forward.w3.weight": "transformer_blocks.().ff.net.4.weight", +} + + +def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict: + """Convert Diffusers checkpoint to Alpha-VLLM format""" + logger.info("Converting Diffusers checkpoint to Alpha-VLLM format") + new_sd = {} + + for key, value in sd.items(): + new_key = key + for pattern, replacement in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): + if "()." in pattern: + for block_idx in range(num_double_blocks): + if str(block_idx) in key: + converted = pattern.replace("()", str(block_idx)) + new_key = key.replace(converted, replacement.replace("()", str(block_idx))) + break + + if new_key == key: + logger.debug(f"Unmatched key in conversion: {key}") + new_sd[new_key] = value + + logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format") + return new_sd diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 8896c047e..f5343924a 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -61,6 +61,8 @@ # ARCH_SD3_UNKNOWN = "stable-diffusion-3" ARCH_FLUX_1_DEV = "flux-1-dev" ARCH_FLUX_1_UNKNOWN = "flux-1" +ARCH_LUMINA_2 = "lumina-2" +ARCH_LUMINA_UNKNOWN = "lumina" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" @@ -69,6 +71,7 @@ IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" IMPL_FLUX = "https://github.com/black-forest-labs/flux" +IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -123,6 +126,7 @@ def build_metadata( clip_skip: Optional[int] = None, sd3: Optional[str] = None, flux: Optional[str] = None, + lumina: Optional[str] = None, ): """ sd3: only supports "m", flux: only supports "dev" @@ -146,6 +150,11 @@ def build_metadata( arch = ARCH_FLUX_1_DEV else: arch = ARCH_FLUX_1_UNKNOWN + elif lumina is not None: + if lumina == "lumina2": + arch = ARCH_LUMINA_2 + else: + arch = ARCH_LUMINA_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -167,6 +176,9 @@ def build_metadata( if flux is not None: # Flux impl = IMPL_FLUX + elif lumina is not None: + # Lumina + impl = IMPL_LUMINA elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI @@ -225,7 +237,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl or sd3 is not None or flux is not None: + if sdxl or sd3 is not None or flux is not None or lumina is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c40798846..6a4b39b3a 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -610,6 +610,21 @@ def encode_prompt(prpt): from diffusers.utils import BaseOutput +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + @dataclass class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): """ @@ -649,22 +664,49 @@ def __init__( self, num_train_timesteps: int = 1000, shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None + self._shift = shift + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + @property def step_index(self): """ @@ -690,6 +732,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index + def set_shift(self, shift: float): + self._shift = shift + def scale_noise( self, sample: torch.FloatTensor, @@ -709,10 +754,31 @@ def scale_noise( `torch.FloatTensor`: A scaled input sample. """ - if self.step_index is None: - self._init_step_index(timestep) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) - sigma = self.sigmas[self.step_index] sample = sigma * noise + (1.0 - sigma) * sample return sample @@ -720,7 +786,37 @@ def scale_noise( def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -730,18 +826,49 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) self.num_inference_steps = num_inference_steps - timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) - sigmas = timesteps / self.config.num_train_timesteps - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps - self.timesteps = timesteps.to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas self._step_index = None self._begin_index = None @@ -807,7 +934,11 @@ def step( returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" @@ -823,30 +954,10 @@ def step( sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 - - noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) - - eps = noise * s_noise - sigma_hat = sigma * (gamma + 1) - - if gamma > 0: - sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - # NOTE: "original_sample" should not be an expected prediction_type but is left in for - # backwards compatibility - - # if self.config.prediction_type == "vector_field": - - denoised = sample - model_output * sigma - # 2. Convert to an ODE derivative - derivative = (sample - denoised) / sigma_hat - - dt = self.sigmas[self.step_index + 1] - sigma_hat + prev_sample = sample + (sigma_next - sigma) * model_output - prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) @@ -858,6 +969,86 @@ def step( return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def __len__(self): return self.config.num_train_timesteps diff --git a/library/strategy_base.py b/library/strategy_base.py index 358e42f1d..fad79682f 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -2,7 +2,7 @@ import os import re -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Callable import numpy as np import torch @@ -430,9 +430,21 @@ def _default_is_disk_cached_latents_expected( bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, - alpha_mask: bool, + apply_alpha_mask: bool, multi_resolution: bool = False, - ): + ) -> bool: + """ + Args: + latents_stride: stride of latents + bucket_reso: resolution of the bucket + npz_path: path to the npz file + flip_aug: whether to flip images + apply_alpha_mask: whether to apply alpha mask + multi_resolution: whether to use multi-resolution latents + + Returns: + bool + """ if not self.cache_to_disk: return False if not os.path.exists(npz_path): @@ -451,7 +463,7 @@ def _default_is_disk_cached_latents_expected( return False if flip_aug and "latents_flipped" + key_reso_suffix not in npz: return False - if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: return False except Exception as e: logger.error(f"Error loading file: {npz_path}") @@ -462,22 +474,35 @@ def _default_is_disk_cached_latents_expected( # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( self, - encode_by_vae, - vae_device, - vae_dtype, + encode_by_vae: Callable, + vae_device: torch.device, + vae_dtype: torch.dtype, image_infos: List, flip_aug: bool, - alpha_mask: bool, + apply_alpha_mask: bool, random_crop: bool, multi_resolution: bool = False, ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + + Args: + encode_by_vae: function to encode images by VAE + vae_device: device to use for VAE + vae_dtype: dtype to use for VAE + image_infos: list of ImageInfo + flip_aug: whether to flip images + apply_alpha_mask: whether to apply alpha mask + random_crop: whether to random crop images + multi_resolution: whether to use multi-resolution latents + + Returns: + None """ from library import train_util # import here to avoid circular import img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( - image_infos, alpha_mask, random_crop + image_infos, apply_alpha_mask, random_crop ) img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) @@ -519,12 +544,40 @@ def load_latents_from_disk( ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ for SD/SDXL + + Args: + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray] + ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ return self._default_load_latents_from_disk(None, npz_path, bucket_reso) def _default_load_latents_from_disk( self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + Args: + latents_stride (Optional[int]): Stride for latents. If None, load all latents. + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray] + ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask + """ if latents_stride is None: key_reso_suffix = "" else: @@ -552,6 +605,19 @@ def save_latents_to_disk( alpha_mask=None, key_reso_suffix="", ): + """ + Args: + npz_path (str): Path to the npz file. + latents_tensor (torch.Tensor): Latent tensor + original_size (List[int]): Original size of the image + crop_ltrb (List[int]): Crop left top right bottom + flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor + alpha_mask (Optional[torch.Tensor]): Alpha mask + key_reso_suffix (str): Key resolution suffix + + Returns: + None + """ kwargs = {} if os.path.exists(npz_path): diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py new file mode 100644 index 000000000..c9e654236 --- /dev/null +++ b/library/strategy_lumina.py @@ -0,0 +1,363 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast +from library import train_util +from library.strategy_base import ( + LatentsCachingStrategy, + TokenizeStrategy, + TextEncodingStrategy, + TextEncoderOutputsCachingStrategy, +) +import numpy as np +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +GEMMA_ID = "google/gemma-2-2b" + + +class LuminaTokenizeStrategy(TokenizeStrategy): + def __init__( + self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None + ) -> None: + self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained( + GEMMA_ID, cache_dir=tokenizer_cache_dir + ) + self.tokenizer.padding_side = "right" + + if max_length is None: + self.max_length = 256 + else: + self.max_length = max_length + + def tokenize( + self, text: Union[str, List[str]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + text (Union[str, List[str]]): Text to tokenize + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + token input ids, attention_masks + """ + text = [text] if isinstance(text, str) else text + encodings = self.tokenizer( + text, + max_length=self.max_length, + return_tensors="pt", + padding="max_length", + truncation=True, + pad_to_multiple_of=8, + ) + return (encodings.input_ids, encodings.attention_mask) + + def tokenize_with_weights( + self, text: str | List[str] + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Args: + text (Union[str, List[str]]): Text to tokenize + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + token input ids, attention_masks, weights + """ + # Gemma doesn't support weighted prompts, return uniform weights + tokens, attention_masks = self.tokenize(text) + weights = [torch.ones_like(t) for t in tokens] + return tokens, attention_masks, weights + + +class LuminaTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + super().__init__() + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: Tuple[torch.Tensor, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states, input_ids, attention_masks + """ + text_encoder = models[0] + assert isinstance(text_encoder, Gemma2Model) + input_ids, attention_masks = tokens + + outputs = text_encoder( + input_ids=input_ids.to(text_encoder.device), + attention_mask=attention_masks.to(text_encoder.device), + output_hidden_states=True, + return_dict=True, + ) + + return outputs.hidden_states[-2], input_ids, attention_masks + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: Tuple[torch.Tensor, torch.Tensor], + weights: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks + weights_list (List[torch.Tensor]): Currently unused + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states, input_ids, attention_masks + """ + # For simplicity, use uniform weighting + return self.encode_tokens(tokenize_strategy, models, tokens) + + +class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + ) -> None: + super().__init__( + cache_to_disk, + batch_size, + skip_disk_cache_validity_check, + is_partial, + ) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return ( + os.path.splitext(image_abs_path)[0] + + LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + ) + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + """ + Args: + npz_path (str): Path to the npz file. + + Returns: + bool: True if the npz file is expected to be cached. + """ + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "hidden_state" not in npz: + return False + if "attention_mask" not in npz: + return False + if "input_ids" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + """ + Load outputs from a npz file + + Returns: + List[np.ndarray]: hidden_state, input_ids, attention_mask + """ + data = np.load(npz_path) + hidden_state = data["hidden_state"] + attention_mask = data["attention_mask"] + input_ids = data["input_ids"] + return [hidden_state, input_ids, attention_mask] + + def cache_batch_outputs( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + text_encoding_strategy: TextEncodingStrategy, + batch: List[train_util.ImageInfo], + ) -> None: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + text_encoding_strategy (LuminaTextEncodingStrategy): + infos (List): List of image_info + + Returns: + None + """ + assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) + assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) + + captions = [info.system_prompt or "" + info.caption for info in batch] + + if self.is_weighted: + tokens, attention_masks, weights_list = ( + tokenize_strategy.tokenize_with_weights(captions) + ) + with torch.no_grad(): + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + models, + (tokens, attention_masks), + weights_list, + ) + ) + else: + tokens = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens + ) + ) + + if hidden_state.dtype != torch.float32: + hidden_state = hidden_state.float() + + hidden_state = hidden_state.cpu().numpy() + attention_mask = attention_masks.cpu().numpy() # (B, S) + input_ids = input_ids.cpu().numpy() # (B, S) + + for i, info in enumerate(batch): + hidden_state_i = hidden_state[i] + attention_mask_i = attention_mask[i] + input_ids_i = input_ids[i] + + assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}" + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state=hidden_state_i, + attention_mask=attention_mask_i, + input_ids=input_ids_i, + ) + else: + info.text_encoder_outputs = [ + hidden_state_i, + attention_mask_i, + input_ids_i, + ] + + +class LuminaLatentsCachingStrategy(LatentsCachingStrategy): + LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path( + self, absolute_path: str, image_size: Tuple[int, int] + ) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected( + self, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + ) -> bool: + """ + Args: + bucket_reso (Tuple[int, int]): The resolution of the bucket. + npz_path (str): Path to the npz file. + flip_aug (bool): Whether to flip the image. + alpha_mask (bool): Whether to apply + """ + return self._default_is_disk_cached_latents_expected( + 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True + ) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray], + ]: + """ + Args: + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray], + ]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet + """ + return self._default_load_latents_from_disk( + 8, npz_path, bucket_reso + ) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents( + self, + vae, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + ): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, + vae_device, + vae_dtype, + image_infos, + flip_aug, + alpha_mask, + random_crop, + multi_resolution=True, + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a994..18aceaf7b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -148,10 +148,11 @@ def split_train_val( paths: List[str], + sizes: List[Optional[Tuple[int, int]]], is_training_dataset: bool, validation_split: float, validation_seed: int | None -) -> List[str]: +) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -172,10 +173,12 @@ def split_train_val( # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part - return paths[0:math.ceil(len(paths) * (1 - validation_split))] + split = math.ceil(len(paths) * (1 - validation_split)) + return paths[0:split], sizes[0:split] else: # Validation dataset we split to the second part - return paths[len(paths) - round(len(paths) * validation_split):] + split = len(paths) - round(len(paths) * validation_split) + return paths[split:], sizes[split:] class ImageInfo: @@ -192,7 +195,7 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.latents_flipped: Optional[torch.Tensor] = None self.latents_npz: Optional[str] = None # set in cache_latents self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size - self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = ( None # crop left top right bottom in original pixel size, not latents size ) self.cond_img_path: Optional[str] = None @@ -208,6 +211,8 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.system_prompt: Optional[str] = None + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -431,6 +436,7 @@ def __init__( custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -461,6 +467,8 @@ def __init__( self.validation_seed = validation_seed self.validation_split = validation_split + self.system_prompt = system_prompt + class DreamBoothSubset(BaseSubset): def __init__( @@ -492,6 +500,7 @@ def __init__( custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -519,6 +528,7 @@ def __init__( custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.is_reg = is_reg @@ -561,6 +571,7 @@ def __init__( custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -588,6 +599,7 @@ def __init__( custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.metadata_file = metadata_file @@ -626,6 +638,7 @@ def __init__( custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -653,6 +666,7 @@ def __init__( custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.conditioning_data_dir = conditioning_data_dir @@ -1683,8 +1697,9 @@ def __getitem__(self, index): text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: + system_prompt = subset.system_prompt or "" caption = self.process_caption(subset, image_info.caption) - input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension # if self.XTI_layers: # caption_layer = [] # for layer in self.XTI_layers: @@ -1854,6 +1869,7 @@ def __init__( debug_dataset: bool, validation_split: float, validation_seed: Optional[int], + system_prompt: Optional[str], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -1866,6 +1882,7 @@ def __init__( self.is_training_dataset = is_training_dataset self.validation_seed = validation_seed self.validation_split = validation_split + self.system_prompt = system_prompt self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1931,12 +1948,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): with open(info_cache_file, "r", encoding="utf-8") as f: metas = json.load(f) img_paths = list(metas.keys()) - sizes = [meta["resolution"] for meta in metas.values()] + sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()] # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") - sizes = [None] * len(img_paths) + sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths) # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() @@ -1969,7 +1986,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): w, h = None, None if w is not None and h is not None: - sizes[i] = [w, h] + sizes[i] = (w, h) size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") @@ -1987,11 +2004,13 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] + sizes = [] # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: - img_paths = split_train_val( + img_paths, sizes = split_train_val( img_paths, + sizes, self.is_training_dataset, self.validation_split, self.validation_seed @@ -2054,6 +2073,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): num_train_images = 0 num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] + for subset in subsets: num_repeats = subset.num_repeats if self.is_training_dataset else 1 if num_repeats < 1: @@ -2080,8 +2100,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): else: num_train_images += num_repeats * len(img_paths) + system_prompt = self.system_prompt or subset.system_prompt or "" for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: @@ -2962,7 +2983,7 @@ def trim_and_resize_if_required( # for new_cache_latents def load_images_and_masks_for_caching( image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool -) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: +) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: r""" requires image_infos to have: [absolute_path or image], bucket_reso, resized_size @@ -3463,6 +3484,7 @@ def get_sai_model_spec( is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, flux: str = None, + lumina: str = None, ): timestamp = time.time() @@ -3498,6 +3520,7 @@ def get_sai_model_spec( clip_skip=args.clip_skip, # None or int sd3=sd3, flux=flux, + lumina=lumina, ) return metadata @@ -6165,6 +6188,16 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["controlnet_image"] = m.group(1) continue + m = re.match(r"ct (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["cfg_trunc_ratio"] = float(m.group(1)) + continue + + m = re.match(r"rc (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["renorm_cfg"] = float(m.group(1)) + continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex) diff --git a/lumina_train_network.py b/lumina_train_network.py new file mode 100644 index 000000000..0fd4da6b3 --- /dev/null +++ b/lumina_train_network.py @@ -0,0 +1,406 @@ +import argparse +import copy +from typing import Any, Tuple + +import torch + +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +from torch import Tensor +from accelerate import Accelerator + + +import train_network +from library import ( + lumina_models, + flux_train_utils, + lumina_util, + lumina_train_util, + sd3_train_utils, + strategy_base, + strategy_lumina, + train_util, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LuminaNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group, val_dataset_group): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning("Enabling cache_text_encoder_outputs due to disk caching") + args.cache_text_encoder_outputs = True + + train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) + + self.train_gemma2 = not args.network_train_unet_only + + def load_target_model(self, args, weight_dtype, accelerator): + loading_dtype = None if args.fp8_base else weight_dtype + + model = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + loading_dtype, + torch.device("cpu"), + disable_mmap=args.disable_mmap_load_safetensors, + use_flash_attn=args.use_flash_attn, + ) + + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 Lumina 2 model") + else: + logger.info( + "Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + + # if args.blocks_to_swap: + # logger.info(f'Enabling block swap: {args.blocks_to_swap}') + # model.enable_block_swap(args.blocks_to_swap, accelerator.device) + # self.is_swapping_blocks = True + + gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") + gemma2.eval() + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + + return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model + + def get_tokenize_strategy(self, args): + return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy): + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + + def get_text_encoding_strategy(self, args): + return strategy_lumina.LuminaTextEncodingStrategy() + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_gemma2] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_gemma2, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, + args, + accelerator: Accelerator, + unet, + vae, + text_encoders, + dataset, + weight_dtype, + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + + if text_encoders[0].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[0].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}") + + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) + assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) + + system_prompt = args.system_prompt or "" + sample_prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in sample_prompts: + prompts = [ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ] + for prompt in prompts: + prompt = system_prompt + prompt + if prompt in sample_prompts_te_outputs: + continue + + logger.info(f"cache Text Encoder outputs for prompt: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + ) + + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move Gemma 2 back to cpu") + text_encoders[0].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + + def sample_images( + self, + accelerator, + args, + epoch, + global_step, + device, + vae, + tokenizer, + text_encoder, + lumina, + ): + lumina_train_util.sample_images( + accelerator, + args, + epoch, + global_step, + lumina, + vae, + self.get_models_for_text_encoding(args, accelerator, text_encoder), + self.sample_prompts_te_outputs, + ) + + # Remaining methods maintain similar structure to flux implementation + # with Lumina-specific model calls and strategies + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + # not sure, they use same flux vae + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator: Accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks) + dit: lumina_models.NextDiT, + network, + weight_dtype, + train_unet, + is_train=True, + ): + assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler) + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = lumina_train_util.compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `latents` + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Unpack Gemma2 outputs + gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds + + def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps): + with torch.set_grad_enabled(is_train), accelerator.autocast(): + # NextDiT forward expects (x, t, cap_feats, cap_mask) + model_pred = dit( + x=img, # image latents (B, C, H, W) + t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features + cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask + ) + return model_pred + + model_pred = call_dit( + img=noisy_model_input, + gemma2_hidden_states=gemma2_hidden_states, + gemma2_attn_mask=gemma2_attn_mask, + timesteps=timesteps, + ) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss + target = latents - noise + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(): + model_pred_prior = call_dit( + img=noisy_model_input[diff_output_pr_indices], + gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]), + ) + network.set_multiplier(1.0) + + # model_pred_prior = lumina_util.unpack_latents( + # model_pred_prior, packed_latent_height, packed_latent_width + # ) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2") + + def update_metadata(self, metadata, args): + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + text_encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.embed_tokens.to(dtype=weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + nextdit = unet + assert isinstance(nextdit, lumina_models.NextDiT) + nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() + + return nextdit + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + lumina_train_util.add_lumina_train_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = LuminaNetworkTrainer() + trainer.train(args) diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py new file mode 100644 index 000000000..3f6c9b417 --- /dev/null +++ b/networks/lora_lumina.py @@ -0,0 +1,1011 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + split_dims: Optional[List[int]] = None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of lumina as same as Diffusers + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + lumina, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim for JointTransformerBlock + attn_dim = kwargs.get("attn_dim", None) # attention dimension + mlp_dim = kwargs.get("mlp_dim", None) # MLP dimension + mod_dim = kwargs.get("mod_dim", None) # modulation dimension + refiner_dim = kwargs.get("refiner_dim", None) # refiner blocks dimension + + if attn_dim is not None: + attn_dim = int(attn_dim) + if mlp_dim is not None: + mlp_dim = int(mlp_dim) + if mod_dim is not None: + mod_dim = int(mod_dim) + if refiner_dim is not None: + refiner_dim = int(refiner_dim) + + type_dims = [attn_dim, mlp_dim, mod_dim, refiner_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims for embedders + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] + assert len(in_dims) == 4, f"invalid in_dims: {in_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder, final_layer)" + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + lumina, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + train_blocks=train_blocks, + split_qkv=split_qkv, + type_dims=type_dims, + in_dims=in_dims, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + lumina, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"] + LORA_PREFIX_LUMINA = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder + + def __init__( + self, + text_encoders, # Now this will be a single Gemma2 model + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, + split_qkv: bool = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv + + self.type_dims = type_dims + self.in_dims = in_dims + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + + # create module instances + def create_modules( + is_lumina: bool, + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # for handling embedders + module = root_module + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_lumina and type_dims is not None: + identifier = [ + ("attention",), # attention layers + ("mlp",), # MLP layers + ("modulation",), # modulation layers + ("refiner",), # refiner blocks + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder (Gemma2) + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + + logger.info(f"create LoRA for Gemma2 Text Encoder:") + text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules) + + # Handle embedders + if self.in_dims: + for filter, in_dim in zip(["x_embedder", "t_embedder", "cap_embedder", "final_layer"], self.in_dims): + loras, _ = create_modules(True, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # # split qkv + # for key in list(state_dict.keys()): + # if "double" in key and "qkv" in key: + # split_dims = [3072] * 3 + # elif "single" in key and "linear1" in key: + # split_dims = [3072] * 3 + [12288] + # else: + # continue + + # weight = state_dict[key] + # lora_name = key.split(".")[0] + + # if key not in state_dict: + # continue # already merged + + # # (rank, in_dim) * 3 + # down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # # (split dim, rank) * 3 + # up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + # alpha = state_dict.pop(f"{lora_name}.alpha") + + # # merge down weight + # down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # # merge up weight (sum of split_dim, rank*3) + # rank = up_weights[0].size(1) + # up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + # i = 0 + # for j in range(len(split_dims)): + # up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + # i += split_dims[j] + + # state_dict[f"{lora_name}.lora_down.weight"] = down_weight + # state_dict[f"{lora_name}.lora_up.weight"] = up_weight + # state_dict[f"{lora_name}.alpha"] = alpha + + # # print( + # # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # # ) + # print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + # merge qkv + state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_LUMINA): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te_loras = [lora for lora in self.text_encoder_loras] + if len(te_loras) > 0: + logger.info(f"Text Encoder: {len(te_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/train_network.py b/train_network.py index c3879531d..07de30b3b 100644 --- a/train_network.py +++ b/train_network.py @@ -129,7 +129,7 @@ def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetG if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) - def load_target_model(self, args, weight_dtype, accelerator): + def load_target_model(self, args, weight_dtype, accelerator) -> tuple: text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # モデルに xformers とか memory efficient attention を組み込む @@ -354,12 +354,13 @@ def process_batch( if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions']) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), @@ -1241,6 +1242,7 @@ def remove_model(old_ckpt_name): # For --sample_at_first optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + progress_bar.unpause() # Reset progress bar to before sampling images optimizer_train_fn() is_tracking = len(accelerator.trackers) > 0 if is_tracking: @@ -1343,6 +1345,7 @@ def remove_model(old_ckpt_name): self.sample_images( accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet ) + progress_bar.unpause() # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1498,7 +1501,7 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, @@ -1530,6 +1533,7 @@ def remove_model(old_ckpt_name): train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + progress_bar.unpause() optimizer_train_fn() # end of epoch