diff --git a/examples/llama/conf/config.yaml b/examples/llama/conf/config.yaml index 89ae03e9b..77c284745 100644 --- a/examples/llama/conf/config.yaml +++ b/examples/llama/conf/config.yaml @@ -49,4 +49,4 @@ action: run hydra: run: - dir: ${experiment.exp_dir}/hydra \ No newline at end of file + dir: ${experiment.exp_dir}/hydra diff --git a/examples/llama/conf/config_hetero.yaml b/examples/llama/conf/config_hetero.yaml new file mode 100644 index 000000000..1addf67cc --- /dev/null +++ b/examples/llama/conf/config_hetero.yaml @@ -0,0 +1,25 @@ +defaults: + - train: train_llama2_7b_tp_hetero + - _self_ + +experiment: + exp_name: llama2_tp_hetero + exp_dir: ./outputs_llama2_tp_hetero + task: + type: train + backend: megatron + entrypoint: ./flagscale/train/hetero/train_llama.py + runner: + backend: torchrun + nnodes: 1 + nproc_per_node: 8 + hostfile: hostfile + envs: + CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/llama/conf/train/train_llama2_7b_tp_hetero.yaml b/examples/llama/conf/train/train_llama2_7b_tp_hetero.yaml new file mode 100644 index 000000000..90995e2d9 --- /dev/null +++ b/examples/llama/conf/train/train_llama2_7b_tp_hetero.yaml @@ -0,0 +1,67 @@ +system: + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 3 + disable_bias_linear: True + use_flash_attn: True + sequence_parallel: True + use_distributed_optimizer: True + hetero_mode: pp + hetero_device_types: A100 + hetero_current_device_type: A100 + hetero_pipeline_stages: [3,16,8,8] + process_meshes: [4,1,1,2,1,2] + precision: + bf16: True + initial_loss_scale: 16384 + min_loss_scale: 1.0 + logging: + log_interval: 1 + checkpoint: + save_interval: 100 + +model: + use_mcore_models: True + transformer_impl: transformer_engine + num_layers: 32 + hidden_size: 4096 + ffn_hidden_size: 11008 + num_attention_heads: 32 + seq_length: 4096 + group_query_attention: False + num_query_groups: 8 + max_position_embeddings: 4096 + norm_epsilon: 1e-5 + use_rotary_position_embeddings: True + no_position_embedding: True + swiglu: True + normalization: RMSNorm + untie_embeddings_and_output_weights: True + init_method_std: 0.02 + attention_dropout: 0.0 + hidden_dropout: 0.0 + weight_decay: 0.1 + clip_grad: 1.0 + train_iters: 30 + eval_iters: 0 + eval_interval: 2000 + micro_batch_size: 1 + global_batch_size: 32 + + optimizer: + weight_decay: 1e-2 + adam_beta1: 0.9 + adam_beta2: 0.95 + lr_scheduler: + lr: 0.00015 + min_lr: 1.0e-5 + lr_warmup_fraction: .01 + lr_decay_iters: 1 + lr_decay_style: cosine + +data: + data_path: ${data_path:??} + split: 1 + tokenizer: + tokenizer_type: Llama2Tokenizer + tokenizer_model: examples/llama/tokenizer.model + vocab_size: 32000 diff --git a/flagscale/train/hetero/__init__.py b/flagscale/train/hetero/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flagscale/train/hetero/initialize.py b/flagscale/train/hetero/initialize.py new file mode 100644 index 000000000..4887f1909 --- /dev/null +++ b/flagscale/train/hetero/initialize.py @@ -0,0 +1,183 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch +from datetime import timedelta +from megatron.training import get_args +from megatron.core import parallel_state as mpu +from megatron.training.global_vars import set_hetero_context, get_hetero_context +from megatron.training.arguments import parse_args, validate_args +from megatron.training.yaml_arguments import validate_yaml +from megatron.training.checkpointing import load_args_from_checkpoint +from megatron.training.global_vars import set_global_variables, set_global_writers +from megatron.training.initialize import _set_random_seed, _init_autoresume, _compile_dependencies, _initialize_tp_communicators +from megatron.training.utils import save_checkpoint_info + +def initialize_megatron( + extra_args_provider=None, + args_defaults={}, + ignore_unknown_args=False, + allow_no_cuda=False, + skip_mpu_initialization=False, +): + """Set global variables, initialize distributed, and + set autoresume and random seeds. + `allow_no_cuda` should not be set unless using megatron for cpu only + data processing. In general this arg should not be set unless you know + what you are doing. + Returns a function to finalize distributed env initialization + (optionally, only when args.lazy_mpu_init == True) + """ + if not allow_no_cuda: + # Make sure cuda is available. + assert torch.cuda.is_available(), "Megatron requires CUDA." + + # Parse arguments + args = parse_args(extra_args_provider, ignore_unknown_args) + + if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): + assert args.load is not None, "--use-checkpoints-args requires --load argument" + load_args_from_checkpoint(args) + + if args.yaml_cfg is not None: + args = validate_yaml(args, args_defaults) + else: + validate_args(args, args_defaults) + + + # set global args, build tokenizer, and set adlr-autoresume, + # tensorboard-writer, and timers. + set_global_variables(args) + + # torch.distributed initialization + def finish_mpu_init(): + args = get_args() + # Pytorch distributed. + _initialize_distributed() + + # Random seeds for reproducibility. + if args.rank == 0: + print("> setting random seeds to {} ...".format(args.seed)) + _set_random_seed(args.seed, args.data_parallel_random_init) + + # Set tensorboard writer and wandb writer. + set_global_writers(args) + + + if skip_mpu_initialization: + return None + + args = get_args() + if args.lazy_mpu_init: + # TODO is this still a necessary option? + args.use_cpu_initialization = True + # delayed initialization of DDP-related stuff + # We only set basic DDP globals + mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) + # and return function for external DDP manager + # to call when it has DDP initialized + mpu.set_tensor_model_parallel_rank(args.rank) + return finish_mpu_init + else: + # Megatron's MPU is the master. Complete initialization right away. + finish_mpu_init() + + # Autoresume. + _init_autoresume() + + # Compile dependencies. + _compile_dependencies() + + save_checkpoint_info(args.save) + + if args.tp_comm_overlap: + _initialize_tp_communicators() + + # No continuation function + return None + +def _initialize_distributed(): + """Initialize torch.distributed and core model parallel.""" + args = get_args() + + device_count = torch.cuda.device_count() + if torch.distributed.is_initialized(): + + if args.rank == 0: + print( + "torch distributed is already initialized, " + "skipping initialization ...", + flush=True, + ) + args.rank = torch.distributed.get_rank() + args.world_size = torch.distributed.get_world_size() + + else: + + if args.rank == 0: + print("> initializing torch distributed ...", flush=True) + # Manually set the device ids. + if device_count > 0: + device = args.rank % device_count + if args.local_rank is not None: + assert ( + args.local_rank == device + ), "expected local-rank to be the same as rank % device-count." + else: + args.local_rank = device + torch.cuda.set_device(device) + # Call the init process + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, + rank=args.rank, + timeout=timedelta(minutes=args.distributed_timeout_minutes), + ) + + if args.num_process_meshes == None: + if args.hetero_mode is not None: + # Build the heterogenous context after torch.distributed is initialized and + # before model parallel is initialized. + set_hetero_context(args) + if torch.distributed.get_rank() == 0: + print(get_hetero_context(), flush=True) + + # Set the tensor model-parallel, pipeline model-parallel, and + # data-parallel communicators. + if device_count > 0: + if mpu.model_parallel_is_initialized(): + print("model parallel is already initialized") + else: + if args.num_process_meshes != None: + mpu.initialize_hetero_model_parallel( + args, + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.virtual_pipeline_model_parallel_size, + args.pipeline_model_parallel_split_rank, + context_parallel_size=args.context_parallel_size, + expert_model_parallel_size=args.expert_model_parallel_size, + distributed_timeout_minutes=args.distributed_timeout_minutes, + nccl_communicator_config_path=args.nccl_communicator_config_path, + ) + else: + mpu.initialize_model_parallel( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.virtual_pipeline_model_parallel_size, + args.pipeline_model_parallel_split_rank, + context_parallel_size=args.context_parallel_size, + expert_model_parallel_size=args.expert_model_parallel_size, + distributed_timeout_minutes=args.distributed_timeout_minutes, + nccl_communicator_config_path=args.nccl_communicator_config_path, + hetero_mode=args.hetero_mode, + order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp', + ) + if args.rank == 0: + print( + f"> initialized tensor model parallel with size " + f"{mpu.get_tensor_model_parallel_world_size()}" + ) + print( + f"> initialized pipeline model parallel with size " + f"{mpu.get_pipeline_model_parallel_world_size()}" + ) \ No newline at end of file diff --git a/flagscale/train/hetero/p2p_communication.py b/flagscale/train/hetero/p2p_communication.py new file mode 100644 index 000000000..875445a9f --- /dev/null +++ b/flagscale/train/hetero/p2p_communication.py @@ -0,0 +1,343 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import operator +from functools import reduce +from typing import Callable, List, Optional, Tuple, Union + +import torch + +from megatron.core import ModelParallelConfig +#from flagscale.hetero import parallel_state +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_group, + get_pipeline_model_parallel_next_rank, + get_pipeline_model_parallel_prev_rank, + get_pipeline_model_parallel_rank, + get_same_tp_pipeline_model_parallel_group, + get_diff_tp_pipeline_model_parallel_group, + get_same_tp_pipeline_model_parallel_next_rank, + get_same_tp_pipeline_model_parallel_prev_rank, + get_diff_tp_pipeline_model_parallel_next_rank, + get_diff_tp_pipeline_model_parallel_prev_rank, + is_pipeline_first_stage, + is_pipeline_last_stage +) + +# Types +Shape = Union[List[int], torch.Size] + +def _batched_p2p_ops_tp_hetero( + *, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + send_group_ids: int, + recv_group_ids: int, +): + reqs = [] + max_group_ids = max(send_group_ids, recv_group_ids) + group_ops = [[] for _ in range(max_group_ids)] + if tensor_send_prev is not None: + if len(tensor_send_prev) == send_group_ids: + for i in range(send_group_ids): + group_ops[i].append(torch.distributed.P2POp( + torch.distributed.isend, + tensor_send_prev[i], + get_pipeline_model_parallel_prev_rank()[i], + get_same_tp_pipeline_model_parallel_group() if get_pipeline_model_parallel_prev_rank()[i] == get_same_tp_pipeline_model_parallel_prev_rank() \ + else get_diff_tp_pipeline_model_parallel_group()[i] + )) + else: + for i in range(send_group_ids): + group_ops[i].append(torch.distributed.P2POp( + torch.distributed.isend, + tensor_send_prev[0], + get_pipeline_model_parallel_prev_rank()[i], + get_same_tp_pipeline_model_parallel_group() if get_pipeline_model_parallel_prev_rank()[i] == get_same_tp_pipeline_model_parallel_prev_rank() \ + else get_diff_tp_pipeline_model_parallel_group()[i] + )) + if tensor_recv_prev is not None: + for i in range(recv_group_ids): + group_ops[i].append(torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv_prev[i], + get_pipeline_model_parallel_prev_rank()[i], + get_same_tp_pipeline_model_parallel_group() if get_pipeline_model_parallel_prev_rank()[i] == get_same_tp_pipeline_model_parallel_prev_rank() \ + else get_diff_tp_pipeline_model_parallel_group()[i] + )) + if tensor_send_next is not None: + if len(tensor_send_next) == send_group_ids: + for i in range(send_group_ids): + group_ops[i].append(torch.distributed.P2POp( + torch.distributed.isend, + tensor_send_next[i], + get_pipeline_model_parallel_next_rank()[i], + get_same_tp_pipeline_model_parallel_group() if get_pipeline_model_parallel_next_rank()[i] == get_same_tp_pipeline_model_parallel_next_rank() \ + else get_diff_tp_pipeline_model_parallel_group()[i] + )) + else: + for i in range(send_group_ids): + group_ops[i].append(torch.distributed.P2POp( + torch.distributed.isend, + tensor_send_next[0], + get_pipeline_model_parallel_next_rank()[i], + get_same_tp_pipeline_model_parallel_group() if get_pipeline_model_parallel_next_rank()[i] == get_same_tp_pipeline_model_parallel_next_rank() \ + else get_diff_tp_pipeline_model_parallel_group()[i] + )) + if tensor_recv_next is not None: + for i in range(recv_group_ids): + group_ops[i].append(torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv_next[i], + get_pipeline_model_parallel_next_rank()[i], + get_same_tp_pipeline_model_parallel_group() if get_pipeline_model_parallel_next_rank()[i] == get_same_tp_pipeline_model_parallel_next_rank() \ + else get_diff_tp_pipeline_model_parallel_group()[i] + )) + + for i in range(len(group_ops)): + if len(group_ops[i]) > 0: + reqs.append(torch.distributed.batch_isend_irecv(group_ops[i])) + + return reqs + +def _communicate_tp_hetero( + *, + tensor_send_next: Optional[torch.Tensor], + tensor_send_prev: Optional[torch.Tensor], + recv_prev: bool, + recv_next: bool, + tensor_shapes: list, + group_ids: Union[int, List[int]], + config: ModelParallelConfig, + wait_on_reqs: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + + tensor_recv_prev = None + tensor_recv_next = None + + recv_prev_shape = tensor_shapes + recv_next_shape = tensor_shapes + + ''' + if not config.variable_seq_lengths: + recv_prev_shape = tensor_shape + recv_next_shape = tensor_shape + else: + recv_prev_shape, recv_next_shape = _communicate_shapes( + tensor_send_next, tensor_send_prev, recv_prev, recv_next, config + ) + ''' + + if recv_prev: + if config.pipeline_dtype is None: + raise RuntimeError("pipeline_dtype must be provided if recv_prev is True") + if tensor_shapes is None: + raise RuntimeError( + "tensor_shape must be specified if recv_prev is True. " + "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)" + ) + if tensor_send_prev != None: + assert isinstance(group_ids, list), 'send_backward and recv_forward need 2 group_ids!' + send_group_ids, recv_group_ids = group_ids[0], group_ids[1] + else: + send_group_ids, recv_group_ids = 0, group_ids + tensor_recv_prev = [] + for i in range(recv_group_ids): + tensor_recv_prev.append(torch.empty( + recv_prev_shape[i] if len(recv_prev_shape) == recv_group_ids else recv_prev_shape[0], + requires_grad=True, + device=torch.cuda.current_device(), + dtype=config.pipeline_dtype, + )) + else: + if tensor_send_prev != None: + send_group_ids, recv_group_ids = group_ids, 0 + + if recv_next: + if config.pipeline_dtype is None: + raise RuntimeError("dtype must be provided if recv_next is True") + if tensor_shapes is None: + raise RuntimeError( + "tensor_shape must be specified if recv_next is True. " + "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)" + ) + if tensor_send_next != None: + assert isinstance(group_ids, list), 'send_forward and recv_backward need 2 group_ids!' + send_group_ids, recv_group_ids = group_ids[0], group_ids[1] + else: + send_group_ids, recv_group_ids = 0, group_ids + tensor_recv_next = [] + for i in range(recv_group_ids): + tensor_recv_next.append(torch.empty( + recv_next_shape[i] if len(recv_next_shape) == recv_group_ids else recv_next_shape[0], + requires_grad=True, + device=torch.cuda.current_device(), + dtype=config.pipeline_dtype, + )) + else: + if tensor_send_next != None: + send_group_ids, recv_group_ids = group_ids, 0 + + #Note: normal p2p_ops hang + #p2p_func = _p2p_ops_tp_hetero + assert wait_on_reqs + p2p_func = _batched_p2p_ops_tp_hetero + reqs = p2p_func( + tensor_send_prev=tensor_send_prev, + tensor_recv_prev=tensor_recv_prev, + tensor_send_next=tensor_send_next, + tensor_recv_next=tensor_recv_next, + send_group_ids=send_group_ids, + recv_group_ids=recv_group_ids, + ) + + if wait_on_reqs and len(reqs) > 0: + for req in reqs: + if isinstance(req, list): + for op in req: + op.wait() + else: + req.wait() + reqs = None + + if config.batch_p2p_comm and config.batch_p2p_sync: + # To protect against race condition when using batch_isend_irecv(). + # User should assert that we have a modern enough PyTorch to not need this + torch.cuda.synchronize() + + return tensor_recv_prev, tensor_recv_next, reqs + +def tp_hetero_recv_forward(tensor_shapes, group_ids, config): + """ Receive tensor from previous rank in pipeline (forward receive). + + See _communicate for argument details. + """ + + if is_pipeline_first_stage(): + input_tensor = None + else: + if config.timers is not None: + config.timers('forward-recv', log_level=2).start() + input_tensor, _, _ = _communicate_tp_hetero( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False, + tensor_shapes=tensor_shapes, + group_ids=group_ids, + config=config, + ) + if config.timers is not None: + config.timers('forward-recv').stop() + return input_tensor + +def tp_hetero_recv_backward(tensor_shapes, group_ids, config): + """Receive tensor from next rank in pipeline (backward receive). + + See _communicate for argument details. + """ + if is_pipeline_last_stage(): + output_tensor_grad = None + else: + if config.timers is not None: + config.timers('backward-recv', log_level=2).start() + _, output_tensor_grad, _ = _communicate_tp_hetero( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + tensor_shapes=tensor_shapes, + group_ids=group_ids, + config=config, + ) + if config.timers is not None: + config.timers('backward-recv').stop() + return output_tensor_grad + +def tp_hetero_send_forward(output_tensors, group_ids, config): + """Send tensor to next rank in pipeline (forward send). + + See _communicate for argument details. + """ + + if not is_pipeline_last_stage(): + if config.timers is not None: + config.timers('forward-send', log_level=2).start() + _communicate_tp_hetero( + tensor_send_next=output_tensors, + tensor_send_prev=None, + recv_prev=False, + recv_next=False, + tensor_shapes=None, + group_ids=group_ids, + config=config, + ) + if config.timers is not None: + config.timers('forward-send').stop() + +def tp_hetero_send_backward(input_tensor_grads, group_ids, config): + """Send tensor to previous rank in pipeline (backward send). + + See _communicate for argument details. + """ + if not is_pipeline_first_stage(): + if config.timers is not None: + config.timers('backward-send', log_level=2).start() + _communicate_tp_hetero( + tensor_send_next=None, + tensor_send_prev=input_tensor_grads, + recv_prev=False, + recv_next=False, + tensor_shapes=None, + group_ids=group_ids, + config=config, + ) + if config.timers is not None: + config.timers('backward-send').stop() + +def tp_hetero_send_forward_recv_backward(output_tensors, tensor_shapes, send_group_ids, recv_group_ids, config): + """Batched send and recv with next rank in pipeline. + + See _communicate for argument details. + """ + if is_pipeline_last_stage(): + output_tensor_grad = None + else: + if config.timers is not None: + config.timers('forward-send-backward-recv', log_level=2).start() + _, output_tensor_grad, _ = _communicate_tp_hetero( + tensor_send_next=output_tensors, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + tensor_shapes=tensor_shapes, + group_ids=[send_group_ids, recv_group_ids], + config=config, + ) + if config.timers is not None: + config.timers('forward-send-backward-recv').stop() + return output_tensor_grad + +def tp_hetero_send_backward_recv_forward(input_tensor_grads, tensor_shapes, send_group_ids, recv_group_ids, config): + """Batched send and recv with previous rank in pipeline. + + See _communicate for argument details. + """ + if is_pipeline_first_stage(): + input_tensor = None + else: + if config.timers is not None: + config.timers('backward-send-forward-recv', log_level=2).start() + input_tensor, _, _ = _communicate_tp_hetero( + tensor_send_next=None, + tensor_send_prev=input_tensor_grads, + recv_prev=True, + recv_next=False, + tensor_shapes=tensor_shapes, + group_ids=[send_group_ids, recv_group_ids], + config=config, + ) + if config.timers is not None: + config.timers('backward-send-forward-recv').stop() + return input_tensor \ No newline at end of file diff --git a/flagscale/train/hetero/schedules.py b/flagscale/train/hetero/schedules.py new file mode 100644 index 000000000..ca8b4366d --- /dev/null +++ b/flagscale/train/hetero/schedules.py @@ -0,0 +1,611 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import contextlib +from typing import Callable, Iterator, List, Optional, Union + +import torch +from torch.autograd.variable import Variable + +from megatron.core import parallel_state +from megatron.core.enums import ModelType +from flagscale.train.hetero import p2p_communication +from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler +from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type +from megatron.core.pipeline_parallel.schedules import ( + forward_backward_pipelining_with_interleaving, + forward_backward_pipelining_without_interleaving, + forward_backward_no_pipelining, + forward_step, + backward_step, + check_first_val_step, + deallocate_output_tensor + ) + +# Types +Shape = Union[List[int], torch.Size] + +def get_forward_backward_func(): + """Retrieves the appropriate forward_backward function given the + configuration of parallel_state. + + Returns a function that will perform all of the forward and + backward passes of the model given the pipeline model parallel + world size and virtual pipeline model parallel world size in the + global parallel_state. + + Note that if using sequence parallelism, the sequence length component of + the tensor shape is updated to original_sequence_length / + tensor_model_parallel_world_size. + + The function returned takes the following arguments: + + forward_step_func (required): A function that takes a data + iterator and a model as its arguments and return the model's + forward output and the loss function. The loss function should + take one torch.Tensor and return a torch.Tensor of loss and a + dictionary of string -> torch.Tensor. + + A third argument, checkpoint_activations_microbatch, indicates + that the activations for this microbatch should be + checkpointed. A None value for this argument indicates that + the default from the configuration should be used. This is + used when the + num_microbatches_with_partial_activation_checkpoints is used. + + For example: + + def loss_func(loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + def forward_step(data_iterator, model): + data, loss_mask = next(data_iterator) + output = model(data) + return output, partial(loss_func, loss_mask) + + + forward_backward_func(forward_step_func=forward_step, ...) + + + data_iterator (required): an iterator over the data, will be + passed as is to forward_step_func. Expected to be a list of + iterators in the case of interleaved pipeline parallelism. + + model (required): the actual model. Expected to be a list of modules in the case of interleaved + pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule. + + num_microbatches (int, required): + The number of microbatches to go through + + seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack + transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths + in the config is True. Otherwise, each microbatch in the current global batch size must use + this sequence length. + + micro_batch_size (int, required): The number of sequences in a microbatch. + + decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack + transformer. This is ignored for a single-stack transformer. + + forward_only (optional, default = False): Perform only the forward step + + collect_non_loss_data (optional, bool, default=False): TODO + + first_val_step (bool, optional): Is the first step of the validation phase. Used by + Transformer Engine modules to only update their fp8 weights only on the first validation step. + + """ + pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() + if pipeline_model_parallel_size > 1: + if isinstance(parallel_state.get_pipeline_model_parallel_group(), list): + assert parallel_state.get_virtual_pipeline_model_parallel_world_size() == None, \ + 'vp not supported for hetero tp mode' + forward_backward_func = forward_backward_pipelining_without_interleaving_hetero + else: + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + forward_backward_func = forward_backward_pipelining_with_interleaving + else: + forward_backward_func = forward_backward_pipelining_without_interleaving + else: + forward_backward_func = forward_backward_no_pipelining + return forward_backward_func + +def get_tp_hetero_tensor_shapes( + *, + send_recv: bool, + model_type: ModelType, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int, + config, +): + # send: True, recv: False + assert model_type == ModelType.encoder_or_decoder, \ + 'Only support encoder or decoder model type for tp hetero mode for now!' + #TODO: cp support + + tensor_shapes = [] + tp_size_of_each_pipeline_stage = parallel_state.get_tensor_model_parallel_size_of_each_pipeline_stage() + pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() + pipeline_size = parallel_state.get_pipeline_model_parallel_world_size() + tp_size_of_current_pipline_rank = tp_size_of_each_pipeline_stage[pipeline_rank] + + # Send + if send_recv: + tp_size_of_next_pipeline_rank = tp_size_of_each_pipeline_stage[(pipeline_rank + 1) % pipeline_size] + tp_scale = tp_size_of_current_pipline_rank / tp_size_of_next_pipeline_rank + if config.sequence_parallel: + if tp_size_of_current_pipline_rank == tp_size_of_next_pipeline_rank: + seq_length = seq_length // tp_size_of_current_pipline_rank + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + elif tp_size_of_current_pipline_rank > tp_size_of_next_pipeline_rank: + seq_length = seq_length // tp_size_of_current_pipline_rank + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + else: + seq_length = seq_length // tp_size_of_next_pipeline_rank + for i in range(tp_size_of_next_pipeline_rank // tp_size_of_current_pipline_rank): + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + else: + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + # Recv + else: + tp_size_of_prev_pipeline_rank = tp_size_of_each_pipeline_stage[(pipeline_rank - 1) % pipeline_size] + tp_scale = tp_size_of_prev_pipeline_rank / tp_size_of_current_pipline_rank + if config.sequence_parallel: + if tp_size_of_current_pipline_rank == tp_size_of_prev_pipeline_rank: + seq_length = seq_length // tp_size_of_current_pipline_rank + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + elif tp_size_of_current_pipline_rank > tp_size_of_prev_pipeline_rank: + seq_length = seq_length // tp_size_of_current_pipline_rank + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + else: + seq_length = seq_length // tp_size_of_prev_pipeline_rank + for i in range(tp_size_of_prev_pipeline_rank // tp_size_of_current_pipline_rank): + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + else: + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + + return tensor_shapes, tp_scale + +def tp_hetero_recv_backward(send_tensor_shapes, tp_scale, config): + output_tensor_grads = [] + if not parallel_state.is_pipeline_last_stage(): + if config.sequence_parallel: + # match + if int(tp_scale) == 1: + output_tensor_grads = p2p_communication.tp_hetero_recv_backward(send_tensor_shapes, 1, config) + # Fwd small tp -> large tp, Bwd small tp <- large tp + elif tp_scale < 1: + tmp_tensors = p2p_communication.tp_hetero_recv_backward(send_tensor_shapes, int(1 / tp_scale), config) + output_tensor_grads.append(torch.cat(tmp_tensors, dim=0)) + else: + output_tensor_grads = p2p_communication.tp_hetero_recv_backward(send_tensor_shapes, 1, config) + else: + output_tensor_grads = p2p_communication.tp_hetero_recv_backward(send_tensor_shapes, 1, config) + else: + output_tensor_grads.append(None) + return output_tensor_grads + +def tp_hetero_send_forward(output_tensors, tensor_shapes, tp_scale, config): + if not parallel_state.is_pipeline_last_stage(): + if isinstance(output_tensors, list): + output_tensor = output_tensors[0] + else: + output_tensor = output_tensors + + if config.sequence_parallel: + # match + if int(tp_scale) == 1: + p2p_communication.tp_hetero_send_forward([output_tensor], 1, config) + # small tp -> large tp + elif tp_scale < 1: + split_size = int(1 / tp_scale) + tmp_tensors = list(torch.chunk(output_tensor, split_size, dim=0)) + p2p_communication.tp_hetero_send_forward(tmp_tensors, split_size, config) + # large tp -> small tp + else: + p2p_communication.tp_hetero_send_forward([output_tensor], 1, config) + else: + # match + if int(tp_scale) == 1: + p2p_communication.tp_hetero_send_forward([output_tensor], 1, config) + # small tp -> large tp + elif tp_scale < 1: + p2p_communication.tp_hetero_send_forward([output_tensor], int(1 / tp_scale), config) + # large tp -> small tp + else: + tensor_rank = parallel_state.get_tensor_model_parallel_rank() + if tensor_rank % (int(tp_scale)) == 0: + p2p_communication.tp_hetero_send_forward([output_tensor], 1, config) + + +def tp_hetero_recv_forward(recv_tensor_shapes, tp_scale, config): + input_tensors = [] + if not parallel_state.is_pipeline_first_stage(): + if config.sequence_parallel: + # match + if int(tp_scale) == 1: + input_tensors = p2p_communication.tp_hetero_recv_forward(recv_tensor_shapes, 1, config) + # small tp -> large tp + elif tp_scale < 1: + input_tensors = p2p_communication.tp_hetero_recv_forward(recv_tensor_shapes, 1, config) + # large tp -> small tp + else: + tmp_tensors = p2p_communication.tp_hetero_recv_forward(recv_tensor_shapes, int(tp_scale), config) + input_tensors.append(torch.cat(tmp_tensors, dim=0)) + else: + input_tensors = p2p_communication.tp_hetero_recv_forward(recv_tensor_shapes, 1, config) + else: + input_tensors.append(None) + return input_tensors + +def tp_hetero_send_backward(input_tensor_grads, tensor_shapes, tp_scale, config): + if not parallel_state.is_pipeline_first_stage(): + if isinstance(input_tensor_grads, list): + input_tensor_grad = input_tensor_grads[0] + else: + input_tensor_grad = input_tensor_grads + + if config.sequence_parallel: + # match + if int(tp_scale) == 1: + p2p_communication.tp_hetero_send_backward([input_tensor_grad], 1, config) + # Fwd: small tp -> large tp, Bwd large tp -> small tp + elif tp_scale < 1: + p2p_communication.tp_hetero_send_backward([input_tensor_grad], 1, config) + # Fwd: large tp -> small tp, Bwd small tp -> large tp + else: + split_size = int(tp_scale) + tmp_tensors = list(torch.chunk(input_tensor_grad, split_size, dim=0)) + p2p_communication.tp_hetero_send_backward(tmp_tensors, split_size, config) + else: + # match + if int(tp_scale) == 1: + p2p_communication.tp_hetero_send_backward([input_tensor_grad], 1, config) + # Fwd: small tp -> large tp, Bwd large tp -> small tp + elif tp_scale < 1: + tensor_rank = parallel_state.get_tensor_model_parallel_rank() + if tensor_rank % (int(1 / tp_scale)) == 0: + p2p_communication.tp_hetero_send_backward([input_tensor_grad], 1, config) + # Fwd: large tp -> small tp, Bwd small tp -> large tp + else: + p2p_communication.tp_hetero_send_backward([input_tensor_grad], int(tp_scale), config) + +def tp_hetero_send_forward_recv_backward(output_tensors, tensor_shapes, tp_scale, config): + output_tensor_grads = [] + if not parallel_state.is_pipeline_last_stage(): + if isinstance(output_tensors, list): + output_tensor = output_tensors[0] + else: + output_tensor = output_tensors + + if config.sequence_parallel: + # match + if int(tp_scale) == 1: + output_tensor_grads = p2p_communication.tp_hetero_send_forward_recv_backward([output_tensor], tensor_shapes, 1, 1, config) + # small tp -> large_tp + elif tp_scale < 1: + split_size = int(1 / tp_scale) + tmp_tensors = list(torch.chunk(output_tensor, split_size, dim=0)) + recv_tensors = p2p_communication.tp_hetero_send_forward_recv_backward(tmp_tensors, tensor_shapes, split_size, split_size, config) + output_tensor_grads.append(torch.cat(recv_tensors, dim=0)) + else: + output_tensor_grads = p2p_communication.tp_hetero_send_forward_recv_backward([output_tensor], tensor_shapes, 1, 1, config) + else: + # match + if int(tp_scale) == 1: + output_tensor_grads = p2p_communication.tp_hetero_send_forward_recv_backward([output_tensor], tensor_shapes, 1, 1, config) + # small tp -> large_tp + elif tp_scale < 1: + output_tensor_grads = p2p_communication.tp_hetero_send_forward_recv_backward([output_tensor], tensor_shapes, int(1 / tp_scale), 1, config) + else: + tensor_rank = parallel_state.get_tensor_model_parallel_rank() + output_tensor_grads = p2p_communication.tp_hetero_send_forward_recv_backward([output_tensor], tensor_shapes, + 1 if tensor_rank % (int(tp_scale)) == 0 else 0, + 1, config) + else: + output_tensor_grads.append(None) + return output_tensor_grads + +def tp_hetero_send_backward_recv_forward(input_tensor_grads, tensor_shapes, tp_scale, config): + input_tensors = [] + if not parallel_state.is_pipeline_first_stage(): + if isinstance(input_tensor_grads, list): + input_tensor_grad = input_tensor_grads[0] + else: + input_tensor_grad = input_tensor_grads + + if config.sequence_parallel: + # match + if int(tp_scale) == 1: + input_tensors = p2p_communication.tp_hetero_send_backward_recv_forward([input_tensor_grad], tensor_shapes, 1, 1, config) + # small tp -> large tp + elif tp_scale < 1: + input_tensors = p2p_communication.tp_hetero_send_backward_recv_forward([input_tensor_grad], tensor_shapes, 1, 1, config) + # large tp -> small tp + else: + split_size = int(tp_scale) + tmp_tensors = list(torch.chunk(input_tensor_grad, split_size, dim=0)) + recv_tensors = p2p_communication.tp_hetero_send_backward_recv_forward(tmp_tensors, tensor_shapes, split_size, split_size, config) + input_tensors.append(torch.cat(recv_tensors, dim=0)) + else: + # match + if int(tp_scale) == 1: + input_tensors = p2p_communication.tp_hetero_send_backward_recv_forward([input_tensor_grad], tensor_shapes, 1, 1, config) + # small tp -> large tp + elif tp_scale < 1: + tensor_rank = parallel_state.get_tensor_model_parallel_rank() + input_tensors = p2p_communication.tp_hetero_send_backward_recv_forward([input_tensor_grad], tensor_shapes, + 1 if tensor_rank % (int(1 / tp_scale)) == 0 else 0, + 1, config) + # large tp -> small tp + else: + input_tensors = p2p_communication.tp_hetero_send_backward_recv_forward([input_tensor_grad], tensor_shapes, int(tp_scale), 1, config) + else: + input_tensors.append(None) + return input_tensors + +def forward_backward_pipelining_without_interleaving_hetero( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int = None, + forward_only: bool = False, + collect_non_loss_data: bool = False, + first_val_step: bool = None, +): + """Run non-interleaved 1F1B schedule, with communication between pipeline + stages. + + Returns dictionary with losses if the last stage, empty dict otherwise.""" + + if isinstance(model, list): + assert ( + len(model) == 1 + ), "non-interleaved pipeline parallelism does not support model chunking" + model = model[0] + if isinstance(data_iterator, list): + assert ( + len(data_iterator) == 1 + ), "non-pipeline-parallel schedule does not support model chunking" + data_iterator = data_iterator[0] + + config = get_model_config(model) + if config.overlap_p2p_comm: + raise ValueError( + "Non-interleaved pipeline parallelism does not support overlapping p2p communication" + ) + + if config.timers is not None: + config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + + # Disable async grad reductions + no_sync_func = config.no_sync_func + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + no_sync_context = None + + def disable_grad_sync(): + """Disable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is None: + no_sync_context = no_sync_func() + no_sync_context.__enter__() + + def enable_grad_sync(): + """Enable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is not None: + no_sync_context.__exit__(None, None, None) + no_sync_context = None + + disable_grad_sync() + + # Compute number of warmup microbatches. + num_warmup_microbatches = ( + parallel_state.get_pipeline_model_parallel_world_size() + - parallel_state.get_pipeline_model_parallel_rank() + - 1 + ) + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_microbatches_remaining = num_microbatches - num_warmup_microbatches + + # Checkpoint the activations of partial Transformer layers in a number of micro-batches + # within the maximum outstanding micro-batch backpropagations. + # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' + # checkpoint partial Transformer layers (or skip checkpointing) and + # the rest of micro-batches within a window of micro-batches checkpoint + # all Transformer layers. The window of micro-batches is set by the maximum + # outstanding backpropagations and becomes smaller at later pipeline stages. + # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf + max_outstanding_backprops = None + if config.num_microbatches_with_partial_activation_checkpoints is not None: + max_outstanding_backprops = num_warmup_microbatches + 1 + + model_type = get_model_type(model) + + rank = parallel_state.get_pipeline_model_parallel_rank() + + send_tensor_shapes, send_tp_scale = get_tp_hetero_tensor_shapes( + send_recv=True, + model_type=model_type, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=config, + ) + + recv_tensor_shapes, recv_tp_scale = get_tp_hetero_tensor_shapes( + send_recv=False, + model_type=model_type, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=config, + ) + + # Input, output tensors only need to be saved when doing backward passes + input_tensors = None + output_tensors = None + total_num_tokens = torch.tensor(0, dtype=torch.int).cuda() + + if not forward_only: + input_tensors = [] + output_tensors = [] + forward_data_store = [] + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + i % max_outstanding_backprops + >= config.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + input_tensor = tp_hetero_recv_forward(recv_tensor_shapes, recv_tp_scale, config) + output_tensor, num_tokens = forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + check_first_val_step(first_val_step, forward_only, i == 0), + current_microbatch=i, + ) + tp_hetero_send_forward(output_tensor, send_tensor_shapes, send_tp_scale, config) + total_num_tokens += num_tokens.item() + + if not forward_only: + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + input_tensor = tp_hetero_recv_forward(recv_tensor_shapes, recv_tp_scale, config) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = i == (num_microbatches_remaining - 1) + + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + (i + num_warmup_microbatches) % max_outstanding_backprops + ) >= config.num_microbatches_with_partial_activation_checkpoints + else: + checkpoint_activations_microbatch = None + + output_tensor, num_tokens = forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + check_first_val_step( + first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0) + ), + current_microbatch=i + num_warmup_microbatches, + ) + total_num_tokens += num_tokens.item() + + if forward_only: + tp_hetero_send_forward(output_tensor, send_tensor_shapes, send_tp_scale, config) + + if not last_iteration: + input_tensor = tp_hetero_recv_forward(recv_tensor_shapes, recv_tp_scale, config) + + else: + output_tensor_grad = tp_hetero_send_forward_recv_backward(output_tensor, send_tensor_shapes, send_tp_scale, config) + + # Add input_tensor and output_tensor to end of list. + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + + # Pop input_tensor and output_tensor from the start of the list for + # the backward pass. + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + + # Enable grad sync for the last microbatch in the batch if the full + # backward pass completes in the 1F1B stage. + if num_warmup_microbatches == 0 and last_iteration: + if config.grad_sync_func is None or rank == 0: + enable_grad_sync() + + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + if last_iteration: + input_tensor = None + tp_hetero_send_backward(input_tensor_grad, recv_tensor_shapes, recv_tp_scale, config) + else: + input_tensor = tp_hetero_send_backward_recv_forward(input_tensor_grad, recv_tensor_shapes, recv_tp_scale, config) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + + # Enable async grad reduction in the last backward pass + # Note: If grad sync function is provided, only enable + # async grad reduction in first pipeline stage. Other + # pipeline stages do grad reduction during pipeline + # bubble. + if i == num_warmup_microbatches - 1: + if config.grad_sync_func is None or rank == 0: + enable_grad_sync() + + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + + output_tensor_grad = tp_hetero_recv_backward(send_tensor_shapes, send_tp_scale, config) + + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + tp_hetero_send_backward(input_tensor_grad, recv_tensor_shapes, recv_tp_scale, config) + + # Launch any remaining grad reductions. + if no_sync_context is not None: + enable_grad_sync() + if config.grad_sync_func is not None: + config.grad_sync_func(model.parameters()) + + if config.finalize_model_grads_func is not None and not forward_only: + # Finalize model grads (perform full grad all-reduce / reduce-scatter for + # data parallelism, layernorm all-reduce for sequence parallelism, and + # embedding all-reduce for pipeline parallelism). + config.finalize_model_grads_func( + [model], total_num_tokens if config.calculate_per_token_loss else None + ) + + if config.timers is not None: + config.timers('forward-backward').stop() + + return forward_data_store \ No newline at end of file diff --git a/flagscale/train/hetero/train_llama.py b/flagscale/train/hetero/train_llama.py new file mode 100644 index 000000000..cb27f6392 --- /dev/null +++ b/flagscale/train/hetero/train_llama.py @@ -0,0 +1,292 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Pretrain GPT.""" + +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + +import torch +from functools import partial + +from typing import Union +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import parallel_state as mpu +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +import megatron.legacy.model +from megatron.core.models.gpt import GPTModel +from megatron.core.utils import StragglerDetector +from megatron.core.transformer.spec_utils import import_module +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, +) +from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.yaml_arguments import core_transformer_config_from_yaml +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) + +from flagscale.datasets.sft_dataset import SFTDatasetConfig, SFTDataset +from flagscale.train.extra_valid import extra_valid_dataset_provider +from flagscale.train.hetero.training import pretrain + +stimer = StragglerDetector() + +def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: + """Builds the model. + + If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + """ + args = get_args() + use_te = args.transformer_impl == "transformer_engine" + + print_rank_0('building GPT model ...') + # Experimental loading arguments from yaml + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + config = core_transformer_config_from_args(args) + + if args.use_mcore_models: + if args.spec is not None: + transformer_layer_spec = import_module(args.spec) + else: + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + else: + transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + ) + else: + assert ( + args.context_parallel_size == 1 + ), "Context parallelism is only supported with Megatron Core!" + + model = megatron.legacy.model.GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + assert not loss[0].isnan(), ( + f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) + + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def forward_step(data_iterator, model: GPTModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (GPTModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + ) + + +def core_sft_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + return SFTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + apply_sft_dataset_separated_loss_mask_if_existed=args.apply_sft_dataset_separated_loss_mask_if_existed, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + if args.apply_sft_dataset_separated_loss_mask_if_existed: + config = core_sft_dataset_config_from_args(args) + else: + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + elif args.apply_sft_dataset_separated_loss_mask_if_existed: + dataset_type = SFTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + is_dataset_built_on_rank, + config + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + extra_valid_dataset_provider.is_distributed = True + + pretrain(train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + get_batch_fn=get_batch, + extra_valid_dataset_provider=extra_valid_dataset_provider) diff --git a/flagscale/train/hetero/training.py b/flagscale/train/hetero/training.py new file mode 100644 index 000000000..2b475ce4c --- /dev/null +++ b/flagscale/train/hetero/training.py @@ -0,0 +1,680 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import dataclasses +import time +import os +import sys +import json +import gc +import torch + +# The earliest we can measure the start time. +_TRAIN_START_TIME = time.time() + +from megatron.core import tensor_parallel +from megatron.core.utils import get_model_config +from megatron.core.utils import check_param_hashes_across_dp_replicas +from megatron.training.checkpointing import load_checkpoint +from megatron.training.checkpointing import save_checkpoint +from megatron.training.checkpointing import get_checkpoint_name +from megatron.legacy.model import Float16Module +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.distributed import finalize_model_grads +from megatron.core.enums import ModelType +from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig +from megatron.training.initialize import set_jit_fusion_options +from megatron.training.initialize import write_args_to_tensorboard +from megatron.training.async_utils import maybe_finalize_async_save +from megatron.training.utils import ( + calc_params_l2_norm, + check_adlr_autoresume_termination, + is_last_rank, + print_rank_0, + print_rank_last, + report_memory, + unwrap_model, + append_to_progress_log, +) +from megatron.training.global_vars import ( + get_args, + get_signal_handler, + get_timers, + get_tensorboard_writer, + get_wandb_writer, + get_one_logger, + get_current_global_batch_size, + get_num_microbatches, + update_num_microbatches) + +from megatron.training.training import ( + print_datetime, + search_data, + build_train_valid_test_data_iterators, + get_optimizer_param_scheduler, + evaluate_and_print_results, + setup_model_and_optimizer, + save_checkpoint_and_time, + num_floating_point_operations, + training_log, + ) + +from flagscale.train.extra_valid import build_extra_valid_data_iterators +from flagscale.train.extra_valid import extra_evaluate_and_print_results +from flagscale.train.hetero.initialize import initialize_megatron +from flagscale.train.hetero.schedules import get_forward_backward_func +from megatron.core import parallel_state as mpu + +def pretrain(train_valid_test_dataset_provider, + model_provider, + model_type, + forward_step_func, + process_non_loss_data_func=None, + extra_args_provider=None, + args_defaults={}, + get_batch_fn=None, + extra_valid_dataset_provider=None): + """Main training program. + + This function will run the followings in the order provided: + 1) initialize Megatron. + 2) setup model, optimizer and lr schedule using the model_provider. + 3) call train_val_test_data_provider to get train/val/test datasets. + 4) train the modle using the forward_step_func. + + Args: + train_valid_test_dataset_provider: a function that takes the size of + train/valid/test dataset and returns `train, valid, test` datasets. + model_provider: a function that returns a vanilla version of the + model. By vanilla we mean a simple model on cpu with no fp16 or ddp. + model_type: an enum that specifies the type of model being trained. + forward_step_func: a function that takes a `data iterator` and `model`, + and returns a `loss` scalar with a dictionary with key:values being + the info we would like to monitor during training, for example + `lm-loss: value`. We also require that this function add + `batch generator` to the timers class. + process_non_loss_data_func: a function to post process outputs of the + network. It can be used for dumping output tensors (e.g images) to + tensorboard. It takes `collected data`(list of tensors), + `current iteration index` and `tensorboard writer` as arguments. + extra_args_provider: a function that takes a parser and adds arguments + to it. It is used for programs to add their own arguments. + args_defaults: a dictionary from argument-name to argument-value. It + to set already parse arguments. + """ + + # Initalize and get arguments, timers, and Tensorboard writer. + if not torch.cuda.is_available(): + initialize_megatron(extra_args_provider=extra_args_provider, + args_defaults=args_defaults,allow_no_cuda=True) + else: + initialize_megatron(extra_args_provider=extra_args_provider, + args_defaults=args_defaults) + + + args = get_args() + timers = get_timers() + + if args.log_progress: + append_to_progress_log("Starting job") + + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options() + + # Adjust the startup time so it reflects the largest value. + # This will be closer to what scheduler will see (outside of + # image ... launches. + global _TRAIN_START_TIME + start_time_tensor = torch.tensor([_TRAIN_START_TIME], + dtype=torch.double, + device='cuda') + torch.distributed.all_reduce(start_time_tensor, + op=torch.distributed.ReduceOp.MIN) + _TRAIN_START_TIME = start_time_tensor.item() + print_rank_0('time to initialize megatron (seconds): {:.3f}'.format( + time.time() - _TRAIN_START_TIME)) + print_datetime('after megatron is initialized') + + if args.data_searching_save is not None: + search_data(train_valid_test_dataset_provider, get_batch_fn) + return + + one_logger = get_one_logger() + if one_logger: + one_logger.log_metrics({ + 'train_iterations_warmup': 5 + }) + + # Model, optimizer, and learning rate. + timers('model-and-optimizer-setup', log_level=0).start(barrier=True) + model, optimizer, opt_param_scheduler = setup_model_and_optimizer( + model_provider, model_type) + + timers('model-and-optimizer-setup').stop() + print_datetime('after model, optimizer, and learning rate ' + 'scheduler are built') + config = get_model_config(model[0]) + + # Data stuff. + timers('train/valid/test-data-iterators-setup', log_level=0).start( + barrier=True) + if args.virtual_pipeline_model_parallel_size is not None: + train_data_iterator = [] + valid_data_iterator = [] + test_data_iterator = [] + for i in range(len(model)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + iterators = build_train_valid_test_data_iterators( + train_valid_test_dataset_provider) + train_data_iterator.append(iterators[0]) + valid_data_iterator.append(iterators[1]) + test_data_iterator.append(iterators[2]) + else: + train_data_iterator, valid_data_iterator, test_data_iterator \ + = build_train_valid_test_data_iterators( + train_valid_test_dataset_provider) + timers('train/valid/test-data-iterators-setup').stop() + print_datetime('after dataloaders are built') + + # Context used for persisting some state between checkpoint saves. + checkpointing_context = {} + + # Print setup timing. + print_rank_0('done with setup ...') + timers.log(['model-and-optimizer-setup', + 'train/valid/test-data-iterators-setup'], barrier=True) + + if not args.skip_train: + print_rank_0('training ...') + + if args.dataloader_type == 'cyclic' and args.retro_project_dir: + assert args.retro_cyclic_train_iters is not None + args.train_iters = args.retro_cyclic_train_iters + print_rank_0("retro cyclic train iters : %d" % args.train_iters) + + iteration = 0 + if args.do_train and args.train_iters > 0: + iteration, num_floating_point_operations_so_far = train( + forward_step_func, + model, optimizer, opt_param_scheduler, + train_data_iterator, valid_data_iterator, + process_non_loss_data_func, config, checkpointing_context, + extra_valid_dataset_provider) + + print_datetime('after training is done') + + if args.save and iteration != 0 and iteration % args.save_interval != 0: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler, + num_floating_point_operations_so_far, checkpointing_context) + else: + print_rank_0('skipping training (--skip-train is on) ...') + + iteration = args.iteration + + if args.do_valid: + prefix = f'iteration {iteration} on validation set' + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, config, + verbose=True, write_to_tensorboard=not args.skip_train) + + if args.do_test: + prefix = f'iteration {iteration} on test set' + evaluate_and_print_results(prefix, forward_step_func, + test_data_iterator, model, + iteration, process_non_loss_data_func, config, + verbose=True, write_to_tensorboard=not args.skip_train) + + maybe_finalize_async_save(blocking=True) + +def train(forward_step_func, model, optimizer, opt_param_scheduler, + train_data_iterator, valid_data_iterator, + process_non_loss_data_func, config, checkpointing_context, extra_valid_dataset_provider=None): + """Train the model function.""" + args = get_args() + timers = get_timers() + + # Write args to tensorboard + write_args_to_tensorboard() + + # Turn on training mode which enables dropout. + for model_module in model: + model_module.train() + + # Tracking loss. + total_loss_dict = {} + + # Iterations. + iteration = args.iteration + one_logger = get_one_logger() + if one_logger: + iteration_start = iteration + train_samples_start = args.consumed_train_samples + train_samples_target = args.train_samples + one_logger.log_metrics({ + 'train_samples_start': args.consumed_train_samples, + 'train_iterations_start': iteration, + 'train_samples_target': train_samples_target, + 'train_iterations_target': args.train_iters, + }) + + num_floating_point_operations_so_far = args.num_floating_point_operations_so_far + + # Setup some training config params + config.grad_scale_func = optimizer.scale_loss + config.timers = timers + if isinstance(model[0], DDP) and args.overlap_grad_reduce: + assert config.no_sync_func is None, \ + ('When overlap_grad_reduce is True, config.no_sync_func must be None; ' + 'a custom no_sync_func is not supported when overlapping grad-reduce') + config.no_sync_func = [model_chunk.no_sync for model_chunk in model] + if len(model) == 1: + config.no_sync_func = config.no_sync_func[0] + if args.delay_grad_reduce: + config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] + if len(model) == 1: + config.grad_sync_func = config.grad_sync_func[0] + if args.overlap_param_gather and args.delay_param_gather: + config.param_sync_func = [lambda x: optimizer.finish_param_sync(model_index, x) + for model_index in range(len(model))] + if len(model) == 1: + config.param_sync_func = config.param_sync_func[0] + config.finalize_model_grads_func = finalize_model_grads + + timers('interval-time', log_level=0).start(barrier=True) + print_datetime('before the start of training step') + report_memory_flag = True + exit = False + + if args.manual_gc: + # Disable the default garbage collector and perform the collection manually. + # This is to align the timing of garbage collection across ranks. + assert args.manual_gc_interval >= 0, \ + 'Manual garbage collection interval should be laerger than or equal to 0.' + gc.disable() + gc.collect() + + # Singleton Initialization + if args.log_straggler: + global stimer + world = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + mmcnt = args.straggler_minmax_count + stimer.configure(world, rank, + mmcnt = mmcnt, + enabled = not args.disable_straggler_on_startup, + port = args.straggler_ctrlr_port) + total_flops = 0.0 + + num_microbatches = get_num_microbatches() + + wandb_writer = get_wandb_writer() + if wandb_writer and args.wandb_log_model: + # wandb.watch's log_freg needs to take the accumulated number of microbatches into account + log_freq = args.wandb_log_model_interval * num_microbatches + wandb_writer.watch(unwrap_model(model), log="all", log_freq=log_freq) + + eval_duration = 0.0 + eval_iterations = 0 + def track_e2e_metrics(): + # Nested function to track a bunch of E2E APP metrics + if one_logger: + train_duration = timers('interval-time').active_time() # overall_elapsed + train_samples = args.consumed_train_samples - train_samples_start + train_iterations = iteration - iteration_start + train_iterations_time_msecs_avg = (train_duration * 1000.0) / train_iterations + if eval_iterations: + validation_iterations_time_msecs_avg = (eval_duration * 1000.0) / eval_iterations + else: + validation_iterations_time_msecs_avg = None + + one_logger.log_metrics({ + 'train_iterations_end': iteration, + 'train_samples_end': args.consumed_train_samples, + 'train_iterations': train_iterations, + 'train_samples': train_samples, + 'train_iterations_time_msecs_avg': train_iterations_time_msecs_avg, + 'validation_iterations_time_msecs_avg': validation_iterations_time_msecs_avg + }) + + while iteration < args.train_iters: + if args.profile and \ + iteration == args.profile_step_start and \ + torch.distributed.get_rank() in args.profile_ranks: + torch.cuda.cudart().cudaProfilerStart() + torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + + maybe_finalize_async_save(False) + + # Update number of microbatches first without consistency check to decide if a + # checkpoint should be saved. If the number of microbatches is different + # from the previous iteration, save a checkpoint. Then run consistency check + # to make sure training configuration is still valid. + update_num_microbatches(args.consumed_train_samples, consistency_check=False) + if get_num_microbatches() != num_microbatches and iteration != 0 \ + and args.save_when_num_microbatches_change: + assert get_num_microbatches() > num_microbatches, \ + "number of microbatches should be increasing due to batch size rampup" + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context) + num_microbatches = get_num_microbatches() + update_num_microbatches(args.consumed_train_samples, consistency_check=True) + + args.curr_iteration = iteration + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ + train_step(forward_step_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config) + iteration += 1 + if args.hetero_mode != "dp": + batch_size = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + args.consumed_train_samples += batch_size + num_fp_ops = num_floating_point_operations(args, batch_size) + num_floating_point_operations_so_far += num_fp_ops + total_flops += num_fp_ops + else: + micro_batch_for_all_data_parallel = sum(map(lambda x, y: x * y, + args.hetero_micro_batch_sizes, + args.hetero_data_parallel_splits)) + batch_size = get_num_microbatches() * micro_batch_for_all_data_parallel + args.consumed_train_samples += batch_size + num_fp_ops = num_floating_point_operations(args, batch_size) + num_floating_point_operations_so_far += num_fp_ops + total_flops += num_fp_ops + + # Logging. + loss_scale = optimizer.get_loss_scale().item() + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(model) + + if iteration % args.log_interval == 0: + track_e2e_metrics() + + learning_rate = None + decoupled_learning_rate = None + for param_group in optimizer.param_groups: + if param_group['is_decoupled_lr']: + decoupled_learning_rate = param_group['lr'] + else: + learning_rate = param_group['lr'] + report_memory_flag = training_log(loss_dict, total_loss_dict, + learning_rate, + decoupled_learning_rate, + iteration, loss_scale, + report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad) + # StragglerDetector + if iteration % args.log_interval == 0 and args.log_straggler: + stimer.report(total_flops, args.log_interval) + total_flops = 0.0 + + if args.check_weight_hash_across_dp_replicas_interval is not None and \ + iteration % args.check_weight_hash_across_dp_replicas_interval == 0: + if args.use_distributed_optimizer and args.overlap_param_gather: + optimizer.disable_pre_hook() + assert check_param_hashes_across_dp_replicas(model), \ + "Parameter hashes not matching across DP replicas" + torch.distributed.barrier() + print_rank_0(f">>> Weight hashes match after {iteration} iterations...") + if args.use_distributed_optimizer and args.overlap_param_gather: + optimizer.enable_pre_hook() + + # Autoresume + if args.adlr_autoresume and \ + (iteration % args.adlr_autoresume_interval == 0): + check_adlr_autoresume_termination(iteration, model, optimizer, + opt_param_scheduler) + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0 and \ + args.do_valid: + timers('interval-time').stop() + if args.use_distributed_optimizer and args.overlap_param_gather: + optimizer.disable_pre_hook() + if args.manual_gc and args.manual_gc_eval: + # Collect all objects. + gc.collect() + prefix = 'iteration {}'.format(iteration) + timers('eval-time', log_level=0).start(barrier=True) + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, + config, False) + eval_duration += timers('eval-time').elapsed() + eval_iterations += args.eval_iters + timers('eval-time').stop() + if args.manual_gc and args.manual_gc_eval: + # Collect only the objects created and used in evaluation. + gc.collect(generation=0) + if args.use_distributed_optimizer and args.overlap_param_gather: + optimizer.enable_pre_hook() + timers('interval-time', log_level=0).start(barrier=True) + + # Extra Evaluation + if args.extra_valid_interval and iteration % args.extra_valid_interval == 0: + # Need to rebuild the dataloaders for extra validation, + # but we don't need to rebuild the datasets + # TODO: refactor this code and test this in vp - @aoyulong + if args.virtual_pipeline_model_parallel_size is not None: + extra_valid_data_iterators = [] + for i in range(len(model)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + extra_valid_data_iterators.append( + build_extra_valid_data_iterators(extra_valid_dataset_provider)) + else: + extra_valid_data_iterators = build_extra_valid_data_iterators(extra_valid_dataset_provider) + # do_extra_valid flag is used to indicate that we are doing extra validation + # and is set in the build_extra_valid_data_iterators function + if args.do_extra_valid: + if args.use_distributed_optimizer and args.overlap_param_gather: + optimizer.disable_pre_hook() + if args.manual_gc and args.manual_gc_eval: + # Collect all objects. + gc.collect() + prefix = 'iteration {}'.format(iteration) + for extra_valid_index, extra_valid_data_iterator in enumerate(extra_valid_data_iterators): + extra_evaluate_and_print_results(extra_valid_index, prefix, forward_step_func, + extra_valid_data_iterator, model, + iteration, process_non_loss_data_func, + config, False) + if args.manual_gc and args.manual_gc_eval: + # Collect only the objects created and used in evaluation. + gc.collect(generation=0) + if args.use_distributed_optimizer and args.overlap_param_gather: + optimizer.enable_pre_hook() + + # Checkpointing + saved_checkpoint = False + if args.exit_signal_handler: + signal_handler = get_signal_handler() + if any(signal_handler.signals_received()): + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context) + print_datetime('exiting program after receiving SIGTERM.') + exit = True + break + + need_save = False + if args.rampup_batch_size is not None \ + and args.rampup_save_interval is not None: + rampup_samples = int(args.rampup_batch_size[2]) + if args.consumed_train_samples < rampup_samples: + if args.save and args.rampup_save_interval and \ + iteration % args.rampup_save_interval == 0: + need_save = True + else: + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: + need_save = True + else: + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: + need_save = True + + # if args.save and args.save_interval and \ + # iteration % args.save_interval == 0: + if need_save: + timers('interval-time').stop() + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context) + saved_checkpoint = True + timers('interval-time', log_level=0).start(barrier=True) + + # Exiting based on duration + if args.exit_duration_in_mins: + train_time = (time.time() - _TRAIN_START_TIME) / 60.0 + done_cuda = torch.tensor( + [train_time > args.exit_duration_in_mins], + dtype=torch.int, device='cuda') + torch.distributed.all_reduce( + done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + if done: + if not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context) + print_datetime('exiting program after {} minutes'.format(train_time)) + exit = True + break + + # Exiting based on iterations + if args.exit_interval and iteration % args.exit_interval == 0: + if args.save and not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context) + torch.distributed.barrier() + print_datetime('exiting program at iteration {}'.format(iteration)) + exit = True + break + + if args.profile and \ + iteration == args.profile_step_end and \ + torch.distributed.get_rank() in args.profile_ranks: + torch.cuda.cudart().cudaProfilerStop() + + if args.manual_gc: + if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0: + gc.collect() + + track_e2e_metrics() + + # Flush TensorBoard and WandB writers. + writer = get_tensorboard_writer() + if writer: + writer.flush() + wandb_writer = get_wandb_writer() + if wandb_writer: + wandb_writer.finish() + + # Close out pre-hooks if using distributed optimizer and overlapped param gather. + if args.use_distributed_optimizer and args.overlap_param_gather: + optimizer.disable_pre_hook() + + maybe_finalize_async_save(True) + + # If any exit conditions (signal handler, duration, iterations) have been reached, exit. + if exit: + sys.exit() + + return iteration, num_floating_point_operations_so_far + +def train_step(forward_step_func, data_iterator, + model, optimizer, opt_param_scheduler, config): + """Single training step.""" + args = get_args() + timers = get_timers() + + # Set grad to zero. + for model_chunk in model: + model_chunk.zero_grad_buffer() + optimizer.zero_grad() + + # Forward pass. + forward_backward_func = get_forward_backward_func() + losses_reduced = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=False) + + # Empty unused memory. + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + # Vision gradients. + if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino": + unwrapped_model = unwrap_model(model[0]) + unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) + + # Update parameters. + timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + timers('optimizer').stop() + + # Vision momentum. + if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino": + unwrapped_model = unwrap_model(model[0]) + unwrapped_model.update_momentum(args.curr_iteration) + + # Update learning rate. + if update_successful: + if args.hetero_mode != "dp": + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + else: + micro_batch_for_all_data_parallel = sum(map(lambda x, y: x * y, + args.hetero_micro_batch_sizes, + args.hetero_data_parallel_splits)) + increment = get_num_microbatches() * \ + micro_batch_for_all_data_parallel + opt_param_scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + # Empty unused memory. + if args.empty_unused_memory_level >= 2: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + for key in losses_reduced[0].keys(): + numerator = 0 + denominator = 0 + for x in losses_reduced: + val = x[key] + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + if isinstance(val, tuple) or isinstance(val, list): + numerator += val[0] + denominator += val[1] + else: + # legacy behavior. we average over the number of microbatches, + # and so the denominator is 1. + numerator += val + denominator += 1 + loss_reduced[key] = numerator / denominator + return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad + return {}, skipped_iter, grad_norm, num_zeros_in_grad \ No newline at end of file diff --git a/megatron/megatron/core/parallel_state.py b/megatron/megatron/core/parallel_state.py index e16aee9b6..07e80a5b0 100644 --- a/megatron/megatron/core/parallel_state.py +++ b/megatron/megatron/core/parallel_state.py @@ -9,12 +9,16 @@ import torch -from .utils import GlobalMemoryBuffer +from megatron.core.utils import GlobalMemoryBuffer # Intra-layer model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Inter-layer model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None +# same tp device parallel group that the current rank belongs to. +_SAME_TP_PIPELINE_MODEL_PARALLEL_GROUP = None +# difference tp device parallel group that the current rank belongs to. +_DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP = None # Model parallel group (both intra- and pipeline) that the current rank belongs to. _MODEL_PARALLEL_GROUP = None # Embedding group. @@ -56,6 +60,13 @@ # rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None +# A list of global ranks for each same/diff tp pipeline group +_SAME_TP_PIPELINE_GLOBAL_RANKS = None +_DIFF_TP_PIPELINE_GLOBAL_RANKS = None + +# A List of tensor parallel size of each pipeline stage for tp mixed mode +_TENSOR_PARALLEL_SIZE_OF_EACH_PIPELINE_STAGE = None + # A list of global ranks for each data parallel group to ease calculation of the source # rank when broadcasting weights from src to all other data parallel ranks _DATA_PARALLEL_GLOBAL_RANKS = None @@ -719,7 +730,300 @@ def initialize_model_parallel( # we could stick it there _set_global_memory_buffer() +def initialize_hetero_model_parallel( + args, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_split_rank: Optional[int] = None, + use_sharp: bool = False, + context_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + nccl_communicator_config_path: Optional[str] = None, + distributed_timeout_minutes: int = 30, +) -> None: + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + data_parallel_size = args.data_parallel_size + process_meshes_tp = args.process_meshes_tp + process_meshes_pp = args.process_meshes_pp + process_meshes_dp = args.process_meshes_dp + # len = p + 1, [0, sum(p0), sum(p0-p1), ..., sum(p0-pn-1)] + cumu_num_device_of_all_pipeline_stage = args.cumu_num_device_of_all_pipeline_stage + tp_size_of_each_pipeline_stage = args.tp_size_of_each_pipeline_stage + + assert expert_model_parallel_size == 1 and context_parallel_size == 1 and virtual_pipeline_model_parallel_size == None, \ + 'ep/cp != 1 or vp != None not supported now!' + + global _TENSOR_PARALLEL_SIZE_OF_EACH_PIPELINE_STAGE + _TENSOR_PARALLEL_SIZE_OF_EACH_PIPELINE_STAGE = tp_size_of_each_pipeline_stage + + rank = torch.distributed.get_rank() + + nccl_comm_cfgs = {} + if nccl_communicator_config_path is not None: + try: + import yaml + except ImportError: + raise RuntimeError( + "Cannot import `yaml`. Setting custom nccl communicator configs " + "requires the yaml package." + ) + + with open(nccl_communicator_config_path, "r") as stream: + nccl_comm_cfgs = yaml.safe_load(stream) + + timeout = timedelta(minutes=distributed_timeout_minutes) + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_GLOO + global _DATA_PARALLEL_GLOBAL_RANKS + assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized' + for i in range(pipeline_model_parallel_size): + start_rank = cumu_num_device_of_all_pipeline_stage[i] + end_rank = cumu_num_device_of_all_pipeline_stage[i+1] + for j in range(tp_size_of_each_pipeline_stage[i]): + ranks = [x for x in range( + start_rank + j, end_rank, tp_size_of_each_pipeline_stage[i] + )] + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs) + ) + group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo") + if rank in ranks: + #print(f"global_rank: {rank}, dp_parallel_group: {ranks}") + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GROUP_GLOO = group_gloo + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + # Build the data-context-parallel groups. + global _DATA_PARALLEL_GROUP_WITH_CP + global _DATA_PARALLEL_GROUP_WITH_CP_GLOO + global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP + assert args.context_parallel_size == 1, 'cp!=1 not support now!' + _DATA_PARALLEL_GROUP_WITH_CP = _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP_WITH_CP_GLOO = _DATA_PARALLEL_GROUP_GLOO + _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = _DATA_PARALLEL_GLOBAL_RANKS + + # Build the context-parallel groups. + global _CONTEXT_PARALLEL_GROUP + global _CONTEXT_PARALLEL_GLOBAL_RANKS + assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized' + ranks = [rank] + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('cp', nccl_comm_cfgs) + ) + _CONTEXT_PARALLEL_GROUP = group + _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks + + # Build the data_modulo_expert_group. + global _DATA_MODULO_EXPERT_PARALLEL_GROUP + assert ( + _DATA_MODULO_EXPERT_PARALLEL_GROUP is None + ), 'Data modulo expert group is already initialized' + global _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO + global _DATA_MODULO_EXPERT_PARALLEL_GLOBAL_RANKS + assert args.expert_model_parallel_size == 1, 'ep!=1 not support now!' + _DATA_MODULO_EXPERT_PARALLEL_GROUP = _DATA_PARALLEL_GROUP + _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = _DATA_PARALLEL_GROUP_GLOO + _DATA_MODULO_EXPERT_PARALLEL_GLOBAL_RANKS = _DATA_PARALLEL_GLOBAL_RANKS + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' + for i in range(data_parallel_size): + ranks = [] + for j in range(pipeline_model_parallel_size): + start_rank = cumu_num_device_of_all_pipeline_stage[j] + tp_size_of_each_pipeline_stage[j] * i + end_rank = cumu_num_device_of_all_pipeline_stage[j] + tp_size_of_each_pipeline_stage[j] * (i+1) + local_ranks = [i for i in range(start_rank, end_rank, 1)] + ranks.extend(local_ranks) + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('mp', nccl_comm_cfgs) + ) + if rank in ranks: + #print(f"global_rank: {rank}, mp_parallel_group: {ranks}") + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS + assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized' + for i in range(data_parallel_size): + for j in range(pipeline_model_parallel_size): + start_rank = cumu_num_device_of_all_pipeline_stage[j] + tp_size_of_each_pipeline_stage[j] * i + end_rank = cumu_num_device_of_all_pipeline_stage[j] + tp_size_of_each_pipeline_stage[j] * (i+1) + ranks = [i for i in range(start_rank, end_rank, 1)] + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs) + ) + if rank in ranks: + #print(f"global_rank: {rank}, tp_parallel_group: {ranks}") + _TENSOR_MODEL_PARALLEL_GROUP = group + _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks + + # Build the tensor + expert parallel groups + global _EXPERT_MODEL_PARALLEL_GROUP + assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized' + global _EXPERT_MODEL_PARALLEL_GLOBAL_RANKS + global _TENSOR_AND_EXPERT_PARALLEL_GROUP + assert ( + _TENSOR_AND_EXPERT_PARALLEL_GROUP is None + ), 'Tensor + expert parallel group is already initialized' + global _TENSOR_AND_EXPERT_PARALLEL_GLOBAL_RANKS + assert args.expert_model_parallel_size == 1, 'ep!=1 not support now!' + _TENSOR_AND_EXPERT_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_AND_EXPERT_PARALLEL_GLOBAL_RANKS = _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS + ranks = [rank] + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('exp', nccl_comm_cfgs) + ) + _EXPERT_MODEL_PARALLEL_GROUP = group + _EXPERT_MODEL_PARALLEL_GLOBAL_RANKS = ranks + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized' + + # EMBEDDING_GROUP: shared embedding and output linear set + # POSITION_EMBEDDING_GROUP: pipeline_model_parallel_split_rank set + global _EMBEDDING_GROUP + global _EMBEDDING_GLOBAL_RANKS + assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' + global _POSITION_EMBEDDING_GROUP + global _POSITION_EMBEDDING_GLOBAL_RANKS + assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized' + global _LAST_RANK_WHEN_USING_PIPELINE + assert _LAST_RANK_WHEN_USING_PIPELINE is None, 'last rank when using pipeline is already initialized' + max_tp_size = max(tp_size_of_each_pipeline_stage) + for i in range(max_tp_size): + for j in range(data_parallel_size): + ranks = [] + for k in range(pipeline_model_parallel_size): + if tp_size_of_each_pipeline_stage[k] == max_tp_size: + ranks.append(cumu_num_device_of_all_pipeline_stage[k] + tp_size_of_each_pipeline_stage[k] * j + i) + else: + ranks.append(cumu_num_device_of_all_pipeline_stage[k] + tp_size_of_each_pipeline_stage[k] * j + + i // (max_tp_size // tp_size_of_each_pipeline_stage[k]) ) + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('pp', nccl_comm_cfgs) + ) + if rank in ranks: + if _PIPELINE_MODEL_PARALLEL_GROUP == None: + _PIPELINE_MODEL_PARALLEL_GROUP = [group] + _PIPELINE_GLOBAL_RANKS = [ranks] + else: + _PIPELINE_MODEL_PARALLEL_GROUP.append(group) + _PIPELINE_GLOBAL_RANKS.append(ranks) + + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + position_embedding_ranks = [ranks[0]] + else: + embedding_ranks = ranks + position_embedding_ranks = ranks + + group = torch.distributed.new_group( + embedding_ranks, timeout=timeout, pg_options=get_nccl_options('embd', nccl_comm_cfgs) + ) + if rank in embedding_ranks: + _EMBEDDING_GROUP = group + if rank in ranks: + _EMBEDDING_GLOBAL_RANKS = embedding_ranks + + group = torch.distributed.new_group( + position_embedding_ranks, + timeout=timeout, + pg_options=get_nccl_options('embd', nccl_comm_cfgs), + ) + if rank in position_embedding_ranks: + _POSITION_EMBEDDING_GROUP = group + if rank in ranks: + _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks + + _LAST_RANK_WHEN_USING_PIPELINE = _PIPELINE_GLOBAL_RANKS[0][-1] + + # Build the same/diff tp pipeline groups + global _SAME_TP_PIPELINE_MODEL_PARALLEL_GROUP + global _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP + global _SAME_TP_PIPELINE_GLOBAL_RANKS + global _DIFF_TP_PIPELINE_GLOBAL_RANKS + assert _SAME_TP_PIPELINE_MODEL_PARALLEL_GROUP is None, 'same tp pipeline group is already initialized' + assert _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP is None, 'diff tp pipeline group is already initialized' + num_process_meshes = args.num_process_meshes + num_device_of_each_parallel_group = [tp * dp * pp for tp, dp, pp in + zip(process_meshes_tp, process_meshes_dp, process_meshes_pp)] + cumu_num_device_of_all_parallel_group = [sum(num_device_of_each_parallel_group[:i]) for i in range(num_process_meshes + 1)] + num_same_tp_pipeline_model_parallel_groups = [] + num_diff_tp_pipeline_model_parallel_groups = [] + assert num_process_meshes == len(set(process_meshes_tp)), 'all parallel group tp should not be the same' + for i in range(num_process_meshes): + num_same_tp_pipeline_model_parallel_groups.append(process_meshes_tp[i] * process_meshes_dp[i]) + for i in range(num_process_meshes - 1): + num_diff_tp_pipeline_model_parallel_groups.append(process_meshes_dp[i] * process_meshes_tp[i] \ + if process_meshes_tp[i] > process_meshes_tp[i+1] else \ + process_meshes_dp[i] * process_meshes_tp[i+1]) + + for i in range(num_process_meshes): + start_rank = cumu_num_device_of_all_parallel_group[i] + end_rank = cumu_num_device_of_all_parallel_group[i+1] + for j in range(process_meshes_dp[i]): + for k in range(process_meshes_tp[i]): + ranks = [x for x in range( + start_rank + j * process_meshes_tp[i] + k, end_rank, process_meshes_dp[i] * process_meshes_tp[i] + )] + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('same_tp_pp', nccl_comm_cfgs) + ) + if rank in ranks: + _SAME_TP_PIPELINE_MODEL_PARALLEL_GROUP = group + _SAME_TP_PIPELINE_GLOBAL_RANKS = ranks + + for i in range(num_process_meshes - 1): + prev_start_rank = cumu_num_device_of_all_parallel_group[i+1] - process_meshes_dp[i] * process_meshes_tp[i] + next_start_rank = cumu_num_device_of_all_parallel_group[i+1] + if process_meshes_tp[i] > process_meshes_tp[i+1]: + for j in range(process_meshes_dp[i]): + for k in range(process_meshes_tp[i]): + ranks = [prev_start_rank + j * process_meshes_tp[i] + k, + next_start_rank + j * process_meshes_tp[i+1] + k // (process_meshes_tp[i] // process_meshes_tp[i+1])] + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('diff_tp_pp', nccl_comm_cfgs) + ) + if rank in ranks: + if _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP == None: + _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP = [group] + _DIFF_TP_PIPELINE_GLOBAL_RANKS = [ranks] + else: + _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP.append(group) + _DIFF_TP_PIPELINE_GLOBAL_RANKS.append(ranks) + else: + for j in range(process_meshes_dp[i+1]): + for k in range(process_meshes_tp[i+1]): + ranks = [prev_start_rank + j * process_meshes_tp[i] + k // (process_meshes_tp[i+1] // process_meshes_tp[i]), + next_start_rank + j * process_meshes_tp[i+1] + k] + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('diff_tp_pp', nccl_comm_cfgs) + ) + if rank in ranks: + if _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP == None: + _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP = [group] + _DIFF_TP_PIPELINE_GLOBAL_RANKS = [ranks] + else: + _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP.append(group) + _DIFF_TP_PIPELINE_GLOBAL_RANKS.append(ranks) + + # Initialize global memory buffer + # This isn't really "parallel state" but there isn't another good place to + # put this. If we end up with a more generic initialization of megatron-core + # we could stick it there + _set_global_memory_buffer() + def is_initialized(): """Useful for code segments that may be accessed with or without mpu initialization""" return _DATA_PARALLEL_GROUP is not None @@ -770,6 +1074,17 @@ def get_pipeline_model_parallel_group(): ), 'pipeline_model parallel group is not initialized' return _PIPELINE_MODEL_PARALLEL_GROUP +def get_same_tp_pipeline_model_parallel_group(): + assert ( + _SAME_TP_PIPELINE_MODEL_PARALLEL_GROUP is not None + ), 'same tp pipeline_model parallel group is not initialized' + return _SAME_TP_PIPELINE_MODEL_PARALLEL_GROUP + +def get_diff_tp_pipeline_model_parallel_group(): + assert ( + _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP is not None + ), 'diff tp pipeline_model parallel group is not initialized' + return _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP def get_data_parallel_group(with_context_parallel=False): """Get the data parallel group the caller rank belongs to.""" @@ -901,6 +1216,10 @@ def set_virtual_pipeline_model_parallel_world_size(world_size): global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size +def get_tensor_model_parallel_size_of_each_pipeline_stage(): + """Return tensor parallel size of each pipeline stage for tp mixed mode""" + global _TENSOR_PARALLEL_SIZE_OF_EACH_PIPELINE_STAGE + return _TENSOR_PARALLEL_SIZE_OF_EACH_PIPELINE_STAGE def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" @@ -909,14 +1228,24 @@ def get_tensor_model_parallel_world_size(): return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE return len(_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS) - def get_pipeline_model_parallel_world_size(): """Return world size for the pipeline model parallel group.""" global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return len(_PIPELINE_GLOBAL_RANKS) + if isinstance(get_pipeline_model_parallel_group(), list): + return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()[0]) + else: + return len(_PIPELINE_GLOBAL_RANKS) + +def get_same_tp_pipeline_model_parallel_world_size(): + return torch.distributed.get_world_size(group=get_same_tp_pipeline_model_parallel_group()) +def get_diff_tp_pipeline_model_parallel_world_size(): + if isinstance(get_diff_tp_pipeline_model_parallel_group(), list): + return torch.distributed.get_world_size(group=get_diff_tp_pipeline_model_parallel_group()[0]) + else: + return torch.distributed.get_world_size(group=get_diff_tp_pipeline_model_parallel_group()) def set_expert_model_parallel_rank(rank): """Set expert model parallel rank.""" @@ -949,14 +1278,36 @@ def get_tensor_model_parallel_rank(): return _MPU_TENSOR_MODEL_PARALLEL_RANK return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS.index(torch.distributed.get_rank()) - def get_pipeline_model_parallel_rank(): """Return my rank for the pipeline model parallel group.""" global _MPU_PIPELINE_MODEL_PARALLEL_RANK if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: return _MPU_PIPELINE_MODEL_PARALLEL_RANK - return _PIPELINE_GLOBAL_RANKS.index(torch.distributed.get_rank()) + if isinstance(get_pipeline_model_parallel_group(), list): + return _PIPELINE_GLOBAL_RANKS[0].index(torch.distributed.get_rank()) + else: + return _PIPELINE_GLOBAL_RANKS.index(torch.distributed.get_rank()) + +def get_same_tp_pipeline_model_parallel_rank(): + return _SAME_TP_PIPELINE_GLOBAL_RANKS.index(torch.distributed.get_rank()) + +def get_diff_tp_pipeline_model_parallel_rank(): + if isinstance(get_diff_tp_pipeline_model_parallel_group(), list): + return _DIFF_TP_PIPELINE_GLOBAL_RANKS[0].index(torch.distributed.get_rank()) + else: + return _DIFF_TP_PIPELINE_GLOBAL_RANKS.index(torch.distributed.get_rank()) +def is_same_tp_pipeline_first_stage(): + return get_same_tp_pipeline_model_parallel_rank() == 0 + +def is_same_tp_pipeline_last_stage(): + return get_same_tp_pipeline_model_parallel_rank() == (get_same_tp_pipeline_model_parallel_world_size() - 1) + +def is_diff_tp_pipeline_first_stage(): + return get_diff_tp_pipeline_model_parallel_rank() == 0 + +def is_diff_tp_pipeline_last_stage(): + return get_diff_tp_pipeline_model_parallel_rank() == (get_diff_tp_pipeline_model_parallel_world_size() - 1) def get_pipeline_model_parallel_split_rank(): """Return pipeline model parallel split rank.""" @@ -1070,11 +1421,12 @@ def get_virtual_pipeline_model_parallel_world_size(): def get_tensor_model_parallel_src_rank(): """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" - assert ( - _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None - ), "Tensor model parallel group is not initialized" - return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0] - + if _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS != None: + return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0] + else: + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size def get_data_parallel_src_rank(with_context_parallel=False): """Calculate the global rank corresponding to the first local rank @@ -1093,7 +1445,10 @@ def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - return _PIPELINE_GLOBAL_RANKS[0] + if isinstance(get_pipeline_model_parallel_group(), list): + return _PIPELINE_GLOBAL_RANKS[0][0] + else: + return _PIPELINE_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): @@ -1101,7 +1456,10 @@ def get_pipeline_model_parallel_last_rank(): current tensor parallel group""" assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] + if isinstance(get_pipeline_model_parallel_group(), list): + return _PIPELINE_GLOBAL_RANKS[0][last_rank_local] + else: + return _PIPELINE_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): @@ -1109,7 +1467,16 @@ def get_pipeline_model_parallel_next_rank(): assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + if isinstance(get_pipeline_model_parallel_group(), list): + next_ranks = [] + for i in range(len(_PIPELINE_GLOBAL_RANKS)): + next_ranks.append(_PIPELINE_GLOBAL_RANKS[i][(rank_in_pipeline + 1) % world_size]) + if all(x == next_ranks[0] for x in next_ranks): + return [next_ranks[0]] + else: + return next_ranks + else: + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): @@ -1117,8 +1484,52 @@ def get_pipeline_model_parallel_prev_rank(): assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] - + if isinstance(get_pipeline_model_parallel_group(), list): + prev_ranks = [] + for i in range(len(_PIPELINE_GLOBAL_RANKS)): + prev_ranks.append(_PIPELINE_GLOBAL_RANKS[i][(rank_in_pipeline - 1) % world_size]) + if all(x == prev_ranks[0] for x in prev_ranks): + return [prev_ranks[0]] + else: + return prev_ranks + else: + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + +def get_same_tp_pipeline_model_parallel_next_rank(): + assert _SAME_TP_PIPELINE_GLOBAL_RANKS is not None, 'same tp pipeline group is not initialized' + rank_in_same_tp_pipeline = get_same_tp_pipeline_model_parallel_rank() + world_size = get_same_tp_pipeline_model_parallel_world_size() + return _SAME_TP_PIPELINE_GLOBAL_RANKS[(rank_in_same_tp_pipeline + 1) % world_size] + +def get_same_tp_pipeline_model_parallel_prev_rank(): + assert _SAME_TP_PIPELINE_GLOBAL_RANKS is not None, 'same tp pipeline group is not initialized' + rank_in_same_tp_pipeline = get_same_tp_pipeline_model_parallel_rank() + world_size = get_same_tp_pipeline_model_parallel_world_size() + return _SAME_TP_PIPELINE_GLOBAL_RANKS[(rank_in_same_tp_pipeline - 1) % world_size] + +def get_diff_tp_pipeline_model_parallel_next_rank(): + assert _DIFF_TP_PIPELINE_GLOBAL_RANKS is not None, 'diff tp pipeline group is not initialized' + rank_in_diff_tp_pipeline = get_diff_tp_pipeline_model_parallel_rank() + world_size = get_diff_tp_pipeline_model_parallel_world_size() + if isinstance(get_diff_tp_pipeline_model_parallel_group(), list): + next_ranks = [] + for i in range(len(_DIFF_TP_PIPELINE_GLOBAL_RANKS)): + next_ranks.append(_DIFF_TP_PIPELINE_GLOBAL_RANKS[i][(rank_in_diff_tp_pipeline + 1) % world_size]) + return next_ranks + else: + return _DIFF_TP_PIPELINE_GLOBAL_RANKS[(rank_in_diff_tp_pipeline + 1) % world_size] + +def get_diff_tp_pipeline_model_parallel_prev_rank(): + assert _DIFF_TP_PIPELINE_GLOBAL_RANKS is not None, 'diff tp pipeline group is not initialized' + rank_in_diff_tp_pipeline = get_diff_tp_pipeline_model_parallel_rank() + world_size = get_diff_tp_pipeline_model_parallel_world_size() + if isinstance(get_diff_tp_pipeline_model_parallel_group(), list): + next_ranks = [] + for i in range(len(_DIFF_TP_PIPELINE_GLOBAL_RANKS)): + next_ranks.append(_DIFF_TP_PIPELINE_GLOBAL_RANKS[i][(rank_in_diff_tp_pipeline - 1) % world_size]) + return next_ranks + else: + return _DIFF_TP_PIPELINE_GLOBAL_RANKS[(rank_in_diff_tp_pipeline - 1) % world_size] def get_last_rank_when_using_pipeline(): """Return the global rank of the last process in the pipeline""" @@ -1282,3 +1693,15 @@ def destroy_model_parallel(): _MPU_EXPERT_MODEL_PARALLEL_RANK = None global _LAST_RANK_WHEN_USING_PIPELINE _LAST_RANK_WHEN_USING_PIPELINE = None + global _SAME_TP_PIPELINE_MODEL_PARALLEL_GROUP + _SAME_TP_PIPELINE_MODEL_PARALLEL_GROUP = None + global _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP + _DIFF_TP_PIPELINE_MODEL_PARALLEL_GROUP = None + global _SAME_TP_PIPELINE_GLOBAL_RANKS + _SAME_TP_PIPELINE_GLOBAL_RANKS = None + global _DIFF_TP_PIPELINE_GLOBAL_RANKS + _DIFF_TP_PIPELINE_GLOBAL_RANKS = None + global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS + _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None + global _TENSOR_PARALLEL_SIZE_OF_EACH_PIPELINE_STAGE + _TENSOR_PARALLEL_SIZE_OF_EACH_PIPELINE_STAGE = None diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index 4435cd34d..d9a37f7e2 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -158,47 +158,125 @@ def validate_args(args, defaults={}): # Load saved args from Retro (if applicable). load_retro_args(args) - # Tensor model parallel size. - args.tensor_model_parallel_size = min( - args.tensor_model_parallel_size, args.world_size) - assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ - ' ({}) is not divisible by tensor model parallel size ({})'.format( - args.world_size, args.tensor_model_parallel_size) - - # Pipeline model parallel size. - args.pipeline_model_parallel_size = min( - args.pipeline_model_parallel_size, - (args.world_size // args.tensor_model_parallel_size)) - args.transformer_pipeline_model_parallel_size = ( - args.pipeline_model_parallel_size - 1 - if args.standalone_embedding_stage else - args.pipeline_model_parallel_size - ) + args.num_process_meshes = None + if args.process_meshes != None: + assert args.hetero_mode == "pp", \ + 'hetero_mode should be set to pp with process_meshes not None!' + + process_meshes_tp = args.process_meshes[::3] + process_meshes_dp = args.process_meshes[1::3] + process_meshes_pp = args.process_meshes[2::3] + + assert args.untie_embeddings_and_output_weights, \ + 'not support shared embeddings and output weights' - # Checks. - model_parallel_size = args.pipeline_model_parallel_size * \ - args.tensor_model_parallel_size - assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \ - 'world size ({}) is not divisible by tensor parallel size ({}) times ' \ - 'pipeline parallel size ({}) times context parallel size ({})'.format( - args.world_size, args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, args.context_parallel_size) - args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size) - if args.rank == 0: - print('using world size: {}, data-parallel size: {}, ' - 'context-parallel size: {} ' - 'tensor-model-parallel size: {}, ' - 'pipeline-model-parallel size: {} '.format( - args.world_size, args.data_parallel_size, - args.context_parallel_size, - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size), flush=True) - if args.pipeline_model_parallel_size > 1: - if args.pipeline_model_parallel_split_rank is not None: - assert args.pipeline_model_parallel_split_rank < \ - args.pipeline_model_parallel_size, 'split rank needs'\ - ' to be less than pipeline model parallel size ({})'.format( - args.pipeline_model_parallel_size) + args.num_process_meshes = len(process_meshes_tp) + assert args.num_process_meshes == 2, \ + 'only support 2 process_meshes for now!' + + #Data parallel size. + assert all(x == process_meshes_dp[0] for x in process_meshes_dp), \ + 'all parallel group dp should be the same!' + args.data_parallel_size = process_meshes_dp[0] + + #Pipeline model paralle size. + assert args.pipeline_model_parallel_size == sum(process_meshes_pp), \ + 'pipeline_model_parallel_size should match sum of process_meshes_pp!' + assert args.standalone_embedding_stage == False, \ + 'standalone not supported with process_meshes set!' + args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size + assert args.pipeline_model_parallel_split_rank == None, \ + 'pipeline_model_parallel_split_rank not supported with process_meshes set!' + + #Context parallel size. + assert args.context_parallel_size == 1, \ + 'cp!=1 not support now!' + + #Virtual parallel size. + assert args.num_layers_per_virtual_pipeline_stage == None, \ + 'virtual pipeline not support now!' + + #Expert parallel size. + assert args.expert_model_parallel_size == 1, \ + 'ep!=1 not support now!' + + #Tensor model parallel size + num_device_of_each_pipeline_stage = [] + tp_size_of_each_pipeline_stage = [] + for i in range(len(process_meshes_pp)): + for j in range(process_meshes_pp[i]): + tp_size_of_each_pipeline_stage.append(process_meshes_tp[i]) + num_device_of_each_pipeline_stage.append(process_meshes_tp[i] * args.data_parallel_size) + + # len = p + 1, [0, sum(p0), sum(p0-p1), ..., sum(p0-pn-1)] + cumu_num_device_of_all_pipeline_stage = [sum(num_device_of_each_pipeline_stage[:i]) for i in range(args.pipeline_model_parallel_size + 1)] + + for i in range(args.pipeline_model_parallel_size): + if cumu_num_device_of_all_pipeline_stage[i] <= args.rank < cumu_num_device_of_all_pipeline_stage[i+1]: + args.tensor_model_parallel_size = tp_size_of_each_pipeline_stage[i] + + assert args.world_size == sum(tp * dp * pp for tp, dp, pp in + zip(process_meshes_tp, process_meshes_dp, process_meshes_pp)), \ + 'total world size should match sum of all tp x dp x pp!' + + args.process_meshes_tp = process_meshes_tp + args.process_meshes_dp = process_meshes_dp + args.process_meshes_pp = process_meshes_pp + args.cumu_num_device_of_all_pipeline_stage = cumu_num_device_of_all_pipeline_stage + args.tp_size_of_each_pipeline_stage = tp_size_of_each_pipeline_stage + + if args.rank == 0: + print('using world size: {}, data-parallel size: {}, ' + 'context-parallel size: {} ' + 'tensor-model-parallel size: {}, ' + 'pipeline-model-parallel size: {} '.format( + args.world_size, args.data_parallel_size, + args.context_parallel_size, + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size), flush=True) + + else: + # Tensor model parallel size. + args.tensor_model_parallel_size = min( + args.tensor_model_parallel_size, args.world_size) + assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ + ' ({}) is not divisible by tensor model parallel size ({})'.format( + args.world_size, args.tensor_model_parallel_size) + + # Pipeline model parallel size. + args.pipeline_model_parallel_size = min( + args.pipeline_model_parallel_size, + (args.world_size // args.tensor_model_parallel_size)) + args.transformer_pipeline_model_parallel_size = ( + args.pipeline_model_parallel_size - 1 + if args.standalone_embedding_stage else + args.pipeline_model_parallel_size + ) + + # Checks. + model_parallel_size = args.pipeline_model_parallel_size * \ + args.tensor_model_parallel_size + assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \ + 'world size ({}) is not divisible by tensor parallel size ({}) times ' \ + 'pipeline parallel size ({}) times context parallel size ({})'.format( + args.world_size, args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, args.context_parallel_size) + args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size) + if args.rank == 0: + print('using world size: {}, data-parallel size: {}, ' + 'context-parallel size: {} ' + 'tensor-model-parallel size: {}, ' + 'pipeline-model-parallel size: {} '.format( + args.world_size, args.data_parallel_size, + args.context_parallel_size, + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size), flush=True) + if args.pipeline_model_parallel_size > 1: + if args.pipeline_model_parallel_split_rank is not None: + assert args.pipeline_model_parallel_split_rank < \ + args.pipeline_model_parallel_size, 'split rank needs'\ + ' to be less than pipeline model parallel size ({})'.format( + args.pipeline_model_parallel_size) if args.tp_comm_overlap: assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' @@ -240,8 +318,8 @@ def validate_args(args, defaults={}): setattr(args, key, defaults[key]) # Heterogeneous Training - assert args.hetero_mode is None, \ - "Hetero mode is not supported in this version. Please use the v0.3." + #assert args.hetero_mode is None, \ + # "Hetero mode is not supported in this version. Please use the v0.3." if args.hetero_mode: assert args.global_batch_size is not None, "global_batch_size should be specified when hetero_mode is not None" assert args.hetero_current_device_type, "hetero_current_device_type should be specified when hetero_mode is not None" @@ -538,8 +616,14 @@ def validate_args(args, defaults={}): # disable sequence parallelism when tp=1 # to avoid change in numerics when # sequence_parallelism is enabled. - if args.tensor_model_parallel_size == 1: - args.sequence_parallel = False + if args.num_process_meshes != None: + if 1 in args.tp_size_of_each_pipeline_stage: + if args.rank == 0: + print("Set sequence_parallel false for some parallel group's tp size match 1") + args.sequence_parallel = False + else: + if args.tensor_model_parallel_size == 1: + args.sequence_parallel = False # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled @@ -1934,6 +2018,11 @@ def _add_hetero_args(parser): 'hetero-pipeline-stages must be in the form:' 'n0 layers_0_0 layers_0_1 ... n1 nlayers_1_0 nlayers_1_1 ...' 'The order should be consistent with --hetero-device-types.') + group.add_argument('--process-meshes', nargs='*', type=int, default=None, + help='Use this arg to set TP/DP/PP of each process mesh group.' + 'This argument must be in the form: TP0, DP0, PP0, TP1, DP1, PP1' + '...TPN, DPN, PPN. TP size can be different, sum of PP should match ' + 'pipeline-model-parallel-size, DP size should be the same.') return parser