From 8a71e3bc7ddf27b83f981b9751a115e48f31c5f0 Mon Sep 17 00:00:00 2001 From: Hongxiao Bai Date: Mon, 10 Feb 2025 02:25:55 -0800 Subject: [PATCH] ADLR/megatron-lm!2224 - MoE permute/unpermute fusion Co-authored-by: Zijie Yan --- .../core/extensions/transformer_engine.py | 56 +++++++---- megatron/core/transformer/moe/README.md | 3 + .../moe/legacy_a2a_token_dispatcher.py | 15 +-- megatron/core/transformer/moe/moe_utils.py | 45 ++++++++- .../core/transformer/moe/token_dispatcher.py | 92 +++++++++++-------- .../core/transformer/transformer_config.py | 63 ++++++++----- megatron/training/arguments.py | 2 + .../moe/test_a2a_token_dispatcher.py | 20 +++- .../transformer/moe/test_token_dispatcher.py | 15 ++- 9 files changed, 217 insertions(+), 94 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 388ce484f7..426b733524 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -5,7 +5,7 @@ import os import pickle import warnings -from typing import Callable +from typing import Any, Callable, Optional import torch import transformer_engine as te @@ -113,13 +113,13 @@ def __init__( input_size: int, output_size: int, *, - parallel_mode: str, + parallel_mode: Optional[str], config: ModelParallelConfig, init_method: Callable, bias: bool, skip_bias_add: bool, skip_weight_param_allocation: bool, - tp_comm_buffer_name: str = None, + tp_comm_buffer_name: Optional[str] = None, is_expert: bool = False, ): self.config = config @@ -275,7 +275,7 @@ def __init__( skip_bias_add: bool, is_expert: bool, skip_weight_param_allocation: bool = False, - tp_comm_buffer_name: str = None, + tp_comm_buffer_name: Optional[str] = None, ): self.config = config @@ -440,7 +440,7 @@ def __init__( skip_bias_add: bool, is_expert: bool, skip_weight_param_allocation: bool = False, - tp_comm_buffer_name: str = None, + tp_comm_buffer_name: Optional[str] = None, ): if gather_output: raise ValueError('Transformer Engine linear layers do not support gather_output = True') @@ -523,7 +523,7 @@ def __init__( input_is_parallel: bool, skip_bias_add: bool, is_expert: bool, - tp_comm_buffer_name: str = None, + tp_comm_buffer_name: Optional[str] = None, ): if not input_is_parallel: raise ValueError( @@ -608,10 +608,10 @@ def __init__( layer_number: int, attn_mask_type: AttnMaskType, attention_type: str, - attention_dropout: float = None, - softmax_scale: float = None, - k_channels: int = None, - v_channels: int = None, + attention_dropout: Optional[float] = None, + softmax_scale: Optional[float] = None, + k_channels: Optional[int] = None, + v_channels: Optional[int] = None, cp_comm_type: str = "p2p", ): self.config = config @@ -628,7 +628,7 @@ def __init__( f"setting query key layer scaling via argument, so these two must match." ) - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if is_te_min_version("0.11.0"): extra_kwargs["num_gqa_groups"] = self.config.num_query_groups elif self.config.num_query_groups != self.config.num_attention_heads: @@ -827,13 +827,13 @@ def __init__( input_size: int, output_size: int, *, - parallel_mode: str, + parallel_mode: Optional[str], config: ModelParallelConfig, init_method: Callable, bias: bool, skip_bias_add: bool, is_expert: bool = False, - tp_comm_buffer_name: str = None, + tp_comm_buffer_name: Optional[str] = None, ): self.config = config @@ -1070,7 +1070,7 @@ def __init__( bias: bool, skip_bias_add: bool, is_expert: bool, - tp_comm_buffer_name: str = None, + tp_comm_buffer_name: Optional[str] = None, ): super().__init__( @@ -1115,7 +1115,7 @@ def __init__( bias: bool, skip_bias_add: bool, is_expert: bool, - tp_comm_buffer_name: str = None, + tp_comm_buffer_name: Optional[str] = None, ): super().__init__( @@ -1143,9 +1143,9 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): else: - TEGroupedLinear = None - TEColumnParallelGroupedLinear = None - TERowParallelGroupedLinear = None + TEGroupedLinear = None # type: ignore[assignment, misc] + TEColumnParallelGroupedLinear = None # type: ignore[assignment, misc] + TERowParallelGroupedLinear = None # type: ignore[assignment, misc] class TEDelayedScaling(te.common.recipe.DelayedScaling): @@ -1279,7 +1279,7 @@ def get_cpu_offload_context( except ImportError: - get_cpu_offload_context = None + get_cpu_offload_context = None # type: ignore[assignment, misc] try: @@ -1322,3 +1322,21 @@ def fused_apply_rotary_pos_emb_thd( Fp8Padding = None Fp8Unpadding = None + +try: + + from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_sort_chunks_by_index, + moe_unpermute, + ) + + fused_permute = moe_permute + fused_unpermute = moe_unpermute + fused_sort_chunks_by_index = moe_sort_chunks_by_index + +except ImportError: + + fused_permute = None + fused_unpermute = None + fused_sort_chunks_by_index = None diff --git a/megatron/core/transformer/moe/README.md b/megatron/core/transformer/moe/README.md index f1d21864df..e649637a48 100644 --- a/megatron/core/transformer/moe/README.md +++ b/megatron/core/transformer/moe/README.md @@ -76,6 +76,7 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit | --moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. | | --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. | | --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. | +| --moe-permute-fusion | Fuse token rearrangement ops during token dispatching. | | --moe-shared-expert-intermediate-size | Set shared expert total ffn hidden size. It should be equal to `num_shared_experts * ffn_size_of_each_shared_expert` if there are multiple shared experts. None means no shared expert. | | --moe-shared-expert-overlap | (Experimental, may changed) If this is set, the communications/computations in the shared experts and the dispatcher will overlap (The `alltoall` dispatcher is needed.) Otherwise, the shared expert runs after the routed experts. | | --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.| @@ -90,6 +91,7 @@ To train a top-2 MoE model with 8 experts and auxiliary loss, include the follow --num-experts 8 --expert-model-parallel-size 8 --moe-grouped-gemm +--moe-permute-fusion --moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss. --moe-router-topk 2 --moe-aux-loss-coeff 1e-2 @@ -241,6 +243,7 @@ MOE_ARGS=( --moe-router-topk 2 --moe-aux-loss-coeff 1e-2 --moe-grouped-gemm + --moe-permute-fusion ) DATA_ARGS=( diff --git a/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py b/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py index dd5f447dd3..0847b0ba42 100644 --- a/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py +++ b/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py @@ -1,5 +1,8 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# type: ignore +# This file will be deprecated soon. We won't fix the mypy type checks. + from typing import List, Optional, Tuple import torch @@ -60,13 +63,13 @@ def __init__( self.num_global_tokens_per_local_expert_cpu = None input_chunk_idxs = torch.arange(self.num_experts) # [num_local_experts, ep_size]. Sort the input chunks by local experts. - self.sort_input_by_local_experts = ( - input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel().tolist() - ) + self.sort_input_by_local_experts = input_chunk_idxs.reshape( + -1, self.num_local_experts + ).T.ravel() # [ep_size, num_local_experts]. Restore the output chunks by local experts. - self.restore_output_by_local_experts = ( - input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel().tolist() - ) + self.restore_output_by_local_experts = input_chunk_idxs.reshape( + self.num_local_experts, -1 + ).T.ravel() # Token drop and padding. # We need to keep track of the token num if we drop tokens without padding them. diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index d22170ba03..a426973a25 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -8,6 +8,17 @@ from megatron.core import parallel_state from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +try: + from megatron.core.extensions.transformer_engine import ( + fused_permute, + fused_sort_chunks_by_index, + fused_unpermute, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + def switch_load_balancing_loss_func( probs: torch.Tensor, @@ -207,7 +218,13 @@ def set_loss_scale(scale: torch.Tensor): MoEAuxLossAutoScaler.main_loss_backward_scale = scale -def permute(tokens, routing_map, num_out_tokens: int = None, drop_and_pad: bool = False): +def permute( + tokens, + routing_map, + num_out_tokens: Optional[int] = None, + fused: bool = False, + drop_and_pad: bool = False, +): """Permute the tokens and probs based on the mask. Tokens with the same designated expert will be grouped together. The shape of mask is [tokens, num_experts], it indicates which experts were selected @@ -221,11 +238,17 @@ def permute(tokens, routing_map, num_out_tokens: int = None, drop_and_pad: bool routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. num_out_tokens (int, optional): The number of output tokens. If None, it's set to the number of input tokens. + fused (bool, optional): Whether use the fused permute function. drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity. If set to true, routing_map has a fixed number of non-zeros in each column. """ + if fused: + if not HAVE_TE or fused_permute is None: + raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.") + return fused_permute(tokens, routing_map, num_out_tokens) + num_tokens, hidden = tokens.shape num_experts = routing_map.shape[1] if drop_and_pad and not (num_out_tokens is None): @@ -262,6 +285,7 @@ def unpermute( restore_shape: torch.Size, probs: torch.Tensor = None, routing_map: torch.Tensor = None, + fused: bool = False, drop_and_pad: bool = False, ): """ @@ -281,12 +305,18 @@ def unpermute( probs (torch.Tensor, optional): The unpermuted probs tensor, routing_map (torch.Tensor, optional): Token to expert mapping, shape [num_tokens, num_experts]. + fused (bool, optional): Whether use the fused unpermute function. drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity. Returns: torch.Tensor: The tokens restored to their original order. """ + if fused: + if not HAVE_TE or fused_unpermute is None: + raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.") + return fused_unpermute(permuted_tokens, sorted_indices, probs, restore_shape) + _, hidden = restore_shape if probs is not None: @@ -320,10 +350,19 @@ def unpermute( return output_tokens -def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor): +def sort_chunks_by_idxs( + input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False +): """Split and sort the input tensor based on the split_sizes and sorted indices.""" + if fused: + if not HAVE_TE or fused_sort_chunks_by_index is None: + raise ValueError( + "fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0." + ) + return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs) + input = torch.split(input, split_sizes.tolist(), dim=0) - output = torch.cat([input[i] for i in sorted_idxs], dim=0) + output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0) return output diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 03301b1e33..56456a2685 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -125,9 +125,6 @@ def __init__( self.router_topk = config.moe_router_topk self.add_bias = config.add_bias_linear - # self.local_probs: probs of global token assignment to local experts. - self.local_probs = None - # self.global_local_map: 2D tensor. A mask of mapping between global and local tokens where # each element is True if it's between the local_expert_indices. Only useful when cross # device token permutation is enabled and **AllGahter** is performed. @@ -183,6 +180,7 @@ def token_permutation( self.local_map = routing_map[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 ].contiguous() + # probs of global token assignment to local experts. self.local_probs = probs[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 ].contiguous() @@ -190,7 +188,10 @@ def token_permutation( tokens_per_expert = self.local_map.sum(dim=0).long().cpu() (permuted_local_hidden_states, self.reversed_local_input_permutation_mapping) = permute( - hidden_states, self.local_map + hidden_states, + self.local_map, + num_out_tokens=tokens_per_expert.sum(), + fused=self.config.moe_permute_fusion, ) return permuted_local_hidden_states, tokens_per_expert @@ -220,6 +221,8 @@ def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = hidden_states, self.reversed_local_input_permutation_mapping, restore_shape=self.hidden_shape_before_permute, + routing_map=self.local_map, + fused=self.config.moe_permute_fusion, ) unpermuted_local_bias = None @@ -230,6 +233,8 @@ def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = bias, self.reversed_local_input_permutation_mapping, restore_shape=self.hidden_shape_before_permute, + routing_map=self.local_map, + fused=self.config.moe_permute_fusion, ) output_total = unpermuted_local_hidden @@ -279,8 +284,8 @@ def __init__( config (TransformerConfig): Configuration for the transformer model. """ super().__init__(config=config) - self.hidden_shape = None self.num_local_experts = num_local_experts + assert config.num_moe_experts is not None self.num_experts = config.num_moe_experts assert self.num_local_experts > 0, "Expected at least one expert" self.local_expert_indices = local_expert_indices @@ -291,7 +296,6 @@ def __init__( assert ( self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1 ), "local_expert_indices must be continous" - self.probs = None # [ep_size]. Represents the number of tokens sent by the current rank to other # EP ranks. @@ -302,26 +306,25 @@ def __init__( # [tp_size]. Represents the number of tokens received by the current rank from # other TP ranks. self.output_splits_tp = None - # [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent - # to each local expert by all ranks. - self.num_global_tokens_per_local_expert_cpu = None - input_chunk_idxs = torch.arange(self.num_experts * self.tp_size) - # [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts. - self.sort_input_by_local_experts = ( - input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel().tolist() + self.permute_idx_device = torch.device("cuda") if self.config.moe_permute_fusion else None + input_chunk_idxs = torch.arange( + self.num_experts * self.tp_size, device=self.permute_idx_device ) + # [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts. + self.sort_input_by_local_experts = input_chunk_idxs.reshape( + -1, self.num_local_experts + ).T.ravel() # [tp_size * ep_size, num_local_experts]. Restore the output chunks by local experts. - self.restore_output_by_local_experts = ( - input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel().tolist() - ) + self.restore_output_by_local_experts = input_chunk_idxs.reshape( + self.num_local_experts, -1 + ).T.ravel() # Token drop and padding. - # We need to keep track of the token num if we drop tokens without padding them. - self.num_out_tokens = None # Drop and pad the input to capacity. self.drop_and_pad = self.config.moe_pad_expert_input_to_capacity if self.drop_and_pad: assert self.config.moe_expert_capacity_factor is not None + self.moe_expert_capacity_factor = self.config.moe_expert_capacity_factor self.capacity = None # A cuda stream synchronization is needed in self.token_permutation() in some cases, @@ -357,7 +360,7 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: self.capacity = get_capacity( num_tokens=num_tokens, num_experts=self.num_experts, - capacity_factor=self.config.moe_expert_capacity_factor, + capacity_factor=self.moe_expert_capacity_factor, ) self.num_out_tokens = self.capacity * self.num_experts # [num_local_experts], number of tokens processed by each expert. @@ -366,9 +369,13 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: self.capacity * self.tp_size * self.ep_size, dtype=torch.long, ) - # [tp_size * ep_size, num_local_experts]. - self.num_global_tokens_per_local_expert_cpu = torch.full( - (self.num_experts * self.tp_size,), self.capacity, dtype=torch.long + # [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert = torch.full( + (self.num_experts * self.tp_size,), + self.capacity, + dtype=torch.long, + device=self.permute_idx_device, ) return num_tokens_per_local_expert elif self.config.moe_expert_capacity_factor is not None: @@ -395,6 +402,8 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # =================================================== # Calculate input_splits, output_splits for alltoall/allgather in variable size. # =================================================== + # [ep_size]. Represents the number of tokens sent by the current rank to other + # EP ranks. self.input_splits = ( num_local_tokens_per_expert.reshape(self.ep_size, self.num_local_experts) .sum(axis=1) @@ -447,9 +456,15 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: ) if self.num_local_experts > 1: - self.num_global_tokens_per_local_expert_cpu = num_global_tokens_per_local_expert.view( + # [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view( -1, self.num_local_experts - ).to(torch.device("cpu"), non_blocking=True) + ) + if not self.config.moe_permute_fusion: + self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.to( + torch.device("cpu"), non_blocking=False + ) return num_tokens_per_local_expert @@ -496,6 +511,7 @@ def token_permutation( hidden_states, routing_map, num_out_tokens=self.num_out_tokens, + fused=self.config.moe_permute_fusion, drop_and_pad=self.drop_and_pad, ) @@ -509,18 +525,17 @@ def token_permutation( self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) if self.tp_size > 1: + if self.output_splits_tp is None: + output_split_sizes = None + else: + output_split_sizes = self.output_splits_tp.tolist() global_input_tokens = gather_from_sequence_parallel_region( - global_input_tokens, - group=self.tp_group, - output_split_sizes=( - self.output_splits_tp.tolist() if self.output_splits_tp is not None else None - ), + global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes ) # Permutation 2: Sort tokens by local expert. if self.num_local_experts > 1: if self.drop_and_pad: - # Example: global_input_tokens = ( global_input_tokens.view( self.tp_size * self.ep_size, @@ -535,8 +550,9 @@ def token_permutation( else: global_input_tokens = sort_chunks_by_idxs( global_input_tokens, - self.num_global_tokens_per_local_expert_cpu.ravel(), + self.num_global_tokens_per_local_expert.ravel(), self.sort_input_by_local_experts, + fused=self.config.moe_permute_fusion, ) if self.cuda_sync_point == "before_finish": @@ -583,17 +599,18 @@ def token_unpermutation( else: hidden_states = sort_chunks_by_idxs( hidden_states, - self.num_global_tokens_per_local_expert_cpu.T.ravel(), + self.num_global_tokens_per_local_expert.T.ravel(), self.restore_output_by_local_experts, + fused=self.config.moe_permute_fusion, ) if self.tp_size > 1: + if self.output_splits_tp is None: + input_split_sizes = None + else: + input_split_sizes = self.output_splits_tp.tolist() hidden_states = reduce_scatter_to_sequence_parallel_region( - hidden_states, - group=self.tp_group, - input_split_sizes=( - self.output_splits_tp.tolist() if self.output_splits_tp is not None else None - ), + hidden_states, group=self.tp_group, input_split_sizes=input_split_sizes ) # Perform expert parallel AlltoAll communication @@ -612,6 +629,7 @@ def token_unpermutation( restore_shape=self.hidden_shape_before_permute, probs=self.probs, routing_map=self.routing_map, + fused=self.config.moe_permute_fusion, drop_and_pad=self.drop_and_pad, ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 94df70c526..a81fdb967b 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -54,17 +54,17 @@ class TransformerConfig(ModelParallelConfig): If attention backend is local we use the local pytorch implementation in mcore. Users can specify exact backend by changing this config. """ - softmax_scale: float = None + softmax_scale: Optional[float] = None """Softmax scale for attention scaling.""" - num_query_groups: int = None + num_query_groups: Optional[int] = None """Number of query groups for group query attention. If None, normal attention is used.""" - ffn_hidden_size: int = None + ffn_hidden_size: Optional[int] = None """Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided.""" - kv_channels: int = None + kv_channels: Optional[int] = None """Projection weights dimension in multi-head attention. This is set to hidden_size // num_attention_heads if not provided.""" @@ -105,7 +105,7 @@ class TransformerConfig(ModelParallelConfig): """Store the input of MLP activation function in FP8 for backprop to save memory. The stored input is casted back to the original precision before backprop compuatation.""" - num_moe_experts: int = None + num_moe_experts: Optional[int] = None """Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None for no MoE.""" @@ -117,7 +117,7 @@ class TransformerConfig(ModelParallelConfig): """If not None, then will use sliding window attention. The size of the window is specified by the numbers inside the tuple; -1 is special value meaning "infinite window size".""" - normalization: bool = "LayerNorm" + normalization: str = "LayerNorm" """Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`.""" qk_layernorm: bool = False @@ -136,13 +136,13 @@ class TransformerConfig(ModelParallelConfig): #################### # initialization #################### - init_method: Callable = None + init_method: Optional[Callable] = None """Method to initialize weights. Note that bias is always set to zero. Should be a function that takes a single Tensor and initializes it. If None, will be set to megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_std.""" - output_layer_init_method: Callable = None + output_layer_init_method: Optional[Callable] = None """Method to initialize weights of the output layer of both attention and MLP blocks. If None, will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers).""" @@ -188,7 +188,7 @@ class TransformerConfig(ModelParallelConfig): #################### # activation recomputation #################### - recompute_granularity: str = None + recompute_granularity: Optional[str] = None """Determines which type of activation recompute to use. Megatron-core supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. These memory intensive activations are also less compute intensive which makes activation @@ -198,7 +198,7 @@ class TransformerConfig(ModelParallelConfig): If set, must be 'selective' or 'full'. 'selective' always uses all layers. """ - recompute_method: str = None + recompute_method: Optional[str] = None """Determines which transformer layers will be recomputed. uniform will uniformly divide the total number of transformer layers in a transformer block and recompute the input activation of each divided chunk at the specified granularity. block will recompute the input activations for @@ -206,19 +206,19 @@ class TransformerConfig(ModelParallelConfig): pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all layers will do recomputation. If set, must be 'uniform' or 'block'.""" - recompute_num_layers: int = None + recompute_num_layers: Optional[int] = None """When recompute_method is uniform, recompute_num_layers is the number of transformer layers in each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is the number of transformer layers to recompute within each pipeline stage. Must be None for 'selective' activation checkpointing.""" - distribute_saved_activations: bool = None + distribute_saved_activations: Optional[bool] = None """If True, distribute recomputed activations across the model parallel group.""" #################### # fp8 related #################### - fp8: str = None + fp8: Optional[str] = None """If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8 activation and weight tensors and e5m2 for all FP8 output activation gradient tensors.""" @@ -257,7 +257,7 @@ class TransformerConfig(ModelParallelConfig): #################### # MoE related #################### - moe_shared_expert_intermediate_size: int = None + moe_shared_expert_intermediate_size: Optional[int] = None """Shared expert total ffn hidden size. It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if there are multiple shared experts. @@ -274,7 +274,7 @@ class TransformerConfig(ModelParallelConfig): "([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0] where 1 indicates an expert layer and 0 indicates a dense layer.""" - moe_ffn_hidden_size: int = None + moe_ffn_hidden_size: Optional[int] = None """MoE Feed-Forward Network hidden size""" moe_router_load_balancing_type: str = "aux_loss" @@ -286,12 +286,12 @@ class TransformerConfig(ModelParallelConfig): moe_router_topk: int = 2 """Number of experts to route to for each token.""" - moe_router_topk_limited_devices: int = None + moe_router_topk_limited_devices: Optional[int] = None """Number of EP ranks to consider for each token in group-limited routing, DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk. """ - moe_router_num_groups: int = None + moe_router_num_groups: Optional[int] = None """Number of groups to divide experts into for group-limited routing. When using group-limited routing: 1. Experts are divided into 'moe_router_num_groups' equal-sized groups @@ -307,14 +307,14 @@ class TransformerConfig(ModelParallelConfig): (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437) """ - moe_router_group_topk: int = None + moe_router_group_topk: Optional[int] = None """Number of selected groups for group-limited routing.""" moe_router_pre_softmax: bool = False """Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.""" - moe_router_topk_scaling_factor: float = None + moe_router_topk_scaling_factor: Optional[float] = None """Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax enabled. Defaults to None, which means no scaling.""" @@ -345,10 +345,10 @@ class TransformerConfig(ModelParallelConfig): moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss. """Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended.""" - moe_z_loss_coeff: float = None # 1e-3 would be a good start value for z-loss + moe_z_loss_coeff: Optional[float] = None # 1e-3 would be a good start value for z-loss """Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended.""" - moe_input_jitter_eps: float = None + moe_input_jitter_eps: Optional[float] = None """Add noise to the input tensor by applying jitter with a specified epsilon value.""" moe_token_dropping: bool = False @@ -363,7 +363,7 @@ class TransformerConfig(ModelParallelConfig): moe_per_layer_logging: bool = False """Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.""" - moe_expert_capacity_factor: float = None + moe_expert_capacity_factor: Optional[float] = None """moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None.""" @@ -381,10 +381,13 @@ class TransformerConfig(ModelParallelConfig): moe_layer_recompute: bool = False """Memory optimization: checkpointing moe_layer to save actiavtion memory.""" + moe_permute_fusion: bool = False + """Fuse token rearrangement ops during token dispatching.""" + ################## # Context Parallel ################## - cp_comm_type: Union[str, List[str]] = None + cp_comm_type: Optional[Union[str, List[str]]] = None """Inter-gpu communication type for context parallelism. str: all layers share same communication type. List[str]: each layer has its separate communication type. @@ -828,6 +831,20 @@ def __post_init__(self): f"variable sequence length, please use alltoall dispatcher instead." ) + if self.moe_permute_fusion: + from megatron.core.transformer.moe.moe_utils import ( + fused_permute, + fused_sort_chunks_by_index, + fused_unpermute, + ) + + if ( + fused_permute is None + or fused_sort_chunks_by_index is None + or fused_unpermute is None + ): + raise ValueError("fused permutation is not available. Please install TE >= 2.1.0.") + if self.cp_comm_type is not None: if isinstance(self.cp_comm_type, list): assert len(self.cp_comm_type) == self.num_layers, ( diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 36a3861a9b..0cc0417857 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2241,6 +2241,8 @@ def _add_moe_args(parser): group.add_argument('--moe-use-upcycling', action='store_true', help='Load a checkpoint of a dense model, convert it into an MoE model, and save the converted model to the path specified by --save. ' 'Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.') + group.add_argument('--moe-permute-fusion', action='store_true', + help='Fuse token rearrangement ops during token dispatching.') return parser diff --git a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py index dd218f57dd..69995c502f 100644 --- a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py @@ -4,7 +4,10 @@ import torch from tests.unit_tests.test_utilities import Utils -from tests.unit_tests.transformer.moe.test_token_dispatcher import MoEModelTestContainer +from tests.unit_tests.transformer.moe.test_token_dispatcher import ( + MoEModelTestContainer, + permute_fusion_params, +) def test_placeholder(): @@ -23,9 +26,10 @@ def teardown_method(self, method): @pytest.mark.internal @pytest.mark.timeout(120) @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) + @pytest.mark.parametrize("permute_fusion", permute_fusion_params) @pytest.mark.flaky @pytest.mark.flaky_in_dev - def test_forward_backward(self, tp_size, ep_size): + def test_forward_backward(self, tp_size, ep_size, permute_fusion): container = MoEModelTestContainer( tp_size=tp_size, ep_size=ep_size, @@ -34,6 +38,7 @@ def test_forward_backward(self, tp_size, ep_size): moe_router_topk=2, moe_router_load_balancing_type="aux_loss", moe_token_dispatcher_type="alltoall", + moe_permute_fusion=permute_fusion, ) container.dispatcher_dropless_test() @@ -52,6 +57,7 @@ def test_a2aseq_forward_backward(self, tp_size, ep_size): moe_router_topk=2, moe_router_load_balancing_type="aux_loss", moe_token_dispatcher_type="alltoall_seq", + moe_permute_fusion=False, ) container.dispatcher_dropless_test() @@ -59,9 +65,10 @@ def test_a2aseq_forward_backward(self, tp_size, ep_size): @pytest.mark.internal @pytest.mark.timeout(120) @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) + @pytest.mark.parametrize("permute_fusion", permute_fusion_params) @pytest.mark.flaky @pytest.mark.flaky_in_dev - def test_capacity_forward_backward(self, tp_size, ep_size): + def test_capacity_forward_backward(self, tp_size, ep_size, permute_fusion): container = MoEModelTestContainer( tp_size=tp_size, ep_size=ep_size, @@ -73,6 +80,7 @@ def test_capacity_forward_backward(self, tp_size, ep_size): moe_token_drop_policy="probs", moe_expert_capacity_factor=0.5, moe_pad_expert_input_to_capacity=False, + moe_permute_fusion=permute_fusion, ) container.dispatcher_capacity_test() @@ -80,7 +88,10 @@ def test_capacity_forward_backward(self, tp_size, ep_size): @pytest.mark.internal @pytest.mark.timeout(120) @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) - def test_capacity_padding_forward_backward(self, tp_size, ep_size): + @pytest.mark.parametrize("permute_fusion", permute_fusion_params) + @pytest.mark.flaky + @pytest.mark.flaky_in_dev + def test_capacity_padding_forward_backward(self, tp_size, ep_size, permute_fusion): container = MoEModelTestContainer( tp_size=tp_size, ep_size=ep_size, @@ -92,5 +103,6 @@ def test_capacity_padding_forward_backward(self, tp_size, ep_size): moe_token_drop_policy="probs", moe_expert_capacity_factor=0.6, moe_pad_expert_input_to_capacity=True, + moe_permute_fusion=permute_fusion, ) container.dispatcher_drop_and_pad_test() diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py index f8463042b7..1bac502302 100644 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py @@ -10,6 +10,7 @@ from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_utils import permute, unpermute from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import is_te_min_version from megatron.training.initialize import _set_random_seed from tests.unit_tests.test_utilities import Utils @@ -69,6 +70,7 @@ def __init__( use_cpu_initialization=kwargs.get("use_cpu_initialization", True), sequence_parallel=tp_size > 1, add_bias_linear=kwargs.get("add_bias_linear", False), + moe_permute_fusion=kwargs.get("moe_permute_fusion", False), ) # init moe layer @@ -220,6 +222,11 @@ def destroy(self): Utils.destroy_model_parallel() +permute_fusion_params = [False] +if is_te_min_version("1.14.0"): + permute_fusion_params.append(True) + + class TestAllgatherDispatcher: def setup_method(self, method): pass @@ -231,9 +238,10 @@ def teardown_method(self, method): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal @pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4), (1, 1)]) + @pytest.mark.parametrize("permute_fusion", permute_fusion_params) @pytest.mark.flaky @pytest.mark.flaky_in_dev - def test_forward_backward(self, tp_size, ep_size): + def test_forward_backward(self, tp_size, ep_size, permute_fusion): container = MoEModelTestContainer( tp_size=tp_size, ep_size=ep_size, @@ -242,6 +250,7 @@ def test_forward_backward(self, tp_size, ep_size): moe_router_topk=2, moe_router_load_balancing_type="aux_loss", moe_token_dispatcher_type="allgather", + moe_permute_fusion=permute_fusion, ) container.dispatcher_dropless_test() @@ -249,12 +258,13 @@ def test_forward_backward(self, tp_size, ep_size): @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal + @pytest.mark.parametrize("permute_fusion", permute_fusion_params) @pytest.mark.parametrize( "tp_size,ep_size,moe_tp_size", [(1, 1, 8), (1, 2, 4), (1, 4, 2), (2, 2, 4)] ) @pytest.mark.flaky @pytest.mark.flaky_in_dev - def test_moe_tp_forward_backward(self, tp_size, ep_size, moe_tp_size): + def test_moe_tp_forward_backward(self, tp_size, ep_size, moe_tp_size, permute_fusion): container = MoEModelTestContainer( tp_size=tp_size, ep_size=ep_size, @@ -266,6 +276,7 @@ def test_moe_tp_forward_backward(self, tp_size, ep_size, moe_tp_size): moe_token_dispatcher_type="allgather", sequence_parallel=True, moe_grouped_gemm=True, + moe_permute_fusion=permute_fusion, use_cpu_initialization=False, )