Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for lumina2 #10642

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Conversation

zhuole1025
Copy link
Contributor

What does this PR do?

This PR will add the official Lumina-Image 2 to the diffusers. Lumina-Image 2.0 is the latest model in the Lumina family and will be released very soon (https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0). It is a 2B parameter Diffusion Transformer that significantly improves instruction-following and generates higher-quality, more diverse images. Our paper will be released soon, and we have finished the diffuser pipeline for Lumina-Image 2.0.

Core library:

cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]

for layer in self.context_refiner:
Copy link
Collaborator

@yiyixuxu yiyixuxu Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think context_refiner and noise_refiner should be moved inside foward for better readbility
does it make sense to wrap the code to create freqs_cis inside a class like this

class LTXVideoRotaryPosEmbed(nn.Module):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I have moved refiner into forward() for better readbility. As for the freqs_cis, we already have a class Lumina2PosEmbed.

image_rotary_emb = image_rotary_emb.to(hidden_states.device)

for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cap_freqs_cis is not returned from the method above, no?

"""

temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states)
cap_feats, padded_img_embed, img_size, l_effective_cap_len, l_effective_img_len, image_rotary_emb, max_seq_len, cap_mask, padded_img_mask = self.patchify_and_embed(hidden_states, encoder_hidden_states, encoder_mask, temb)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks at this patchify_and_embed I think cap_feats and cap_mask are not changed, and I think it's suppose to return cap_freqs_cis, no? but it was not included in the list of outputs that it returned

cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)

for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is t here? is it temb?

Comment on lines 289 to 301
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device

l_effective_cap_len = cap_mask.sum(dim=1).tolist()
img_sizes = [(img.size(1), img.size(2)) for img in x]
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]

max_seq_len = max(
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
)
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these length variables can be created inside forward, no?
it is used through out the forward pass


return padded_full_embed, mask

def forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly sure I understand the code entirely here (I think the code as it is is not able to run, no?)
so let me know if my suggestion does not make sense

does it make sense to make a Lumina2PatchEmbed wrap all the logic we have here to create image_rotary_emb and cap_freqs_cis? (instead of the simple PosEmb that returns the freqs_cis based on position ids)

so the forward pass I have in mind is something like this

temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states)

image_sizes = ...
list_text_seq_length = ..  
list_num_patches = ...
max_seq_len = ...

image_rotary_emb, text_rotary_emb = self.patch_embed(....)
for layer in self.context_refier:
   ...

padded_img_embed, padded_img_mask = self.patchify_image(..)

for layer in self.noise_refiner:
    ...

hidden_state, mask = .

for layer in self.layers:

...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So sorry to cause you such confusion! Now I have updated the latest update, including making a Lumina2PatchEmbed following your suggestions. I have tested the code to correctly generate images.

@rodjjo
Copy link

rodjjo commented Feb 3, 2025

I know we have support to pipe.enable_model_cpu_offload.
Do you have the intent of adding pipe.enable_sequential_cpu_offload ?

I tried it out, not currently supported.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @zhuole1025! It looks great and almost ready to merge

We need to add tests and docs apart from addressing the comments/questions here. For both, this PR would serve as a good example of what files needs to be updated. In tests, a modeling test (test_models_transformer_lumina2.py) and a pipeline test (test_lumina2.py) will be needed. More than happy to help make any of the required changes to move the PR to completion 🤗

@@ -612,7 +612,6 @@ def __init__(
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
inner_dim = int(2 * inner_dim / 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unintented/backwards-breaking change, yes? If so, let's revert.

Edit: oh okay, not backwards breaking. I see you've made the required changes in lumina_nextdit2d.py

@@ -4192,6 +4192,102 @@ def __call__(
return hidden_states


class Lumina2AttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this attention processor implementation to the transformer_lumina2.py file, since all new models are now being added in single-file format

Comment on lines +4211 to +4212
query_rotary_emb: Optional[torch.Tensor] = None,
key_rotary_emb: Optional[torch.Tensor] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the rotary_emb being passed for both the query and key are the same. Not sure if I'm missing a case where both would have to be different in the implementation. If not supposed to be different, let's use a single image_rotary_emb parameter here

Comment on lines +4219 to +4221
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

I think we're passing ndim=3 tensors here, so can simplify this code and remove it

@@ -1766,6 +1766,38 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde
return conditioning


class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this module to the transformer_lumina2.py file as well

Comment on lines +379 to +381
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)

assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"

Let's do the assertion in the RoPE layer

Comment on lines +422 to +428
for layer in self.layers:
hidden_states = layer(
hidden_states,
mask,
joint_rotary_emb,
temb=temb,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs gradient checkpointing to be enabled, similar to:

hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(

>>> from diffusers import Lumina2Text2ImgPipeline

>>> pipe = Lumina2Text2ImgPipeline.from_pretrained(
... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just marking as a TODO for updating checkpoint id

"""


def calculate_shift(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a # Copied from here

Comment on lines +352 to +359
prompt_max_length = prompt_embeds.shape[1]
negative_text_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=prompt_max_length,
truncation=True,
return_tensors="pt",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could _get_gemma_prompt_embeds be leveraged here as well to compute negative embeds, or not really?

self,
x: list[torch.Tensor] | torch.Tensor,
cap_mask: torch.Tensor,
t: torch.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think some of YiYi's comments are yet to be addressed (such as naming this to temb) and her questions. In general, I think you could change any occcurences of x to hidden_states, cap_mask to attention_mask, and similar naming conventions followed in latest diffusers implementations for consistency reasons

@a-r-r-o-w a-r-r-o-w added close-to-merge roadmap Add to current release roadmap labels Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
close-to-merge roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

5 participants