Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Iluvatar] Support tensor parallel heterogeneous training #134

Merged
merged 8 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
183 changes: 183 additions & 0 deletions flagscale/train/hetero/initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import torch
from datetime import timedelta
from megatron.training import get_args
from megatron.core import parallel_state as mpu
from megatron.training.global_vars import set_hetero_context, get_hetero_context
from megatron.training.arguments import parse_args, validate_args
from megatron.training.yaml_arguments import validate_yaml
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_global_variables, set_global_writers
from megatron.training.initialize import _set_random_seed, _init_autoresume, _compile_dependencies, _initialize_tp_communicators
from megatron.training.utils import save_checkpoint_info

def initialize_megatron(
extra_args_provider=None,
args_defaults={},
ignore_unknown_args=False,
allow_no_cuda=False,
skip_mpu_initialization=False,
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), "Megatron requires CUDA."

# Parse arguments
args = parse_args(extra_args_provider, ignore_unknown_args)

if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
assert args.load is not None, "--use-checkpoints-args requires --load argument"
load_args_from_checkpoint(args)

if args.yaml_cfg is not None:
args = validate_yaml(args, args_defaults)
else:
validate_args(args, args_defaults)


# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(args)

# torch.distributed initialization
def finish_mpu_init():
args = get_args()
# Pytorch distributed.
_initialize_distributed()

# Random seeds for reproducibility.
if args.rank == 0:
print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)

# Set tensorboard writer and wandb writer.
set_global_writers(args)


if skip_mpu_initialization:
return None

args = get_args()
if args.lazy_mpu_init:
# TODO is this still a necessary option?
args.use_cpu_initialization = True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager
# to call when it has DDP initialized
mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()

# Autoresume.
_init_autoresume()

# Compile dependencies.
_compile_dependencies()

save_checkpoint_info(args.save)

if args.tp_comm_overlap:
_initialize_tp_communicators()

# No continuation function
return None

def _initialize_distributed():
"""Initialize torch.distributed and core model parallel."""
args = get_args()

device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():

if args.rank == 0:
print(
"torch distributed is already initialized, "
"skipping initialization ...",
flush=True,
)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()

else:

if args.rank == 0:
print("> initializing torch distributed ...", flush=True)
# Manually set the device ids.
if device_count > 0:
device = args.rank % device_count
if args.local_rank is not None:
assert (
args.local_rank == device
), "expected local-rank to be the same as rank % device-count."
else:
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
timeout=timedelta(minutes=args.distributed_timeout_minutes),
)

if args.num_process_meshes == None:
if args.hetero_mode is not None:
# Build the heterogenous context after torch.distributed is initialized and
# before model parallel is initialized.
set_hetero_context(args)
if torch.distributed.get_rank() == 0:
print(get_hetero_context(), flush=True)

# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if device_count > 0:
if mpu.model_parallel_is_initialized():
print("model parallel is already initialized")
else:
if args.num_process_meshes != None:
mpu.initialize_hetero_model_parallel(
args,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
context_parallel_size=args.context_parallel_size,
expert_model_parallel_size=args.expert_model_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
)
else:
mpu.initialize_model_parallel(
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
context_parallel_size=args.context_parallel_size,
expert_model_parallel_size=args.expert_model_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
hetero_mode=args.hetero_mode,
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp',
)
if args.rank == 0:
print(
f"> initialized tensor model parallel with size "
f"{mpu.get_tensor_model_parallel_world_size()}"
)
print(
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)
Loading
Loading