Skip to content

Commit

Permalink
[Iluvatar] Support tensor parallel heterogeneous training (#134)
Browse files Browse the repository at this point in the history
This PR adds support for training LLM on heterogeneous devices with
different tensor parallel sizes.

---------

Co-authored-by: yu.song <[email protected]>
  • Loading branch information
njuerect and yu.song authored Jun 7, 2024
1 parent ac373cb commit 9838ede
Show file tree
Hide file tree
Showing 11 changed files with 2,773 additions and 60 deletions.
2 changes: 1 addition & 1 deletion examples/llama/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ action: run

hydra:
run:
dir: ${experiment.exp_dir}/hydra
dir: ${experiment.exp_dir}/hydra
25 changes: 25 additions & 0 deletions examples/llama/conf/config_hetero.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- train: train_llama2_7b_tp_hetero
- _self_

experiment:
exp_name: llama2_tp_hetero
exp_dir: ./outputs_llama2_tp_hetero
task:
type: train
backend: megatron
entrypoint: ./flagscale/train/hetero/train_llama.py
runner:
backend: torchrun
nnodes: 1
nproc_per_node: 8
hostfile: hostfile
envs:
CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7
CUDA_DEVICE_MAX_CONNECTIONS: 1

action: run

hydra:
run:
dir: ${experiment.exp_dir}/hydra
67 changes: 67 additions & 0 deletions examples/llama/conf/train/train_llama2_7b_tp_hetero.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
system:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 3
disable_bias_linear: True
use_flash_attn: True
sequence_parallel: True
use_distributed_optimizer: True
hetero_mode: pp
hetero_device_types: A100
hetero_current_device_type: A100
hetero_pipeline_stages: [3,16,8,8]
process_meshes: [4,1,1,2,1,2]
precision:
bf16: True
initial_loss_scale: 16384
min_loss_scale: 1.0
logging:
log_interval: 1
checkpoint:
save_interval: 100

model:
use_mcore_models: True
transformer_impl: transformer_engine
num_layers: 32
hidden_size: 4096
ffn_hidden_size: 11008
num_attention_heads: 32
seq_length: 4096
group_query_attention: False
num_query_groups: 8
max_position_embeddings: 4096
norm_epsilon: 1e-5
use_rotary_position_embeddings: True
no_position_embedding: True
swiglu: True
normalization: RMSNorm
untie_embeddings_and_output_weights: True
init_method_std: 0.02
attention_dropout: 0.0
hidden_dropout: 0.0
weight_decay: 0.1
clip_grad: 1.0
train_iters: 30
eval_iters: 0
eval_interval: 2000
micro_batch_size: 1
global_batch_size: 32

optimizer:
weight_decay: 1e-2
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 0.00015
min_lr: 1.0e-5
lr_warmup_fraction: .01
lr_decay_iters: 1
lr_decay_style: cosine

data:
data_path: ${data_path:??}
split: 1
tokenizer:
tokenizer_type: Llama2Tokenizer
tokenizer_model: examples/llama/tokenizer.model
vocab_size: 32000
Empty file.
183 changes: 183 additions & 0 deletions flagscale/train/hetero/initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import torch
from datetime import timedelta
from megatron.training import get_args
from megatron.core import parallel_state as mpu
from megatron.training.global_vars import set_hetero_context, get_hetero_context
from megatron.training.arguments import parse_args, validate_args
from megatron.training.yaml_arguments import validate_yaml
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_global_variables, set_global_writers
from megatron.training.initialize import _set_random_seed, _init_autoresume, _compile_dependencies, _initialize_tp_communicators
from megatron.training.utils import save_checkpoint_info

def initialize_megatron(
extra_args_provider=None,
args_defaults={},
ignore_unknown_args=False,
allow_no_cuda=False,
skip_mpu_initialization=False,
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), "Megatron requires CUDA."

# Parse arguments
args = parse_args(extra_args_provider, ignore_unknown_args)

if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
assert args.load is not None, "--use-checkpoints-args requires --load argument"
load_args_from_checkpoint(args)

if args.yaml_cfg is not None:
args = validate_yaml(args, args_defaults)
else:
validate_args(args, args_defaults)


# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(args)

# torch.distributed initialization
def finish_mpu_init():
args = get_args()
# Pytorch distributed.
_initialize_distributed()

# Random seeds for reproducibility.
if args.rank == 0:
print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)

# Set tensorboard writer and wandb writer.
set_global_writers(args)


if skip_mpu_initialization:
return None

args = get_args()
if args.lazy_mpu_init:
# TODO is this still a necessary option?
args.use_cpu_initialization = True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager
# to call when it has DDP initialized
mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()

# Autoresume.
_init_autoresume()

# Compile dependencies.
_compile_dependencies()

save_checkpoint_info(args.save)

if args.tp_comm_overlap:
_initialize_tp_communicators()

# No continuation function
return None

def _initialize_distributed():
"""Initialize torch.distributed and core model parallel."""
args = get_args()

device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():

if args.rank == 0:
print(
"torch distributed is already initialized, "
"skipping initialization ...",
flush=True,
)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()

else:

if args.rank == 0:
print("> initializing torch distributed ...", flush=True)
# Manually set the device ids.
if device_count > 0:
device = args.rank % device_count
if args.local_rank is not None:
assert (
args.local_rank == device
), "expected local-rank to be the same as rank % device-count."
else:
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
timeout=timedelta(minutes=args.distributed_timeout_minutes),
)

if args.num_process_meshes == None:
if args.hetero_mode is not None:
# Build the heterogenous context after torch.distributed is initialized and
# before model parallel is initialized.
set_hetero_context(args)
if torch.distributed.get_rank() == 0:
print(get_hetero_context(), flush=True)

# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if device_count > 0:
if mpu.model_parallel_is_initialized():
print("model parallel is already initialized")
else:
if args.num_process_meshes != None:
mpu.initialize_hetero_model_parallel(
args,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
context_parallel_size=args.context_parallel_size,
expert_model_parallel_size=args.expert_model_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
)
else:
mpu.initialize_model_parallel(
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
context_parallel_size=args.context_parallel_size,
expert_model_parallel_size=args.expert_model_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
hetero_mode=args.hetero_mode,
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp',
)
if args.rank == 0:
print(
f"> initialized tensor model parallel with size "
f"{mpu.get_tensor_model_parallel_world_size()}"
)
print(
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)
Loading

0 comments on commit 9838ede

Please sign in to comment.