Skip to content

Commit

Permalink
Merge branch 'hongxiaob/permute_fusion' into 'main'
Browse files Browse the repository at this point in the history
MoE permute/unpermute fusion

Closes #36

See merge request ADLR/megatron-lm!2224
  • Loading branch information
ko3n1g committed Feb 10, 2025
2 parents 2481987 + 8a71e3b commit 044e2ad
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 94 deletions.
56 changes: 37 additions & 19 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import pickle
import warnings
from typing import Callable
from typing import Any, Callable, Optional

import torch
import transformer_engine as te
Expand Down Expand Up @@ -113,13 +113,13 @@ def __init__(
input_size: int,
output_size: int,
*,
parallel_mode: str,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
):
self.config = config
Expand Down Expand Up @@ -275,7 +275,7 @@ def __init__(
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
tp_comm_buffer_name: Optional[str] = None,
):
self.config = config

Expand Down Expand Up @@ -440,7 +440,7 @@ def __init__(
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
tp_comm_buffer_name: Optional[str] = None,
):
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
Expand Down Expand Up @@ -523,7 +523,7 @@ def __init__(
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
tp_comm_buffer_name: Optional[str] = None,
):
if not input_is_parallel:
raise ValueError(
Expand Down Expand Up @@ -608,10 +608,10 @@ def __init__(
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
softmax_scale: float = None,
k_channels: int = None,
v_channels: int = None,
attention_dropout: Optional[float] = None,
softmax_scale: Optional[float] = None,
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
cp_comm_type: str = "p2p",
):
self.config = config
Expand All @@ -628,7 +628,7 @@ def __init__(
f"setting query key layer scaling via argument, so these two must match."
)

extra_kwargs = {}
extra_kwargs: dict[str, Any] = {}
if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
Expand Down Expand Up @@ -827,13 +827,13 @@ def __init__(
input_size: int,
output_size: int,
*,
parallel_mode: str,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
tp_comm_buffer_name: Optional[str] = None,
):
self.config = config

Expand Down Expand Up @@ -1070,7 +1070,7 @@ def __init__(
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
tp_comm_buffer_name: Optional[str] = None,
):

super().__init__(
Expand Down Expand Up @@ -1115,7 +1115,7 @@ def __init__(
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
tp_comm_buffer_name: Optional[str] = None,
):

super().__init__(
Expand Down Expand Up @@ -1143,9 +1143,9 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):

else:

TEGroupedLinear = None
TEColumnParallelGroupedLinear = None
TERowParallelGroupedLinear = None
TEGroupedLinear = None # type: ignore[assignment, misc]
TEColumnParallelGroupedLinear = None # type: ignore[assignment, misc]
TERowParallelGroupedLinear = None # type: ignore[assignment, misc]


class TEDelayedScaling(te.common.recipe.DelayedScaling):
Expand Down Expand Up @@ -1279,7 +1279,7 @@ def get_cpu_offload_context(

except ImportError:

get_cpu_offload_context = None
get_cpu_offload_context = None # type: ignore[assignment, misc]

try:

Expand Down Expand Up @@ -1322,3 +1322,21 @@ def fused_apply_rotary_pos_emb_thd(

Fp8Padding = None
Fp8Unpadding = None

try:

from transformer_engine.pytorch.permutation import (
moe_permute,
moe_sort_chunks_by_index,
moe_unpermute,
)

fused_permute = moe_permute
fused_unpermute = moe_unpermute
fused_sort_chunks_by_index = moe_sort_chunks_by_index

except ImportError:

fused_permute = None
fused_unpermute = None
fused_sort_chunks_by_index = None
3 changes: 3 additions & 0 deletions megatron/core/transformer/moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit
| --moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. |
| --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. |
| --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. |
| --moe-permute-fusion | Fuse token rearrangement ops during token dispatching. |
| --moe-shared-expert-intermediate-size | Set shared expert total ffn hidden size. It should be equal to `num_shared_experts * ffn_size_of_each_shared_expert` if there are multiple shared experts. None means no shared expert. |
| --moe-shared-expert-overlap | (Experimental, may changed) If this is set, the communications/computations in the shared experts and the dispatcher will overlap (The `alltoall` dispatcher is needed.) Otherwise, the shared expert runs after the routed experts. |
| --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.|
Expand All @@ -90,6 +91,7 @@ To train a top-2 MoE model with 8 experts and auxiliary loss, include the follow
--num-experts 8
--expert-model-parallel-size 8
--moe-grouped-gemm
--moe-permute-fusion
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
Expand Down Expand Up @@ -241,6 +243,7 @@ MOE_ARGS=(
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
--moe-permute-fusion
)

DATA_ARGS=(
Expand Down
15 changes: 9 additions & 6 deletions megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

# type: ignore
# This file will be deprecated soon. We won't fix the mypy type checks.

from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -60,13 +63,13 @@ def __init__(
self.num_global_tokens_per_local_expert_cpu = None
input_chunk_idxs = torch.arange(self.num_experts)
# [num_local_experts, ep_size]. Sort the input chunks by local experts.
self.sort_input_by_local_experts = (
input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel().tolist()
)
self.sort_input_by_local_experts = input_chunk_idxs.reshape(
-1, self.num_local_experts
).T.ravel()
# [ep_size, num_local_experts]. Restore the output chunks by local experts.
self.restore_output_by_local_experts = (
input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel().tolist()
)
self.restore_output_by_local_experts = input_chunk_idxs.reshape(
self.num_local_experts, -1
).T.ravel()

# Token drop and padding.
# We need to keep track of the token num if we drop tokens without padding them.
Expand Down
45 changes: 42 additions & 3 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
from megatron.core import parallel_state
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region

try:
from megatron.core.extensions.transformer_engine import (
fused_permute,
fused_sort_chunks_by_index,
fused_unpermute,
)

HAVE_TE = True
except ImportError:
HAVE_TE = False


def switch_load_balancing_loss_func(
probs: torch.Tensor,
Expand Down Expand Up @@ -207,7 +218,13 @@ def set_loss_scale(scale: torch.Tensor):
MoEAuxLossAutoScaler.main_loss_backward_scale = scale


def permute(tokens, routing_map, num_out_tokens: int = None, drop_and_pad: bool = False):
def permute(
tokens,
routing_map,
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
Expand All @@ -221,11 +238,17 @@ def permute(tokens, routing_map, num_out_tokens: int = None, drop_and_pad: bool
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens.
fused (bool, optional): Whether use the fused permute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
"""
if fused:
if not HAVE_TE or fused_permute is None:
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
return fused_permute(tokens, routing_map, num_out_tokens)

num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
if drop_and_pad and not (num_out_tokens is None):
Expand Down Expand Up @@ -262,6 +285,7 @@ def unpermute(
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""
Expand All @@ -281,12 +305,18 @@ def unpermute(
probs (torch.Tensor, optional): The unpermuted probs tensor,
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns:
torch.Tensor: The tokens restored to their original order.
"""
if fused:
if not HAVE_TE or fused_unpermute is None:
raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
return fused_unpermute(permuted_tokens, sorted_indices, probs, restore_shape)

_, hidden = restore_shape

if probs is not None:
Expand Down Expand Up @@ -320,10 +350,19 @@ def unpermute(
return output_tokens


def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor):
def sort_chunks_by_idxs(
input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False
):
"""Split and sort the input tensor based on the split_sizes and sorted indices."""
if fused:
if not HAVE_TE or fused_sort_chunks_by_index is None:
raise ValueError(
"fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0."
)
return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs)

input = torch.split(input, split_sizes.tolist(), dim=0)
output = torch.cat([input[i] for i in sorted_idxs], dim=0)
output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
return output


Expand Down
Loading

0 comments on commit 044e2ad

Please sign in to comment.