Skip to content

Commit

Permalink
Support MoE for GPTModelPipe (microsoft#373)
Browse files Browse the repository at this point in the history
* MOE: Support MoE layers creation for GPTModelPipe

Signed-off-by: Moshe Island <[email protected]>

* MOE: Support MoE aux loss for GPTModelPipe

Propagate aux loss along GPTModelPipe layers by forwarding the aggregated loss
from each transformer layer to the next transformer layer.

In addition, add a layer to GPTModelPipe, after the last transformer layer, to
catch the final aggregated aux loss and cache it for use in the loss function.

Signed-off-by: Moshe Island <[email protected]>

* MOE: Support display of MoE loss for GPTModelPipe

Signed-off-by: Moshe Island <[email protected]>

* MOE: Verify MoE with no pipe/grad partitioned

Currently PipelineEngine supports only a single tensor partitioning with grad.
MoE model requires to forward with grad both the activations and the aux_loss.
Therefore, until PilelineEngine limitation is removed, verify no partitioning
when using MoE.

Signed-off-by: Moshe Island <[email protected]>

---------

Signed-off-by: Moshe Island <[email protected]>
Co-authored-by: Moshe Island <[email protected]>
  • Loading branch information
mosheisland and mosheisland authored Apr 9, 2024
1 parent 3c5f475 commit bcedecd
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 27 deletions.
61 changes: 56 additions & 5 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""GPT-2 model."""

import torch
from collections import OrderedDict

from megatron import get_args
from megatron.core import mpu, tensor_parallel, sequence_parallel
Expand All @@ -16,7 +17,7 @@

from megatron.model import LayerNorm, RMSNorm
from .language_model import EmbeddingPipe
from .transformer import ParallelTransformerLayerPipe, LMHeadPipe
from .transformer import ParallelTransformerLayerPipe, LMHeadPipe, get_num_experts_per_layer
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec


Expand Down Expand Up @@ -360,12 +361,33 @@ def _to_float16(inputs):
embedding_weights_in_fp32=args.embedding_weights_in_fp32,
tied_weight_attr='word_embeddings_weight'))

experts_per_layer = get_num_experts_per_layer(args.num_experts, args.num_layers, args.expert_interval)
self.is_moe_model = any(n_experts > 1 for n_experts in experts_per_layer)

# Currently PipelineEngine does not support more than 1 pipe and/or grad partitioned tensors that
# require grads.
# When using MoE, we have 2 tensors that are passed along pipeline stages and both require grads.
# Therefore, verify that both pipe_partitioned / grad_partitioned are not enabled
if self.is_moe_model and args.pipeline_model_parallel_size > 1 and args.tensor_model_parallel_size > 1:
pipe_partitioned_enabled = args.deepspeed_config_dict.get('pipeline', {}).get('pipe_partitioned', False)
grad_partitioned_enabled = args.deepspeed_config_dict.get('pipeline', {}).get('grad_partitioned', False)
assert not pipe_partitioned_enabled and not grad_partitioned_enabled, \
'Pipe and/or Grad partitioning are not supported for MoE model'

for layer_idx in range(args.num_layers):
self.specs.append(
LayerSpec(ParallelTransformerLayerPipe,
config,
layer_number=layer_idx,
self_attn_mask_type=AttnMaskType.causal))
config,
layer_number=layer_idx,
self_attn_mask_type=AttnMaskType.causal,
num_experts=experts_per_layer[layer_idx],
input_aggregated_moe_loss=(self.is_moe_model and layer_idx > 0),
return_aggregated_moe_loss=self.is_moe_model))

# if model has experts, add a layer to get and cache the aggregated moe loss from the
# last transformer layer
if self.is_moe_model:
self.specs.append(self._calculate_moe_loss)

# Final layernorm after transformer layers
if args.normalization == 'layernorm':
Expand Down Expand Up @@ -404,6 +426,11 @@ def _logits_helper(embedding, lm_output):
if args.fp16 or args.bf16:
self.specs.append(float16_to_fp32)

# Cache losses
self.moe_loss = None
self.last_lm_loss = None # detached, for display only
self.last_moe_loss = None # detached, for display only

if args.checkpoint_activations:
interval = args.checkpoint_num_layers
elif args.recompute_granularity == "full" and args.recompute_method == 'uniform':
Expand All @@ -418,10 +445,34 @@ def _logits_helper(embedding, lm_output):
num_dp=mpu.get_data_parallel_world_size())

super().__init__(layers=self.specs,
loss_fn=CrossEntropy,
loss_fn=self.loss_func,
topology=topo,
activation_checkpoint_interval=interval,
partition_method='type:transformer')

def _calculate_moe_loss(self, inputs):
""" Calculate MoE auxiliary loss """
assert isinstance(inputs, tuple) and len(inputs) == 2
hidden, aggregated_moe_loss = inputs[0], inputs[1]
args = get_args()
self.moe_loss = aggregated_moe_loss * args.moe_loss_coeff
return hidden

def loss_func(self, output, labels):
loss = CrossEntropy(output, labels)
self.last_lm_loss = loss.clone().detach()
if self.moe_loss is not None:
loss += self.moe_loss
self.last_moe_loss = self.moe_loss.clone().detach()
return loss

def universal_checkpoint_info(self):
return UniversalCheckpointInfo(using_model_pipe=True).get()

def get_additional_losses(self):
if not self.is_moe_model:
return None
return OrderedDict({
'lm loss': self.last_lm_loss,
'moe loss': self.last_moe_loss
})
79 changes: 58 additions & 21 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,8 @@ def forward(self, hidden_states, attention_mask=None,
retriever_output=None,
retriever_attn_mask=None,
inference_params=None,
rotary_pos_emb=None):
rotary_pos_emb=None,
aggregated_moe_loss=None):
# hidden_states: [s, b, h]

# Layer norm at the beginning of the transformer layer.
Expand Down Expand Up @@ -1321,6 +1322,10 @@ def forward(self, hidden_states, attention_mask=None,
else:
mlp_output, moe_loss, _ = self.mlp(layernorm_output)

# when aggregated_moe_loss received, returned moe_loss is the aggregated moe loss
if aggregated_moe_loss is not None:
moe_loss += aggregated_moe_loss

# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
Expand Down Expand Up @@ -1381,23 +1386,51 @@ class ParallelTransformerLayerPipe(ParallelTransformerLayer):
If no mask is provided, the module will query `self._args.attn_mask`
for the mask and only return `super().forward(...)`
"""
def __init__(self, config,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0., num_experts=1,
input_aggregated_moe_loss=False, return_aggregated_moe_loss=False):
self.input_aggregated_moe_loss = input_aggregated_moe_loss
self.return_aggregated_moe_loss = return_aggregated_moe_loss
super().__init__(config, layer_number, layer_type, self_attn_mask_type, drop_path_rate, num_experts)

def forward(self, inputs, **kwargs):
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
if not hasattr(self, '_args'):
self._args = get_args()
rotary_pos_emb = self._args.rotary_pos_emb if self._args.use_rotary_position_embeddings else None
if torch.is_tensor(inputs) or len(inputs) == 1:
assert not self.input_aggregated_moe_loss, f'Expecting an input tuple of size >= 2'
# No attention mask forwarded, search for args.attn_mask
hidden_states, attention_mask = inputs, self._args.attn_mask
# HACK: currently MoE model does not support pipeline parallel, so
# here we just ignore the moe_loss returned by forward()
return super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb)[0]
elif len(inputs) == 2:
# Attention mask is an activation.
hidden_states, attention_mask = inputs[0], inputs[1]
# HACK: currently MoE model does not support pipeline parallel, so
# here we just ignore the moe_loss returned by forward()
return super().forward(*inputs, **kwargs, rotary_pos_emb=rotary_pos_emb)[0], attention_mask
output, moe_loss = super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb)
return (output, moe_loss) if self.return_aggregated_moe_loss else output
elif len(inputs) in (2, 3):
# Attention mask and aggregated_moe can both be activations.
return_attention_mask = False
if len(inputs) == 2:
if self.input_aggregated_moe_loss:
hidden_states, aggregated_moe_loss = inputs[0], inputs[1]
attention_mask = self._args.attn_mask
else:
hidden_states, attention_mask = inputs[0], inputs[1]
return_attention_mask = True
else:
hidden_states, attention_mask, aggregated_moe_loss = inputs[0], inputs[1], inputs[2]

