Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MiniMax's MiniMax-Text-01 #35831

Open
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

geetu040
Copy link
Contributor

@geetu040 geetu040 commented Jan 22, 2025

What does this PR do?

Fixes #35710

This PR adds MiniMaxAI's MiniMax-Text-01 model to Hugging Face Transformers.

  • MiniMax-Text-01 is a powerful language model with 456 billion total parameters, of which 45.9 billion are activated per token.
  • To better unlock the long context capabilities of the model, MiniMax-Text-01 adopts a hybrid architecture that combines Lightning Attention, Softmax Attention and Mixture-of-Experts (MoE).
  • MiniMax-Text-01's training context length is extended to 1 million tokens, and it can handle a context of up to 4 million tokens during the inference.
  • On various academic benchmarks, MiniMax-Text-01 also demonstrates the performance of a top-tier model.

Relevant Links

CC: @MiniMax-AI-Dev

Before submitting

Who can review?

@ArthurZucker, @Rocketknight1

Change Log

  • Tokenizer: It uses the existing GPT2Tokenizer
  • Config: Matches the MixtralConfig with a few additional parameters:
    • residual_post_norm, attn_type_list
    • layernorm_attention_alpha, layernorm_lightning_attention_alpha, layernorm_mlp_alpha
    • layernorm_attention_beta, layernorm_lightning_attention_beta, layernorm_mlp_beta
  • Weight Conversion Script: No script needed, original weights can be loaded directly into the new architecutre
  • Model: MiniMax-Text-01 architecture matches and uses most of the Mixtral architecture, with a few changes in
    • DecoderLayer
      • hidden_states can be used as residual connections, before or after layernorm is applied
      • weighted sum is used in residual connection
      • selection between softmax and lightning attention based on the layer_idx
    • LightningAttention
      • intially used in TransNormerLLM, upgraded in Lightning Attention-2 and adopted in MiniMax-01
      • every 8th decoder layer uses a softmax attention, rest of the layers use lightning attention, which is not previously implemented in transformers

To summarize above, the main area of review is the LightningAttention implementation.

TODOs

  • Update Documentation
  • Update Tests
  • Import Statements and Auto Modeling
  • Implement Model
    • Implement End-to-End Architecture
    • Implement LightningAttention
      • Work with avialable code
      • Refactor, Clean and Optimize
      • Implement Decays
      • Implement Caching
      • Implement attention_mask
      • Support .generate() method
  • Fix CI/CD tests

@Rocketknight1
Copy link
Member

Hi @geetu040, this looks quite good! You can ping us whenever it's ready for review. Also, code quality issues in the CI can be fixed with pip install -e .[quality] in the transformers directory, followed by make fixup.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 27, 2025

Eager to see this! 🔥 Feel free to ping @Cyrilvallez and me for a review!

@ArthurZucker
Copy link
Collaborator

cc @Cyrilvallez can you have a look!

@geetu040 geetu040 marked this pull request as ready for review February 18, 2025 07:53
@geetu040
Copy link
Contributor Author

@ArthurZucker, @Rocketknight1, @Cyrilvallez

This PR is ready for review.

Please see MiniMaxText01Cache, the lightning attention uses a cache that doesnot grow with the sequence length. Though I have implemented this in the best of my knowledge, still I am not sure if this cache can be traced back for beam search or low memory contrasitive search where we need to crop back the cache given the sequence

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Here is a first review! Very nice first modular, congrats! 🤗
Biggest point to work on currently is the Cache class, see my comments 👌
But should not be that much trouble, especially as the modular part itself is quite nice.

I must say I would prefer all classes and files to be named MiniMax and mini_max however, MiniMaxText01 feels weird to me. Are you one of the model creator? Do you feel strongly about it?

Comment on lines +132 to +151
attn_type_list (`List[int]`, *optional*, defaults to `[0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1]`):
List of attention types for each layer. `0` for linear (lightning) attention
and `1` for full (normal) attention.
block_size (`int`, *optional*, defaults to 256):
The length of each attention block, determining how queries, keys, and values
are grouped and processed for intra- and inter-block attention.
postnorm (`bool`, *optional*, defaults to `False`):
Use residual connections post-normalization.
layernorm_full_attention_alpha (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after normal attention.
layernorm_full_attention_beta (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after normal attention.
layernorm_linear_attention_alpha (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after lightning attention.
layernorm_linear_attention_beta (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after lightning attention.
layernorm_mlp_alpha (`float`, *optional*, defaults to 1):
Weight for residual value in residual connection after MLP.
layernorm_mlp_beta (`float`, *optional*, defaults to 1):
Weight for hidden state value in residual connection after MLP.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like all arguments up to those are similar to Mixtral! So the Config should be inherited from Mixtral, and then you can add the additional args 🤗

Comment on lines +347 to +354
def __init__(self, config: MiniMaxText01Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_heads = config.num_attention_heads
self.num_hidden_layers = config.num_hidden_layers
self.block_size = config.block_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no weights here, so this should be a function, not a Module

Comment on lines +388 to +389

return key_decay, query_decay, diagonal_decay, block_decay
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like slope_rate, query_decay and diagonal_decay could always be pre-computed directly in the Attention module

Comment on lines +248 to +250

class MiniMaxText01Cache(DynamicCache):
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you don't want to inherit most methods of DynamicCache, should be easier to inherit from Cache instead

Comment on lines +252 to +256
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.attn_type_list = config.attn_type_list
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_heads = config.num_attention_heads
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it should not be needed

Comment on lines +756 to +875
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns:

Example:

```python
>>> from transformers import AutoTokenizer, MiniMaxText01ForCausalLM

>>> model = MiniMaxText01ForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device

if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here this is all similar to Mixtral as well!

Suggested change
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = MiniMaxText01Model(config)
self.router_aux_loss_coef = config.router_aux_loss_coef
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MiniMaxText01ForCausalLM
>>> model = MiniMaxText01ForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01")
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed

Comment on lines 21 to 23
os.path.join(MODEL_ROOT, "mixtral", "modular_mixtral.py"),
os.path.join(MODEL_ROOT, "minimax_text_01", "modular_minimax_text_01.py"),
os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be touched

Comment on lines 165 to 167
docs/source/en/model_doc/mgp-str.md
docs/source/en/model_doc/minimax_text_01.md
docs/source/en/model_doc/mistral.md
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be modified either

such as Mixtral, Qwen-MoE and DBRX, are MoE models. In these models, not every parameter is active for every token generated.
such as Mixtral, MiniMaxText01, Qwen-MoE and DBRX, are MoE models. In these models, not every parameter is active for every token generated.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, it's okay to modify that, but please add your model name at the end of the list then 🤗 And for all languages

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for MiniMax-Text-01 and MiniMax-VL-01 from MiniMaxAI
5 participants