-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for MiniMax's MiniMax-Text-01 #35831
base: main
Are you sure you want to change the base?
Conversation
Hi @geetu040, this looks quite good! You can ping us whenever it's ready for review. Also, code quality issues in the CI can be fixed with |
This reverts commit d8d3c40.
Eager to see this! 🔥 Feel free to ping @Cyrilvallez and me for a review! |
cc @Cyrilvallez can you have a look! |
@ArthurZucker, @Rocketknight1, @Cyrilvallez This PR is ready for review. Please see |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Here is a first review! Very nice first modular, congrats! 🤗
Biggest point to work on currently is the Cache class, see my comments 👌
But should not be that much trouble, especially as the modular part itself is quite nice.
I must say I would prefer all classes and files to be named MiniMax
and mini_max
however, MiniMaxText01
feels weird to me. Are you one of the model creator? Do you feel strongly about it?
attn_type_list (`List[int]`, *optional*, defaults to `[0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1]`): | ||
List of attention types for each layer. `0` for linear (lightning) attention | ||
and `1` for full (normal) attention. | ||
block_size (`int`, *optional*, defaults to 256): | ||
The length of each attention block, determining how queries, keys, and values | ||
are grouped and processed for intra- and inter-block attention. | ||
postnorm (`bool`, *optional*, defaults to `False`): | ||
Use residual connections post-normalization. | ||
layernorm_full_attention_alpha (`float`, *optional*, defaults to 1): | ||
Weight for residual value in residual connection after normal attention. | ||
layernorm_full_attention_beta (`float`, *optional*, defaults to 1): | ||
Weight for hidden state value in residual connection after normal attention. | ||
layernorm_linear_attention_alpha (`float`, *optional*, defaults to 1): | ||
Weight for residual value in residual connection after lightning attention. | ||
layernorm_linear_attention_beta (`float`, *optional*, defaults to 1): | ||
Weight for hidden state value in residual connection after lightning attention. | ||
layernorm_mlp_alpha (`float`, *optional*, defaults to 1): | ||
Weight for residual value in residual connection after MLP. | ||
layernorm_mlp_beta (`float`, *optional*, defaults to 1): | ||
Weight for hidden state value in residual connection after MLP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like all arguments up to those are similar to Mixtral! So the Config should be inherited from Mixtral, and then you can add the additional args 🤗
def __init__(self, config: MiniMaxText01Config, layer_idx: int): | ||
super().__init__() | ||
self.config = config | ||
self.layer_idx = layer_idx | ||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | ||
self.num_heads = config.num_attention_heads | ||
self.num_hidden_layers = config.num_hidden_layers | ||
self.block_size = config.block_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no weights here, so this should be a function, not a Module
|
||
return key_decay, query_decay, diagonal_decay, block_decay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like slope_rate
, query_decay
and diagonal_decay
could always be pre-computed directly in the Attention module
|
||
class MiniMaxText01Cache(DynamicCache): | ||
def __init__(self, config, batch_size, dtype=torch.float16, device=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you don't want to inherit most methods of DynamicCache
, should be easier to inherit from Cache
instead
self.config = config | ||
self.num_hidden_layers = config.num_hidden_layers | ||
self.attn_type_list = config.attn_type_list | ||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | ||
self.num_heads = config.num_attention_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it should not be needed
output_router_logits: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
logits_to_keep: Union[int, torch.Tensor] = 0, | ||
**kwargs: Unpack[KwargsForCausalLM], | ||
) -> Union[Tuple, MoeCausalLMOutputWithPast]: | ||
r""" | ||
Args: | ||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | ||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | ||
|
||
logits_to_keep (`int` or `torch.Tensor`, *optional*): | ||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all | ||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that | ||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. | ||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. | ||
This is useful when using packed tensor format (single dimension for batch and sequence length). | ||
|
||
Returns: | ||
|
||
Example: | ||
|
||
```python | ||
>>> from transformers import AutoTokenizer, MiniMaxText01ForCausalLM | ||
|
||
>>> model = MiniMaxText01ForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01") | ||
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01") | ||
|
||
>>> prompt = "Hey, are you conscious? Can you talk to me?" | ||
>>> inputs = tokenizer(prompt, return_tensors="pt") | ||
|
||
>>> # Generate | ||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | ||
```""" | ||
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
output_router_logits = ( | ||
output_router_logits if output_router_logits is not None else self.config.output_router_logits | ||
) | ||
|
||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
outputs = self.model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
output_router_logits=output_router_logits, | ||
return_dict=return_dict, | ||
cache_position=cache_position, | ||
**kwargs, | ||
) | ||
|
||
hidden_states = outputs[0] | ||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss | ||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | ||
logits = self.lm_head(hidden_states[:, slice_indices, :]) | ||
|
||
loss = None | ||
if labels is not None: | ||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) | ||
|
||
aux_loss = None | ||
if output_router_logits: | ||
aux_loss = load_balancing_loss_func( | ||
outputs.router_logits if return_dict else outputs[-1], | ||
self.num_experts, | ||
self.num_experts_per_tok, | ||
attention_mask, | ||
) | ||
if labels is not None: | ||
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
if output_router_logits: | ||
output = (aux_loss,) + output | ||
return (loss,) + output if loss is not None else output | ||
|
||
return MoeCausalLMOutputWithPast( | ||
loss=loss, | ||
aux_loss=aux_loss, | ||
logits=logits, | ||
past_key_values=outputs.past_key_values, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
router_logits=outputs.router_logits, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here this is all similar to Mixtral as well!
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = MiniMaxText01Model(config) | |
self.router_aux_loss_coef = config.router_aux_loss_coef | |
self.num_experts = config.num_local_experts | |
self.num_experts_per_tok = config.num_experts_per_tok | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_router_logits: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
logits_to_keep: Union[int, torch.Tensor] = 0, | |
**kwargs: Unpack[KwargsForCausalLM], | |
) -> Union[Tuple, MoeCausalLMOutputWithPast]: | |
r""" | |
Args: | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
logits_to_keep (`int` or `torch.Tensor`, *optional*): | |
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all | |
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that | |
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. | |
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. | |
This is useful when using packed tensor format (single dimension for batch and sequence length). | |
Returns: | |
Example: | |
```python | |
>>> from transformers import AutoTokenizer, MiniMaxText01ForCausalLM | |
>>> model = MiniMaxText01ForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01") | |
>>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01") | |
>>> prompt = "Hey, are you conscious? Can you talk to me?" | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> # Generate | |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | |
```""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_router_logits = ( | |
output_router_logits if output_router_logits is not None else self.config.output_router_logits | |
) | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
output_router_logits=output_router_logits, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
**kwargs, | |
) | |
hidden_states = outputs[0] | |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | |
logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
loss = None | |
if labels is not None: | |
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) | |
aux_loss = None | |
if output_router_logits: | |
aux_loss = load_balancing_loss_func( | |
outputs.router_logits if return_dict else outputs[-1], | |
self.num_experts, | |
self.num_experts_per_tok, | |
attention_mask, | |
) | |
if labels is not None: | |
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
if output_router_logits: | |
output = (aux_loss,) + output | |
return (loss,) + output if loss is not None else output | |
return MoeCausalLMOutputWithPast( | |
loss=loss, | |
aux_loss=aux_loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
router_logits=outputs.router_logits, | |
) | |
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be removed
os.path.join(MODEL_ROOT, "mixtral", "modular_mixtral.py"), | ||
os.path.join(MODEL_ROOT, "minimax_text_01", "modular_minimax_text_01.py"), | ||
os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should not be touched
docs/source/en/model_doc/mgp-str.md | ||
docs/source/en/model_doc/minimax_text_01.md | ||
docs/source/en/model_doc/mistral.md |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should not be modified either
such as Mixtral, Qwen-MoE and DBRX, are MoE models. In these models, not every parameter is active for every token generated. | ||
such as Mixtral, MiniMaxText01, Qwen-MoE and DBRX, are MoE models. In these models, not every parameter is active for every token generated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Humm, it's okay to modify that, but please add your model name at the end of the list then 🤗 And for all languages
What does this PR do?
Fixes #35710
This PR adds MiniMaxAI's MiniMax-Text-01 model to Hugging Face Transformers.
Relevant Links
CC: @MiniMax-AI-Dev
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker, @Rocketknight1
Change Log
GPT2Tokenizer
MixtralConfig
with a few additional parameters:MiniMax-Text-01
architecture matches and uses most of theMixtral
architecture, with a few changes inhidden_states
can be used as residual connections, before or afterlayernorm
is appliedlayer_idx
transformers
To summarize above, the main area of review is the
LightningAttention
implementation.TODOs
LightningAttention
.generate()
method