Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Bowen12992 authored Nov 18, 2024
2 parents 7533789 + dd4309c commit c4a809b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 75 deletions.
67 changes: 0 additions & 67 deletions examples/llama/conf/train/train_llama2_7b_tp_hetero.yaml

This file was deleted.

100 changes: 100 additions & 0 deletions examples/llama/conf/train/train_llama3_8b_hetero.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
system:
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 2
disable_bias_linear: True
use_flash_attn: True
sequence_parallel: True
use_distributed_optimizer: True
precision:
bf16: True
attention_softmax_in_fp32: true
accumulate_allreduce_grads_in_fp32: true
logging:
log_interval: 1
tensorboard_log_interval: 1
wandb_project: "train-llama3-8B"
wandb_exp_name: "train-test-8B"
checkpoint:
load: outputs_llama3/checkpoint_mc
save_interval: 10
finetune: True
ckpt_format: "torch"

model:
use_mcore_models: True
transformer_impl: transformer_engine
num_layers: 32
hidden_size: 4096
ffn_hidden_size: 14336
num_attention_heads: 32
seq_length: 4096
group_query_attention: True
num_query_groups: 8
max_position_embeddings: 8192
norm_epsilon: 1e-5
use_rotary_position_embeddings: True
no_position_embedding: True
swiglu: True
normalization: RMSNorm
rotary_interleaved_patch: False
position_embedding_type: rope
rotary_base: 500000
untie_embeddings_and_output_weights: True
init_method_std: 0.02
attention_dropout: 0.0
hidden_dropout: 0.0
clip_grad: 1.0
train_samples: 200000
eval_iters: 100
eval_interval: 1000
micro_batch_size: 1
global_batch_size: 16

hetero:
enable_hetero: True
hetero_use_cpu_communication: True
# mesh format [tp1,cp1,ep1,dp1,pp1,(tp2,cp2...)]

# 2 mesh, diff tp dp pp
hetero_pipeline_layer_split: [18, 14]
hetero_process_meshes: [2, 1, 1, 4, 1, 4, 1, 1, 2, 1]
hetero_device_types: ["A800", "A100"]

standalone_embedding_stage: False
hetero_current_device_type: "A800"

# recompute:
# recompute_granularity: "full"
# recompute_method: "uniform"
# recompute_num_layers: 1

# ## pp 2 stages and num_micro_batches 4
# recompute_granularity_per_stage_micro_batch:
# - [1, 3, 0, 1, 0]
# - [1, 3, 1, 1, 1]
# recompute_method_per_stage_micro_batch:
# - [1, 3, 0, 1, 0]
# - [1, 3, 0, 1, 0]
# recompute_num_layers_per_stage_micro_batch:
# - [1, 3, 2, 1, 2]
# - [1, 3, 1, 1, 1]


optimizer:
weight_decay: 1e-2
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 1.0e-5
min_lr: 1.0e-6
lr_warmup_fraction: .1
lr_decay_style: cosine

data:
data_path: examples/llama/pile-openwebtext_text_document/pile-openwebtext_text_document
split: 1
tokenizer:
tokenizer_type: Llama3TokenizerFS
tokenizer_path: meta-llama3/Meta-Llama-3-8B
vocab_size: 128256
make_vocab_size_divisible_by: 64
21 changes: 14 additions & 7 deletions flagscale/runner/runner_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _prepare(self):
self.user_args = _get_args_megatron(self.config)
self.rdzv_id = datetime.now().strftime("%Y%m%d_%H%M%S.%f")
self.user_envs = self.config.experiment.get("envs", {})
self.cur_envs = None # current node envs
self.user_script = self.config.experiment.task.entrypoint
self.resources = parse_hostfile(
self.config.experiment.runner.get("hostfile", None)
Expand All @@ -311,9 +312,8 @@ def _run_each(
dryrun=False,
):
export_cmd = []
cur_envs = add_decive_extra_config(self.user_envs, device_type)

for k, v in cur_envs.items():
for k, v in self.cur_envs.items():
export_cmd += [f"{k}={v}"]

runner_cmd = _get_runner_cmd_train(
Expand Down Expand Up @@ -366,11 +366,6 @@ def _run_each(
def run(self, with_test=False, dryrun=False, monitor=False, interval=10):

num_visible_devices = None
visible_devices = self.user_envs.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is not None and isinstance(visible_devices, str):
visible_devices = visible_devices.split(",")
num_visible_devices = len(visible_devices)

runner_config = self.config.experiment.runner

# If hostfile is provided, use the resources from the hostfile
Expand All @@ -383,6 +378,13 @@ def run(self, with_test=False, dryrun=False, monitor=False, interval=10):
for node_rank, (host, resource_info) in enumerate(self.resources.items()):
if node_rank >= nnodes:
break
self.cur_envs = add_decive_extra_config(
self.user_envs, resource_info["type"]
)
visible_devices = self.cur_envs.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is not None and isinstance(visible_devices, str):
visible_devices = visible_devices.split(",")
num_visible_devices = len(visible_devices)
nproc_from_hostfile = resource_info["slots"]
nproc_from_args = runner_config.get("nproc_per_node", None)
nproc_per_node = get_nproc_per_node(
Expand All @@ -403,6 +405,11 @@ def run(self, with_test=False, dryrun=False, monitor=False, interval=10):
)
else:
# If hostfile is not provided, run the job on localhost
self.cur_envs = self.user_envs
visible_devices = self.cur_envs.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is not None and isinstance(visible_devices, str):
visible_devices = visible_devices.split(",")
num_visible_devices = len(visible_devices)
nproc_from_args = runner_config.get("nproc_per_node", None)
nproc_per_node = get_nproc_per_node(
None, nproc_from_args, num_visible_devices
Expand Down
3 changes: 2 additions & 1 deletion flagscale/train/hetero/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
):
assert torch.distributed.is_initialized()
self._args = args
self._distributed_backend = args.distributed_backend
self._rank = torch.distributed.get_rank()
self._world_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size * data_parallel_size
self._offset = offset
Expand Down Expand Up @@ -205,7 +206,7 @@ def build_process_group(
ranks = self._rank_mapper.to_physical_ranks(logical_ranks)
group = torch.distributed.new_group(
ranks,
backend="nccl",
backend=self._distributed_backend,
timeout=self._timeout,
pg_options=pg_options,
)
Expand Down

0 comments on commit c4a809b

Please sign in to comment.