Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for lumina2 #10642

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"LuminaNextDiT2DModel",
"Lumina2Transformer2DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
Expand Down Expand Up @@ -329,6 +330,7 @@
"LTXImageToVideoPipeline",
"LTXPipeline",
"LuminaText2ImgPipeline",
"Lumina2Text2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
Expand Down Expand Up @@ -622,6 +624,7 @@
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
Lumina2Transformer2DModel,
MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
Expand Down Expand Up @@ -820,6 +823,7 @@
LTXImageToVideoPipeline,
LTXPipeline,
LuminaText2ImgPipeline,
Lumina2Text2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
Expand Down Expand Up @@ -139,6 +140,7 @@
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
Lumina2Transformer2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
PriorTransformer,
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,6 @@ def __init__(
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
inner_dim = int(2 * inner_dim / 3)
zhuole1025 marked this conversation as resolved.
Show resolved Hide resolved
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
Expand Down
97 changes: 97 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4192,6 +4192,102 @@ def __call__(
return hidden_states


class Lumina2AttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this attention processor implementation to the transformer_lumina2.py file, since all new models are now being added in single-file format

r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and key vector.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[torch.Tensor] = None,
key_rotary_emb: Optional[torch.Tensor] = None,
Comment on lines +4211 to +4212
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the rotary_emb being passed for both the query and key are the same. Not sure if I'm missing a case where both would have to be different in the implementation. If not supposed to be different, let's use a single image_rotary_emb parameter here

base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
from embeddings import apply_rotary_emb
zhuole1025 marked this conversation as resolved.
Show resolved Hide resolved

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
Comment on lines +4219 to +4221
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

I think we're passing ndim=3 tensors here, so can simplify this code and remove it


batch_size, sequence_length, _ = hidden_states.shape

# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype

# Get key-value heads
kv_heads = inner_dim // head_dim

query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)

# Apply Query-Key Norm if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
if key_rotary_emb is not None:
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)

query, key = query.to(dtype), key.to(dtype)

# Apply proportional attention if true
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale

# perform Grouped-qurey Attention (GQA)
n_rep = attn.heads // kv_heads
if n_rep >= 1:
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)

# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)

return hidden_states


class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
Expand Down Expand Up @@ -6183,6 +6279,7 @@ def __call__(
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
Lumina2AttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
Expand Down
32 changes: 32 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,38 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde
return conditioning


class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this module to the transformer_lumina2.py file as well

def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5):
super().__init__()

from .normalization import RMSNorm

self.time_proj = Timesteps(
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
)

self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024))

self.caption_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(
cap_feat_dim,
hidden_size,
bias=True,
),
)

def forward(self, timestep, caption_feat):
# timestep embedding:
time_freq = self.time_proj(timestep)
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))

# caption condition embedding:
caption_embed = self.caption_embedder(caption_feat)

return time_embed, caption_embed


class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
super().__init__()
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,13 @@ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine:
4 * embedding_dim,
bias=True,
)
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm = RMSNorm(embedding_dim, eps=norm_eps)

def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/lumina_nextdit2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(

self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
inner_dim=int(4 * 2 * dim / 3),
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
Expand Down
Loading