Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mochi Quality Issues #10033

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open

Fix Mochi Quality Issues #10033

wants to merge 41 commits into from

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Nov 27, 2024

What does this PR do?

We're seeing some quality issues with Mochi due to missing upcasts and differences between how attention is handled in the original repo.

This PR:

  1. Matches the transformer implementation 1:1 so that norms are upcast and run in the same precision as the original repo. 2. Changes the MochiAttnProcessor to match the original approach of dropping padding tokens.
  2. Runs the CFG and Sampling step in FP32

I'll update the docs PR: #9934 with a guide on how to reproduce the original repo results exactly once this PR is merged.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 requested a review from a-r-r-o-w November 27, 2024 07:05
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for all the fixes Dhruv!

Very grateful to @YanzuoLu and @Ednaordinary for their continuous help with testing different things out - thank you!

logger = logging.get_logger(__name__) # pylint: disable=invalid-n


class MochiModulatedRMSNorm(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love seeing the single file format for modeling!



logger = logging.get_logger(__name__) # pylint: disable=invalid-name
logger = logging.get_logger(__name__) # pylint: disable=invalid-n
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger = logging.get_logger(__name__) # pylint: disable=invalid-n
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Comment on lines 46 to 47
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could do one cast for the hidden states instead of two here?

Comment on lines 70 to 71
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, a bit neater to cast hdiden_states before the following statements (not really problem though)

Comment on lines +104 to +110
input_dtype = x.dtype

# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))

return x.to(input_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this pattern is very common for some kinds of layers. Do you think we could work on a refactor in the future where we decorate the forward methods with something like @upcast_to_fp32 so that the type conversions occur outside of the forward and they look like clean mathematical input-to-output mapping? This way, it could also be disabled with a global context manager if upcasting is not required for certain models, but be enabled by default.

return hidden_states, gate_msa, scale_mlp, gate_mlp


class MochiAttention(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting to see! I did not know we wanted to move from the central Attention class as well -- only thought we would be breaking up BasicTransformerBlock. Super cool and clean!

hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm probably missing something here, but can you point me to where do we handle the downcast for this upcast? Because otherwise all successive computations would be in float32, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Downcast here

hidden_states = hidden_states.to(hidden_states_dtype)

Comment on lines 276 to 293
for idx in range(batch_size):
mask = attention_mask[idx][None, :]
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()

valid_encoder_query = torch.index_select(encoder_query[idx][None, :], 2, valid_prompt_token_indices)
valid_encoder_key = torch.index_select(encoder_key[idx][None, :], 2, valid_prompt_token_indices)
valid_encoder_value = torch.index_select(encoder_value[idx][None, :], 2, valid_prompt_token_indices)

valid_query = torch.cat([query[idx][None, :], valid_encoder_query], dim=2)
valid_key = torch.cat([key[idx][None, :], valid_encoder_key], dim=2)
valid_value = torch.cat([value[idx][None, :], valid_encoder_value], dim=2)

attn_output = F.scaled_dot_product_attention(
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
)
valid_sequence_length = attn_output.size(2)
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
attn_outputs.append(attn_output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome catch and fix!

I think this might not play well with data parallel implementations due to the loop. i will profile this in the future and we can try a different implementation in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For data parallelism, we're replicating the model on each worker, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe used wrong terminology. what I meant was parallelizing across batch dimension

Copy link
Collaborator Author

@DN6 DN6 Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torch 2.5 has nested tensors as a prototype feature that allows variable sequence lengths in a batch
https://pytorch.org/docs/stable/nested.html

We could do a version check and use that here for parallelizing across batch dimension. The docs says the API is subject to change so I avoided using it.

@jzhang38
Copy link

@DN6 Hey can you give an example code to run your fixed branch? am really curious about the quality changes with those precision fixes.

@yiyixuxu
Copy link
Collaborator

@DN6 can we get this merged now? :) the tests need to be fixed

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i went over the PR so turns out there are a lot of refactorings (in addition to the fix)
i left some comments, I think it is not ready to merge yet and there are some clean up to be done :)

love that we start to take an action on having model-specific attention/attention class now! and +1 on @a-r-r-o-w's comments here to think about a better way to manage precision (can be future PRs, does not have to be done here) https://github.com/huggingface/diffusers/pull/10033/files#r1860070716

also, do we know what actually cause the issue? all of them together?

return hidden_states, gate_msa, scale_mlp, gate_mlp


class MochiAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should have a basic attention class (e.g. AttentionMixin) everything inherits from, no? so that some methods are available, set_processor, get_processor etc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Should I add in this PR or follow up? I don't think Mochi needs those methods since it just has a single processor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow PR is ok! but let's have a rough plan for the follow-up PR before merge
we only have one attention processor for mochi, but these methods are for people to use custom attention processors, also
methods like attn_processor depends on this, and they are useful

def attn_processors(self) -> Dict[str, AttentionProcessor]:

also, let's make sure we don't have something we need relhy with this kind of logic

if isinstance(module, Attention):

out_dim: int = None,
out_context_dim: int = None,
out_bias: bool = True,
context_pre_only: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably can clean up the arguments a bit now? e.g. remove these not used by mochi

Copy link
Collaborator Author

@DN6 DN6 Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these arguments are all used in the Mochi Attention layer.

variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps)

if scale is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at how this layer is used, I don't think scale should be passed down here, this should just be a regular RMSNorm; the scale part should be part of other operation (e.g. variation of AdaLayerNorm)

Copy link
Collaborator Author

@DN6 DN6 Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it's used is a bit different in Mochi. The outputs from the linear layer in MochiRMSNormZero (gate_msa, gate_mlp etc) are passed to the modulation after attention.

norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)