# Forward aggregated_moe_loss to ParallelTransformerLayer for further accumulation
if self.input_aggregated_moe_loss:
kwargs.update({'aggregated_moe_loss': aggregated_moe_loss})

output, moe_loss = super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb)

ret = (output, )
if return_attention_mask:
ret += (attention_mask, )
if self.return_aggregated_moe_loss:
ret += (moe_loss, )
return ret
else:
raise RuntimeError('Received more inputs than understood.')

Expand Down Expand Up @@ -1499,6 +1532,19 @@ def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
return default_layer_type


def get_num_experts_per_layer(num_experts: list, num_layers: int, expert_interval: int, offset: int = 0) -> list:
assert len(num_experts) == 1 or len(num_experts) == num_layers // expert_interval, \
'num_experts must be either a single value or a list of the same length as the number of MoE layers'
if len(num_experts) == 1:
num_experts = num_experts * (num_layers // expert_interval)
experts_per_layer = []
for i in range(num_layers):
layer_num = i + 1 + offset
n_e = num_experts[(layer_num-1) // expert_interval] if layer_num % expert_interval == 0 else 1
experts_per_layer.append(n_e)
return experts_per_layer


class ParallelTransformer(MegatronModule):
"""Transformer class."""

Expand Down Expand Up @@ -1682,21 +1728,12 @@ def build_layer(layer_number, n_e):
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
assert len(num_experts) == 1 or len(num_experts) == args.num_layers // args.expert_interval, \
'num_experts must be either a single value or a list of the same length as the number of MoE layers'

# Create the list of MoE experts
if len(num_experts) == 1:
num_experts = num_experts * (args.num_layers // args.expert_interval)

# Build the layers
self.layers = []
experts_per_layer = get_num_experts_per_layer(num_experts, self.num_layers, args.expert_interval, offset)
for i in range(self.num_layers):
layer_num = i + 1 + offset
if layer_num % args.expert_interval == 0:
n_e = num_experts[(layer_num-1) // args.expert_interval]
else:
n_e = 1
n_e = experts_per_layer[i]
self.layers.append(build_layer(layer_num, n_e))
self.layers = torch.nn.ModuleList(self.layers)

Expand Down
8 changes: 7 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch
from collections import OrderedDict
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args
Expand Down Expand Up @@ -667,8 +668,13 @@ def train_step(forward_step_func, data_iterator,
num_zeros_in_grad = 0
assert isinstance(model[0], deepspeed.PipelineEngine)
loss = model[0].train_batch(data_iter=data_iterator)
additional_losses = model[0].get_additional_losses()
loss_key = 'lm loss' if additional_losses is None else 'loss' # use "lm loss" for backward compatibility
loss_dict = OrderedDict({loss_key: loss})
if additional_losses is not None:
loss_dict.update(additional_losses)
grad_norm = model[0].get_global_grad_norm()
return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad
return loss_dict, skipped_iter, grad_norm, num_zeros_in_grad

# Set grad to zero.
if not args.deepspeed:
Expand Down

0 comments on commit bcedecd

Please sign in to comment.