-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Iluvatar] Support tensor parallel heterogeneous training (#134)
This PR adds support for training LLM on heterogeneous devices with different tensor parallel sizes. --------- Co-authored-by: yu.song <[email protected]>
- Loading branch information
Showing
11 changed files
with
2,773 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,4 +49,4 @@ action: run | |
|
||
hydra: | ||
run: | ||
dir: ${experiment.exp_dir}/hydra | ||
dir: ${experiment.exp_dir}/hydra |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
defaults: | ||
- train: train_llama2_7b_tp_hetero | ||
- _self_ | ||
|
||
experiment: | ||
exp_name: llama2_tp_hetero | ||
exp_dir: ./outputs_llama2_tp_hetero | ||
task: | ||
type: train | ||
backend: megatron | ||
entrypoint: ./flagscale/train/hetero/train_llama.py | ||
runner: | ||
backend: torchrun | ||
nnodes: 1 | ||
nproc_per_node: 8 | ||
hostfile: hostfile | ||
envs: | ||
CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7 | ||
CUDA_DEVICE_MAX_CONNECTIONS: 1 | ||
|
||
action: run | ||
|
||
hydra: | ||
run: | ||
dir: ${experiment.exp_dir}/hydra |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
system: | ||
tensor_model_parallel_size: 4 | ||
pipeline_model_parallel_size: 3 | ||
disable_bias_linear: True | ||
use_flash_attn: True | ||
sequence_parallel: True | ||
use_distributed_optimizer: True | ||
hetero_mode: pp | ||
hetero_device_types: A100 | ||
hetero_current_device_type: A100 | ||
hetero_pipeline_stages: [3,16,8,8] | ||
process_meshes: [4,1,1,2,1,2] | ||
precision: | ||
bf16: True | ||
initial_loss_scale: 16384 | ||
min_loss_scale: 1.0 | ||
logging: | ||
log_interval: 1 | ||
checkpoint: | ||
save_interval: 100 | ||
|
||
model: | ||
use_mcore_models: True | ||
transformer_impl: transformer_engine | ||
num_layers: 32 | ||
hidden_size: 4096 | ||
ffn_hidden_size: 11008 | ||
num_attention_heads: 32 | ||
seq_length: 4096 | ||
group_query_attention: False | ||
num_query_groups: 8 | ||
max_position_embeddings: 4096 | ||
norm_epsilon: 1e-5 | ||
use_rotary_position_embeddings: True | ||
no_position_embedding: True | ||
swiglu: True | ||
normalization: RMSNorm | ||
untie_embeddings_and_output_weights: True | ||
init_method_std: 0.02 | ||
attention_dropout: 0.0 | ||
hidden_dropout: 0.0 | ||
weight_decay: 0.1 | ||
clip_grad: 1.0 | ||
train_iters: 30 | ||
eval_iters: 0 | ||
eval_interval: 2000 | ||
micro_batch_size: 1 | ||
global_batch_size: 32 | ||
|
||
optimizer: | ||
weight_decay: 1e-2 | ||
adam_beta1: 0.9 | ||
adam_beta2: 0.95 | ||
lr_scheduler: | ||
lr: 0.00015 | ||
min_lr: 1.0e-5 | ||
lr_warmup_fraction: .01 | ||
lr_decay_iters: 1 | ||
lr_decay_style: cosine | ||
|
||
data: | ||
data_path: ${data_path:??} | ||
split: 1 | ||
tokenizer: | ||
tokenizer_type: Llama2Tokenizer | ||
tokenizer_model: examples/llama/tokenizer.model | ||
vocab_size: 32000 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
|
||
import torch | ||
from datetime import timedelta | ||
from megatron.training import get_args | ||
from megatron.core import parallel_state as mpu | ||
from megatron.training.global_vars import set_hetero_context, get_hetero_context | ||
from megatron.training.arguments import parse_args, validate_args | ||
from megatron.training.yaml_arguments import validate_yaml | ||
from megatron.training.checkpointing import load_args_from_checkpoint | ||
from megatron.training.global_vars import set_global_variables, set_global_writers | ||
from megatron.training.initialize import _set_random_seed, _init_autoresume, _compile_dependencies, _initialize_tp_communicators | ||
from megatron.training.utils import save_checkpoint_info | ||
|
||
def initialize_megatron( | ||
extra_args_provider=None, | ||
args_defaults={}, | ||
ignore_unknown_args=False, | ||
allow_no_cuda=False, | ||
skip_mpu_initialization=False, | ||
): | ||
"""Set global variables, initialize distributed, and | ||
set autoresume and random seeds. | ||
`allow_no_cuda` should not be set unless using megatron for cpu only | ||
data processing. In general this arg should not be set unless you know | ||
what you are doing. | ||
Returns a function to finalize distributed env initialization | ||
(optionally, only when args.lazy_mpu_init == True) | ||
""" | ||
if not allow_no_cuda: | ||
# Make sure cuda is available. | ||
assert torch.cuda.is_available(), "Megatron requires CUDA." | ||
|
||
# Parse arguments | ||
args = parse_args(extra_args_provider, ignore_unknown_args) | ||
|
||
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): | ||
assert args.load is not None, "--use-checkpoints-args requires --load argument" | ||
load_args_from_checkpoint(args) | ||
|
||
if args.yaml_cfg is not None: | ||
args = validate_yaml(args, args_defaults) | ||
else: | ||
validate_args(args, args_defaults) | ||
|
||
|
||
# set global args, build tokenizer, and set adlr-autoresume, | ||
# tensorboard-writer, and timers. | ||
set_global_variables(args) | ||
|
||
# torch.distributed initialization | ||
def finish_mpu_init(): | ||
args = get_args() | ||
# Pytorch distributed. | ||
_initialize_distributed() | ||
|
||
# Random seeds for reproducibility. | ||
if args.rank == 0: | ||
print("> setting random seeds to {} ...".format(args.seed)) | ||
_set_random_seed(args.seed, args.data_parallel_random_init) | ||
|
||
# Set tensorboard writer and wandb writer. | ||
set_global_writers(args) | ||
|
||
|
||
if skip_mpu_initialization: | ||
return None | ||
|
||
args = get_args() | ||
if args.lazy_mpu_init: | ||
# TODO is this still a necessary option? | ||
args.use_cpu_initialization = True | ||
# delayed initialization of DDP-related stuff | ||
# We only set basic DDP globals | ||
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) | ||
# and return function for external DDP manager | ||
# to call when it has DDP initialized | ||
mpu.set_tensor_model_parallel_rank(args.rank) | ||
return finish_mpu_init | ||
else: | ||
# Megatron's MPU is the master. Complete initialization right away. | ||
finish_mpu_init() | ||
|
||
# Autoresume. | ||
_init_autoresume() | ||
|
||
# Compile dependencies. | ||
_compile_dependencies() | ||
|
||
save_checkpoint_info(args.save) | ||
|
||
if args.tp_comm_overlap: | ||
_initialize_tp_communicators() | ||
|
||
# No continuation function | ||
return None | ||
|
||
def _initialize_distributed(): | ||
"""Initialize torch.distributed and core model parallel.""" | ||
args = get_args() | ||
|
||
device_count = torch.cuda.device_count() | ||
if torch.distributed.is_initialized(): | ||
|
||
if args.rank == 0: | ||
print( | ||
"torch distributed is already initialized, " | ||
"skipping initialization ...", | ||
flush=True, | ||
) | ||
args.rank = torch.distributed.get_rank() | ||
args.world_size = torch.distributed.get_world_size() | ||
|
||
else: | ||
|
||
if args.rank == 0: | ||
print("> initializing torch distributed ...", flush=True) | ||
# Manually set the device ids. | ||
if device_count > 0: | ||
device = args.rank % device_count | ||
if args.local_rank is not None: | ||
assert ( | ||
args.local_rank == device | ||
), "expected local-rank to be the same as rank % device-count." | ||
else: | ||
args.local_rank = device | ||
torch.cuda.set_device(device) | ||
# Call the init process | ||
torch.distributed.init_process_group( | ||
backend=args.distributed_backend, | ||
world_size=args.world_size, | ||
rank=args.rank, | ||
timeout=timedelta(minutes=args.distributed_timeout_minutes), | ||
) | ||
|
||
if args.num_process_meshes == None: | ||
if args.hetero_mode is not None: | ||
# Build the heterogenous context after torch.distributed is initialized and | ||
# before model parallel is initialized. | ||
set_hetero_context(args) | ||
if torch.distributed.get_rank() == 0: | ||
print(get_hetero_context(), flush=True) | ||
|
||
# Set the tensor model-parallel, pipeline model-parallel, and | ||
# data-parallel communicators. | ||
if device_count > 0: | ||
if mpu.model_parallel_is_initialized(): | ||
print("model parallel is already initialized") | ||
else: | ||
if args.num_process_meshes != None: | ||
mpu.initialize_hetero_model_parallel( | ||
args, | ||
args.tensor_model_parallel_size, | ||
args.pipeline_model_parallel_size, | ||
args.virtual_pipeline_model_parallel_size, | ||
args.pipeline_model_parallel_split_rank, | ||
context_parallel_size=args.context_parallel_size, | ||
expert_model_parallel_size=args.expert_model_parallel_size, | ||
distributed_timeout_minutes=args.distributed_timeout_minutes, | ||
nccl_communicator_config_path=args.nccl_communicator_config_path, | ||
) | ||
else: | ||
mpu.initialize_model_parallel( | ||
args.tensor_model_parallel_size, | ||
args.pipeline_model_parallel_size, | ||
args.virtual_pipeline_model_parallel_size, | ||
args.pipeline_model_parallel_split_rank, | ||
context_parallel_size=args.context_parallel_size, | ||
expert_model_parallel_size=args.expert_model_parallel_size, | ||
distributed_timeout_minutes=args.distributed_timeout_minutes, | ||
nccl_communicator_config_path=args.nccl_communicator_config_path, | ||
hetero_mode=args.hetero_mode, | ||
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp', | ||
) | ||
if args.rank == 0: | ||
print( | ||
f"> initialized tensor model parallel with size " | ||
f"{mpu.get_tensor_model_parallel_world_size()}" | ||
) | ||
print( | ||
f"> initialized pipeline model parallel with size " | ||
f"{mpu.get_pipeline_model_parallel_world_size()}" | ||
) |
Oops, something went wrong.