Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MixtralMlpMixin()这个函数里面moe只是计算专家的logits但是没看到分发逻辑 #170

Open
AlenjandroWang opened this issue Feb 9, 2024 · 1 comment

Comments

@AlenjandroWang
Copy link

https://github.com/THUDM/SwissArmyTransformer/blob/main/sat/model/official/mixtral_model.py

@1049451037
Copy link
Member

在这里:

def mlp_forward_default(self, hidden_states, expert_id=-1, **kw_args):
if self.transformer.num_experts == 1 or expert_id > -1:
self = self.transformer.layers[kw_args['layer_id']].mlp
suffix = f"_{expert_id}" if expert_id > 0 else ""
if self.is_gated_mlp:
intermediate_parallel = getattr(self, "dense_h_to_4h"+suffix)(hidden_states)
gated_intermediate_parallel = getattr(self, "dense_h_to_4h_gate"+suffix)(hidden_states)
intermediate_parallel = self.activation_func(gated_intermediate_parallel) * intermediate_parallel
output = getattr(self, "dense_4h_to_h"+suffix)(intermediate_parallel)
else:
intermediate_parallel = getattr(self, "dense_h_to_4h"+suffix)(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
output = getattr(self, "dense_4h_to_h"+suffix)(intermediate_parallel)
return output
else:
mlp_forward = self.hooks.get('mlp_forward', partial(mlp_forward_default, self))
routing_forward = self.hooks.get('routing_forward', partial(routing_forward_default, self))
self = self.transformer.layers[kw_args['layer_id']].mlp
fwd_weight, fwd_idx = routing_forward(hidden_states, **kw_args)
# Adapted from mixtral-8x7b https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(fwd_idx, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[top_x_list] # I don't know why using hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = mlp_forward(current_state, expert_id=expert_idx, **kw_args) * fwd_weight[top_x_list, idx_list, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
output = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return output

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants