-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for lumina2 #10642
base: main
Are you sure you want to change the base?
Add support for lumina2 #10642
Conversation
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] | ||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] | ||
|
||
for layer in self.context_refiner: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think context_refiner and noise_refiner should be moved inside foward
for better readbility
does it make sense to wrap the code to create freqs_cis
inside a class like this
class LTXVideoRotaryPosEmbed(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I have moved refiner into forward() for better readbility. As for the freqs_cis, we already have a class Lumina2PosEmbed.
image_rotary_emb = image_rotary_emb.to(hidden_states.device) | ||
|
||
for layer in self.context_refiner: | ||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cap_freqs_cis
is not returned from the method above, no?
""" | ||
|
||
temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states) | ||
cap_feats, padded_img_embed, img_size, l_effective_cap_len, l_effective_img_len, image_rotary_emb, max_seq_len, cap_mask, padded_img_mask = self.patchify_and_embed(hidden_states, encoder_hidden_states, encoder_mask, temb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks at this patchify_and_embed
I think cap_feats
and cap_mask
are not changed, and I think it's suppose to return cap_freqs_cis
, no? but it was not included in the list of outputs that it returned
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) | ||
|
||
for layer in self.noise_refiner: | ||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is t
here? is it temb
?
bsz = len(x) | ||
pH = pW = self.patch_size | ||
device = x[0].device | ||
|
||
l_effective_cap_len = cap_mask.sum(dim=1).tolist() | ||
img_sizes = [(img.size(1), img.size(2)) for img in x] | ||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] | ||
|
||
max_seq_len = max( | ||
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) | ||
) | ||
max_cap_len = max(l_effective_cap_len) | ||
max_img_len = max(l_effective_img_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these length variables can be created inside forward
, no?
it is used through out the forward pass
|
||
return padded_full_embed, mask | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not exactly sure I understand the code entirely here (I think the code as it is is not able to run, no?)
so let me know if my suggestion does not make sense
does it make sense to make a Lumina2PatchEmbed
wrap all the logic we have here to create image_rotary_emb and cap_freqs_cis? (instead of the simple PosEmb that returns the freqs_cis based on position ids)
so the forward pass I have in mind is something like this
temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states)
image_sizes = ...
list_text_seq_length = ..
list_num_patches = ...
max_seq_len = ...
image_rotary_emb, text_rotary_emb = self.patch_embed(....)
for layer in self.context_refier:
...
padded_img_embed, padded_img_mask = self.patchify_image(..)
for layer in self.noise_refiner:
...
hidden_state, mask = .
for layer in self.layers:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So sorry to cause you such confusion! Now I have updated the latest update, including making a Lumina2PatchEmbed following your suggestions. I have tested the code to correctly generate images.
I know we have support to pipe.enable_model_cpu_offload. I tried it out, not currently supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @zhuole1025! It looks great and almost ready to merge
We need to add tests and docs apart from addressing the comments/questions here. For both, this PR would serve as a good example of what files needs to be updated. In tests, a modeling test (test_models_transformer_lumina2.py
) and a pipeline test (test_lumina2.py
) will be needed. More than happy to help make any of the required changes to move the PR to completion 🤗
@@ -612,7 +612,6 @@ def __init__( | |||
ffn_dim_multiplier: Optional[float] = None, | |||
): | |||
super().__init__() | |||
inner_dim = int(2 * inner_dim / 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is unintented/backwards-breaking change, yes? If so, let's revert.
Edit: oh okay, not backwards breaking. I see you've made the required changes in lumina_nextdit2d.py
@@ -4192,6 +4192,102 @@ def __call__( | |||
return hidden_states | |||
|
|||
|
|||
class Lumina2AttnProcessor2_0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this attention processor implementation to the transformer_lumina2.py
file, since all new models are now being added in single-file format
query_rotary_emb: Optional[torch.Tensor] = None, | ||
key_rotary_emb: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that the rotary_emb
being passed for both the query and key are the same. Not sure if I'm missing a case where both would have to be different in the implementation. If not supposed to be different, let's use a single image_rotary_emb
parameter here
if input_ndim == 4: | ||
batch_size, channel, height, width = hidden_states.shape | ||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
I think we're passing ndim=3 tensors here, so can simplify this code and remove it
@@ -1766,6 +1766,38 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde | |||
return conditioning | |||
|
|||
|
|||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this module to the transformer_lumina2.py
file as well
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) | ||
|
||
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) | |
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" |
Let's do the assertion in the RoPE layer
for layer in self.layers: | ||
hidden_states = layer( | ||
hidden_states, | ||
mask, | ||
joint_rotary_emb, | ||
temb=temb, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs gradient checkpointing to be enabled, similar to:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( |
>>> from diffusers import Lumina2Text2ImgPipeline | ||
|
||
>>> pipe = Lumina2Text2ImgPipeline.from_pretrained( | ||
... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just marking as a TODO for updating checkpoint id
""" | ||
|
||
|
||
def calculate_shift( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a # Copied from
here
prompt_max_length = prompt_embeds.shape[1] | ||
negative_text_inputs = self.tokenizer( | ||
negative_prompt, | ||
padding="max_length", | ||
max_length=prompt_max_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could _get_gemma_prompt_embeds
be leveraged here as well to compute negative embeds, or not really?
self, | ||
x: list[torch.Tensor] | torch.Tensor, | ||
cap_mask: torch.Tensor, | ||
t: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think some of YiYi's comments are yet to be addressed (such as naming this to temb
) and her questions. In general, I think you could change any occcurences of x
to hidden_states
, cap_mask
to attention_mask
, and similar naming conventions followed in latest diffusers implementations for consistency reasons
What does this PR do?
This PR will add the official Lumina-Image 2 to the diffusers. Lumina-Image 2.0 is the latest model in the Lumina family and will be released very soon (https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0). It is a 2B parameter Diffusion Transformer that significantly improves instruction-following and generates higher-quality, more diverse images. Our paper will be released soon, and we have finished the diffuser pipeline for Lumina-Image 2.0.
Core library: