-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for lumina2 #10642
base: main
Are you sure you want to change the base?
Add support for lumina2 #10642
Changes from 1 commit
43825c4
81f47df
66ef8d7
29fdccd
f64f7e2
e26faef
1ff61fd
db6fcf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -4192,6 +4192,102 @@ def __call__( | |||||||
return hidden_states | ||||||||
|
||||||||
|
||||||||
class Lumina2AttnProcessor2_0: | ||||||||
r""" | ||||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is | ||||||||
used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and key vector. | ||||||||
""" | ||||||||
|
||||||||
def __init__(self): | ||||||||
if not hasattr(F, "scaled_dot_product_attention"): | ||||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | ||||||||
|
||||||||
def __call__( | ||||||||
self, | ||||||||
attn: Attention, | ||||||||
hidden_states: torch.Tensor, | ||||||||
encoder_hidden_states: torch.Tensor, | ||||||||
attention_mask: Optional[torch.Tensor] = None, | ||||||||
query_rotary_emb: Optional[torch.Tensor] = None, | ||||||||
key_rotary_emb: Optional[torch.Tensor] = None, | ||||||||
Comment on lines
+4211
to
+4212
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see that the |
||||||||
base_sequence_length: Optional[int] = None, | ||||||||
) -> torch.Tensor: | ||||||||
from embeddings import apply_rotary_emb | ||||||||
zhuole1025 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
input_ndim = hidden_states.ndim | ||||||||
|
||||||||
if input_ndim == 4: | ||||||||
batch_size, channel, height, width = hidden_states.shape | ||||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | ||||||||
Comment on lines
+4219
to
+4221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think we're passing ndim=3 tensors here, so can simplify this code and remove it |
||||||||
|
||||||||
batch_size, sequence_length, _ = hidden_states.shape | ||||||||
|
||||||||
# Get Query-Key-Value Pair | ||||||||
query = attn.to_q(hidden_states) | ||||||||
key = attn.to_k(encoder_hidden_states) | ||||||||
value = attn.to_v(encoder_hidden_states) | ||||||||
|
||||||||
query_dim = query.shape[-1] | ||||||||
inner_dim = key.shape[-1] | ||||||||
head_dim = query_dim // attn.heads | ||||||||
dtype = query.dtype | ||||||||
|
||||||||
# Get key-value heads | ||||||||
kv_heads = inner_dim // head_dim | ||||||||
|
||||||||
query = query.view(batch_size, -1, attn.heads, head_dim) | ||||||||
key = key.view(batch_size, -1, kv_heads, head_dim) | ||||||||
value = value.view(batch_size, -1, kv_heads, head_dim) | ||||||||
|
||||||||
# Apply Query-Key Norm if needed | ||||||||
if attn.norm_q is not None: | ||||||||
query = attn.norm_q(query) | ||||||||
if attn.norm_k is not None: | ||||||||
key = attn.norm_k(key) | ||||||||
|
||||||||
# Apply RoPE if needed | ||||||||
if query_rotary_emb is not None: | ||||||||
query = apply_rotary_emb(query, query_rotary_emb, use_real=False) | ||||||||
if key_rotary_emb is not None: | ||||||||
key = apply_rotary_emb(key, key_rotary_emb, use_real=False) | ||||||||
|
||||||||
query, key = query.to(dtype), key.to(dtype) | ||||||||
|
||||||||
# Apply proportional attention if true | ||||||||
if base_sequence_length is not None: | ||||||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale | ||||||||
else: | ||||||||
softmax_scale = attn.scale | ||||||||
|
||||||||
# perform Grouped-qurey Attention (GQA) | ||||||||
n_rep = attn.heads // kv_heads | ||||||||
if n_rep >= 1: | ||||||||
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) | ||||||||
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) | ||||||||
|
||||||||
# scaled_dot_product_attention expects attention_mask shape to be | ||||||||
# (batch, heads, source_length, target_length) | ||||||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) | ||||||||
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) | ||||||||
|
||||||||
query = query.transpose(1, 2) | ||||||||
key = key.transpose(1, 2) | ||||||||
value = value.transpose(1, 2) | ||||||||
|
||||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | ||||||||
# TODO: add support for attn.scale when we move to Torch 2.1 | ||||||||
hidden_states = F.scaled_dot_product_attention( | ||||||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale | ||||||||
) | ||||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||||||||
hidden_states = hidden_states.to(dtype) | ||||||||
|
||||||||
# linear proj | ||||||||
hidden_states = attn.to_out[0](hidden_states) | ||||||||
|
||||||||
return hidden_states | ||||||||
|
||||||||
|
||||||||
class LuminaAttnProcessor2_0: | ||||||||
r""" | ||||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is | ||||||||
|
@@ -6183,6 +6279,7 @@ def __call__( | |||||||
PAGHunyuanAttnProcessor2_0, | ||||||||
PAGCFGHunyuanAttnProcessor2_0, | ||||||||
LuminaAttnProcessor2_0, | ||||||||
Lumina2AttnProcessor2_0, | ||||||||
FusedAttnProcessor2_0, | ||||||||
CustomDiffusionXFormersAttnProcessor, | ||||||||
CustomDiffusionAttnProcessor2_0, | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -569,6 +569,35 @@ def forward(self, latent): | |
return (latent + pos_embed).to(latent.dtype) | ||
|
||
|
||
class Lumina2PosEmbed(nn.Module): | ||
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512)): | ||
super().__init__() | ||
self.theta = theta | ||
self.axes_dim = axes_dim | ||
self.axes_lens = axes_lens | ||
self.freqs_cis = self.precompute_freqs_cis(axes_dim, axes_lens, theta) | ||
|
||
def precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: | ||
freqs_cis = [] | ||
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): | ||
emb = get_1d_rotary_pos_embed( | ||
d, | ||
e, | ||
theta=self.theta, | ||
freqs_dtype=torch.float64, | ||
) | ||
freqs_cis.append(emb) | ||
return freqs_cis | ||
|
||
def forward(self, ids: torch.Tensor) -> torch.Tensor: | ||
result = [] | ||
for i in range(len(self.axes_dim)): | ||
freqs = self.freqs_cis[i].to(ids.device) | ||
index = ids[:, :, i:i+1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) | ||
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) | ||
return torch.cat(result, dim=-1) | ||
|
||
|
||
class LuminaPatchEmbed(nn.Module): | ||
""" | ||
2D Image to Patch Embedding with support for Lumina-T2X | ||
|
@@ -1766,6 +1795,38 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde | |
return conditioning | ||
|
||
|
||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this module to the |
||
def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5): | ||
super().__init__() | ||
|
||
from normalization import RMSNorm | ||
zhuole1025 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.time_proj = Timesteps( | ||
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 | ||
) | ||
|
||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)) | ||
|
||
self.caption_embedder = nn.Sequential( | ||
RMSNorm(cap_feat_dim, eps=norm_eps), | ||
nn.Linear( | ||
cap_feat_dim, | ||
hidden_size, | ||
bias=True, | ||
), | ||
) | ||
|
||
def forward(self, timestep, caption_feat): | ||
# timestep embedding: | ||
time_freq = self.time_proj(timestep) | ||
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) | ||
|
||
# caption condition embedding: | ||
caption_embed = self.caption_embedder(caption_feat) | ||
|
||
return time_embed, caption_embed | ||
|
||
|
||
class LuminaCombinedTimestepCaptionEmbedding(nn.Module): | ||
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): | ||
super().__init__() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this attention processor implementation to the
transformer_lumina2.py
file, since all new models are now being added in single-file format