-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Mochi Quality Issues #10033
base: main
Are you sure you want to change the base?
Fix Mochi Quality Issues #10033
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for all the fixes Dhruv!
Very grateful to @YanzuoLu and @Ednaordinary for their continuous help with testing different things out - thank you!
logger = logging.get_logger(__name__) # pylint: disable=invalid-n | ||
|
||
|
||
class MochiModulatedRMSNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love seeing the single file format for modeling!
|
||
|
||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
logger = logging.get_logger(__name__) # pylint: disable=invalid-n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger = logging.get_logger(__name__) # pylint: disable=invalid-n | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | ||
hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could do one cast for the hidden states instead of two here?
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | ||
hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, a bit neater to cast hdiden_states before the following statements (not really problem though)
input_dtype = x.dtype | ||
|
||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) | ||
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) | ||
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32))) | ||
|
||
return x.to(input_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this pattern is very common for some kinds of layers. Do you think we could work on a refactor in the future where we decorate the forward methods with something like @upcast_to_fp32
so that the type conversions occur outside of the forward and they look like clean mathematical input-to-output mapping? This way, it could also be disabled with a global context manager if upcasting is not required for certain models, but be enabled by default.
return hidden_states, gate_msa, scale_mlp, gate_mlp | ||
|
||
|
||
class MochiAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting to see! I did not know we wanted to move from the central Attention class as well -- only thought we would be breaking up BasicTransformerBlock. Super cool and clean!
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) | ||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) | ||
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) | ||
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm probably missing something here, but can you point me to where do we handle the downcast for this upcast? Because otherwise all successive computations would be in float32, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Downcast here
hidden_states = hidden_states.to(hidden_states_dtype) |
for idx in range(batch_size): | ||
mask = attention_mask[idx][None, :] | ||
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() | ||
|
||
valid_encoder_query = torch.index_select(encoder_query[idx][None, :], 2, valid_prompt_token_indices) | ||
valid_encoder_key = torch.index_select(encoder_key[idx][None, :], 2, valid_prompt_token_indices) | ||
valid_encoder_value = torch.index_select(encoder_value[idx][None, :], 2, valid_prompt_token_indices) | ||
|
||
valid_query = torch.cat([query[idx][None, :], valid_encoder_query], dim=2) | ||
valid_key = torch.cat([key[idx][None, :], valid_encoder_key], dim=2) | ||
valid_value = torch.cat([value[idx][None, :], valid_encoder_value], dim=2) | ||
|
||
attn_output = F.scaled_dot_product_attention( | ||
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False | ||
) | ||
valid_sequence_length = attn_output.size(2) | ||
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) | ||
attn_outputs.append(attn_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome catch and fix!
I think this might not play well with data parallel implementations due to the loop. i will profile this in the future and we can try a different implementation in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For data parallelism, we're replicating the model on each worker, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe used wrong terminology. what I meant was parallelizing across batch dimension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Torch 2.5 has nested tensors as a prototype feature that allows variable sequence lengths in a batch
https://pytorch.org/docs/stable/nested.html
We could do a version check and use that here for parallelizing across batch dimension. The docs says the API is subject to change so I avoided using it.
@DN6 Hey can you give an example code to run your fixed branch? am really curious about the quality changes with those precision fixes. |
@DN6 can we get this merged now? :) the tests need to be fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok i went over the PR so turns out there are a lot of refactorings (in addition to the fix)
i left some comments, I think it is not ready to merge yet and there are some clean up to be done :)
love that we start to take an action on having model-specific attention/attention class now! and +1 on @a-r-r-o-w's comments here to think about a better way to manage precision (can be future PRs, does not have to be done here) https://github.com/huggingface/diffusers/pull/10033/files#r1860070716
also, do we know what actually cause the issue? all of them together?
return hidden_states, gate_msa, scale_mlp, gate_mlp | ||
|
||
|
||
class MochiAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should have a basic attention class (e.g. AttentionMixin
) everything inherits from, no? so that some methods are available, set_processor
, get_processor
etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Should I add in this PR or follow up? I don't think Mochi needs those methods since it just has a single processor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
follow PR is ok! but let's have a rough plan for the follow-up PR before merge
we only have one attention processor for mochi, but these methods are for people to use custom attention processors, also
methods like attn_processor
depends on this, and they are useful
def attn_processors(self) -> Dict[str, AttentionProcessor]: |
also, let's make sure we don't have something we need relhy with this kind of logic
if isinstance(module, Attention): |
out_dim: int = None, | ||
out_context_dim: int = None, | ||
out_bias: bool = True, | ||
context_pre_only: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably can clean up the arguments a bit now? e.g. remove these not used by mochi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these arguments are all used in the Mochi Attention layer.
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | ||
hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) | ||
|
||
if scale is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at how this layer is used, I don't think scale
should be passed down here, this should just be a regular RMSNorm
; the scale
part should be part of other operation (e.g. variation of AdaLayerNorm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way it's used is a bit different in Mochi. The outputs from the linear layer in MochiRMSNormZero (gate_msa, gate_mlp etc) are passed to the modulation after attention.
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) |
There is also a difference in which parts of the modulation are upcast. The parts that have tanh applied to them remain in BF16, while those that don't are upcast to FP32. The hidden states are always upcast to FP32 before norming and multiplying by the modulation and then the output is cast back to BF16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for this comment
The way it's used is a bit different in Mochi. The outputs from the linear layer in MochiRMSNormZero (gate_msa, gate_mlp etc) are passed to the modulation after attention.
I don't think it's different, I think this is the case for all our DIT models, maybe with slight difference in implementation from model to mode,l but we always apply the "modulate" part explicitly inside the transformer blocks, we should do the same for Mochi too. see some examples here
flux:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
sd3:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
cogvideo
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] |
MochiModulatedRMSNorm
you defined from scratch but it is actually just MochiRMSNorm
with a scale, and can be write as
class MochiModulatedRMSNorm(nn.Module):
def __init__(self, eps: float):
super().__init__()
self.norm = MochiRMSNorm(dim=None, eps=eps, elementwise_affine=False)
def forward(self, hidden_states, scale=None):
hidden_states = self.norm(hidden_states)
if scale is not None:
hidden_states = hidden_states * scale
return hidden_states
so for layers that's are currently using MochiModulatedRMSNorm
, we should easily rewrite with RMSNorm
too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in this example the output of the self.norm
would be torch.bfloat16
.
When applying modulation the hidden state has to be in FP32. Just upcasting the output of the norm unfortunately isn't enough to reproduce the original result.
Downcasting to BF16 only happens after modulation
https://github.com/genmoai/mochi/blob/f3a800aea5862b4af13e66ff77eea1967c8c3a7f/src/genmo/mochi_preview/dit/joint_model/mod_rmsnorm.py#L15
And in the case of tanh modulation, right before adding to the residual
https://github.com/genmoai/mochi/blob/f3a800aea5862b4af13e66ff77eea1967c8c3a7f/src/genmo/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py#L6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh. I see what you mean. Upcast hidden states in MochiModulatedNorm and the output is equivalent. Got it.
return hidden_states | ||
|
||
|
||
class MochiRMSNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks like identical to RMSNorm
, no? did I miss something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our RMSNorm the hidden_states are not upcast for the entire operation
diffusers/src/diffusers/models/normalization.py
Lines 531 to 533 in 6b288ec
input_dtype = hidden_states.dtype | |
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) |
We only upcast when computing the variance. For Mochi the hidden_states are in FP32 throughout and only downcast at the end.
Changing RMSNorm to upcast throughout might affect other models, so I added a new class here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why cannot we upcast the input instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in fact this operation
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if the variance
is in float32 and hidden_states
in a lower precision, pytorch will automatically upcast float32 anyway
a little demo script
import torch
# Create a float16 tensor (simulating hidden_states)
hidden_states = torch.ones(2, 3, dtype=torch.float16)
print("Initial hidden_states dtype:", hidden_states.dtype) # float16
# Calculate variance in float32 (like the code)
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
print("Variance dtype:", variance.dtype) # float32
# Perform the multiplication (this is the line we're testing)
eps = 1e-5
result = hidden_states * torch.rsqrt(variance + eps)
print("Result dtype:", result.dtype) # float32
# Print all values to verify
print("\nValues:")
print("hidden_states:", hidden_states)
print("variance:", variance)
print("result:", result)
Initial hidden_states dtype: torch.float16
Variance dtype: torch.float32
Result dtype: torch.float32
Values:
hidden_states: tensor([[1., 1., 1.],
[1., 1., 1.]], dtype=torch.float16)
variance: tensor([[1.],
[1.]])
result: tensor([[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so I really don't think we need new class here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, the RMSNorm in the Attentions can be replaced.
|
||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) | ||
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) | ||
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I commented on MochiModulatedRMSNorm
- we should apply scale
here, instead of passing it to MochiModulatedRMSNorm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the reason to pass this in is because of how upcasts are handled in Mochi. The input hidden state is always upcast, but the scaling parameter isn't always upcast. And the scaling/modulation happens between the upcast hidden state and scale parameter.
emb = self.linear(self.silu(emb)) | ||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) | ||
|
||
hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].to(torch.float32))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comments here, scale
should be applied here
@@ -202,7 +478,10 @@ def _get_positions( | |||
return positions | |||
|
|||
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: | |||
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float()) | |||
with torch.autocast(freqs.device.type, enabled=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this do here? are we disable autocast here in case it is enabled
ok here if this is what causing the quality issue - but we should come up with something better :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm actually I think with all the manual casts we have in place, just setting torch_dtype=torch.bfloat16
should allow us to effectively run Mochi as if autocast is enabled. We could just remove this autocast context manager. Will test.
The reason I put this here was because under the autocast context einsum would always return bfloat16 (which causes numerical differences in the output)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
We're seeing some quality issues with Mochi due to missing upcasts and differences between how attention is handled in the original repo.
This PR:
I'll update the docs PR: #9934 with a guide on how to reproduce the original repo results exactly once this PR is merged.
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.