There is also a difference in which parts of the modulation are upcast. The parts that have tanh applied to them remain in BF16, while those that don't are upcast to FP32. The hidden states are always upcast to FP32 before norming and multiplying by the modulation and then the output is cast back to BF16

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this comment

The way it's used is a bit different in Mochi. The outputs from the linear layer in MochiRMSNormZero (gate_msa, gate_mlp etc) are passed to the modulation after attention.

I don't think it's different, I think this is the case for all our DIT models, maybe with slight difference in implementation from model to mode,l but we always apply the "modulate" part explicitly inside the transformer blocks, we should do the same for Mochi too. see some examples here
flux:

norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

sd3:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

cogvideo
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]

MochiModulatedRMSNorm you defined from scratch but it is actually just MochiRMSNormwith a scale, and can be write as

class MochiModulatedRMSNorm(nn.Module):
    def __init__(self, eps: float):
        super().__init__()
        self.norm = MochiRMSNorm(dim=None, eps=eps, elementwise_affine=False)

    def forward(self, hidden_states, scale=None):
        hidden_states = self.norm(hidden_states)
        
        if scale is not None:
            hidden_states = hidden_states * scale

        return hidden_states

so for layers that's are currently using MochiModulatedRMSNorm, we should easily rewrite with RMSNorm too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in this example the output of the self.norm would be torch.bfloat16.

When applying modulation the hidden state has to be in FP32. Just upcasting the output of the norm unfortunately isn't enough to reproduce the original result.

Downcasting to BF16 only happens after modulation
https://github.com/genmoai/mochi/blob/f3a800aea5862b4af13e66ff77eea1967c8c3a7f/src/genmo/mochi_preview/dit/joint_model/mod_rmsnorm.py#L15

And in the case of tanh modulation, right before adding to the residual
https://github.com/genmoai/mochi/blob/f3a800aea5862b4af13e66ff77eea1967c8c3a7f/src/genmo/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py#L6

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. I see what you mean. Upcast hidden states in MochiModulatedNorm and the output is equivalent. Got it.

return hidden_states


class MochiRMSNorm(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like identical to RMSNorm, no? did I miss something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our RMSNorm the hidden_states are not upcast for the entire operation

input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

We only upcast when computing the variance. For Mochi the hidden_states are in FP32 throughout and only downcast at the end.

Changing RMSNorm to upcast throughout might affect other models, so I added a new class here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cannot we upcast the input instead?

Copy link
Collaborator

@yiyixuxu yiyixuxu Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact this operation

hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

if the variance is in float32 and hidden_states in a lower precision, pytorch will automatically upcast float32 anyway

a little demo script

import torch

# Create a float16 tensor (simulating hidden_states)
hidden_states = torch.ones(2, 3, dtype=torch.float16)
print("Initial hidden_states dtype:", hidden_states.dtype)  # float16

# Calculate variance in float32 (like the code)
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
print("Variance dtype:", variance.dtype)  # float32

# Perform the multiplication (this is the line we're testing)
eps = 1e-5
result = hidden_states * torch.rsqrt(variance + eps)
print("Result dtype:", result.dtype)  # float32

# Print all values to verify
print("\nValues:")
print("hidden_states:", hidden_states)
print("variance:", variance)
print("result:", result)
Initial hidden_states dtype: torch.float16
Variance dtype: torch.float32
Result dtype: torch.float32

Values:
hidden_states: tensor([[1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float16)
variance: tensor([[1.],
        [1.]])
result: tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000]])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I really don't think we need new class here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, the RMSNorm in the Attentions can be replaced.


# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I commented on MochiModulatedRMSNorm - we should apply scale here, instead of passing it to MochiModulatedRMSNorm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the reason to pass this in is because of how upcasts are handled in Mochi. The input hidden state is always upcast, but the scaling parameter isn't always upcast. And the scaling/modulation happens between the upcast hidden state and scale parameter.

emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)

hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].to(torch.float32)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments here, scale should be applied here

@@ -202,7 +478,10 @@ def _get_positions(
return positions

def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float())
with torch.autocast(freqs.device.type, enabled=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do here? are we disable autocast here in case it is enabled
ok here if this is what causing the quality issue - but we should come up with something better :)

Copy link
Collaborator Author

@DN6 DN6 Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm actually I think with all the manual casts we have in place, just setting torch_dtype=torch.bfloat16 should allow us to effectively run Mochi as if autocast is enabled. We could just remove this autocast context manager. Will test.

The reason I put this here was because under the autocast context einsum would always return bfloat16 (which causes numerical differences in the output)

@DN6 DN6 closed this Nov 28, 2024
@DN6 DN6 reopened this Nov 28, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants