Skip to content

Commit

Permalink
update config
Browse files Browse the repository at this point in the history
  • Loading branch information
lzy-dev committed Dec 3, 2024
1 parent 6653514 commit 8797120
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ system:
accumulate_allreduce_grads_in_fp32: True
logging:
log_interval: 1
log_throughput: True
tensorboard_log_interval: 1
wandb_project: "train-aquila-1B"
wandb_exp_name: "train-test-1B"
checkpoint:
# load: outputs_llama3/checkpoint_mc
load: outputs_llama3/checkpoint_mc
ckpt_format: torch
save_interval: 2000

Expand All @@ -25,7 +26,7 @@ system:
hetero_use_cpu_communication: False
# mesh format [tp1,cp1,ep1,dp1,pp1,(tp2,cp2...)]

hetero_pipeline_layer_split: [12, 12]
hetero_pipeline_layer_split: [18, 18]
hetero_process_meshes: [1, 1, 1, 4, 2]
hetero_device_types: ["A800"]

Expand All @@ -51,31 +52,32 @@ system:
model:
# use_mcore_models: True # deprecated
transformer_impl: transformer_engine
num_layers: 24
num_layers: 36
hidden_size: 2048
num_attention_heads: 16
num_attention_heads: 16
group_query_attention: True
num_query_groups: 2
seq_length: 4096
max_position_embeddings: 4096 # only for adding position embeddings
norm_epsilon: 1e-5
norm_epsilon: 1e-6
use_rotary_position_embeddings: true
no_position_embedding: true
rotary_base: 100000 # To be determined
rotary_base: 1000000
swiglu: true
multiple_of: 256
hidden_dim_multiplier: 2 # ffn_hidden_size 11008
normalization: RMSNorm
qk_layernorm: True
qk_layernorm_hidden_dim: True
position_embedding_type: rope
untie_embeddings_and_output_weights: true
untie_embeddings_and_output_weights: False
init_method_std: 0.02
attention_dropout: 0.0
hidden_dropout: 0.0
weight_decay: 0.1
clip_grad: 1.0
train_samples: 160
eval_iters: 0
micro_batch_size: 2
global_batch_size: 16
micro_batch_size: 1
global_batch_size: 1024
seed: 1234

optimizer:
Expand All @@ -85,15 +87,14 @@ model:
lr_scheduler:
lr: 2.0e-5
min_lr: 2.0e-6
lr_warmup_samples: 10
lr_warmup_samples: 2000
lr_decay_style: cosine

data:
data_path: ${data_path:??}
# data_path: ./build/data/pile_wikipedia_demo
split: 1
tokenizer:
tokenizer_type: QwenTokenizerFS
tokenizer_path: ${tokenizer_path:??}
vocab_size: 151851
tokenizer_path: examples/aquila/qwentokenizer
vocab_size: 151936
make_vocab_size_divisible_by: 64
29 changes: 14 additions & 15 deletions examples/aquila/conf/train/train_aquila_7b.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
system:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
pipeline_model_parallel_size: 2
disable_bias_linear: True
use_flash_attn: True
use_distributed_optimizer: True
precision:
fp16: True
bf16: True
initial_loss_scale: 522893
min_loss_scale: 1.0
attention_softmax_in_fp32: True
Expand All @@ -19,11 +19,11 @@ system:
save_interval: 2000

model:
num_layers: 32
hidden_size: 4096
num_layers: 12
hidden_size: 2048
num_attention_heads: 32
seq_length: 2048
max_position_embeddings: 2048
seq_length: 4096
max_position_embeddings: 4096
norm_epsilon: 1e-5
use_rotary_position_embeddings: true
no_position_embedding: true
Expand All @@ -37,10 +37,10 @@ model:
hidden_dropout: 0.0
weight_decay: 0.1
clip_grad: 1.0
train_samples: 1002539063
train_samples: 1000
eval_iters: 0
micro_batch_size: 2
global_batch_size: 1728
global_batch_size: 16
seed: 1234

optimizer:
Expand All @@ -50,15 +50,14 @@ model:
lr_scheduler:
lr: 2.0e-5
min_lr: 2.0e-6
lr_warmup_samples: 3076172
lr_warmup_samples: 10
lr_decay_style: cosine

data:
data_path: ${data_path:??}
data_path: ./build/data/pile_wikipedia_demo
split: 1
tokenizer:
tokenizer_type: AquilaTokenizerFS
vocab_file: ./examples/aquila/tokenizer/vocab.json
merge_file: ./examples/aquila/tokenizer/merges.txt
special_tokens_file: ./examples/aquila/tokenizer/special_tokens.txt
vocab_size: 100008
tokenizer_type: QwenTokenizerFS
tokenizer_path: examples/aquila/qwentokenizer
vocab_size: 151851
make_vocab_size_divisible_by: 64
6 changes: 1 addition & 5 deletions flagscale/train/hetero/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,11 +866,7 @@ def get_embedding_group(self):
"""Get the embedding group the caller rank belongs to."""
groups = self._process_groups.get("embd", None)
assert groups is not None, 'embedding group is not initialized'
for group in groups:
if self._rank in self._process_group_to_ranks[group]:
embd_group = group
break
return embd_group
return groups

def get_position_embedding_group(self):
"""Get the position embedding group the caller rank belongs to."""
Expand Down
13 changes: 11 additions & 2 deletions megatron/megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
sync.
"""

embed_group = parallel_state.get_embedding_group()
if not isinstance(embed_group, list):
embed_group = [embed_group]
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
and torch.distributed.get_world_size(embed_group[0]) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
Expand All @@ -33,7 +36,13 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad = weight.main_grad
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
if len(embed_group) == 1:
torch.distributed.all_reduce(grad, group=embed_group[0])
else:
origin_grad = grad.data.clone()
for group in embed_group:
grad.data = origin_grad.clone()
torch.distributed.all_reduce(grad, group=group)


def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,16 @@ def setup_embeddings_and_output_layer(self) -> None:
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_output_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
embedding_group = parallel_state.get_embedding_group()
if not isinstance(embedding_group, list):
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
else:
original_weight = weight.data.clone()
for group in embedding_group:
weight.data = original_weight.clone()
torch.distributed.all_reduce(weight.data, group=group)

elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
Expand Down
8 changes: 4 additions & 4 deletions megatron/megatron/core/optimizer/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def get_grad_norm_fp32(
# For cpu comminication
tensor_device = get_device_type_for_comm(model_parallel_group)
if isinstance(model_parallel_group, list):
original_total_norm = total_norm
original_total_norm = total_norm.clone().detach()
for mp_group in model_parallel_group:
total_norm = original_total_norm
total_norm.data = original_total_norm.data.clone()
total_norm = total_norm.to(tensor_device)
torch.distributed.all_reduce(
total_norm, op=torch.distributed.ReduceOp.SUM, group=mp_group
Expand Down Expand Up @@ -209,9 +209,9 @@ def count_zeros_fp32(

# Sum across all model-parallel GPUs.
if isinstance(model_parallel_group, list):
original_total_num_zeros = total_num_zeros
original_total_num_zeros = total_num_zeros.clone().detach()
for mp_group in model_parallel_group:
total_num_zeros = original_total_num_zeros
total_num_zeros.data = original_total_num_zeros.data.clone()
torch.distributed.all_reduce(
total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=mp_group
)
Expand Down
6 changes: 3 additions & 3 deletions megatron/megatron/training/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def set_global_writers(args):
ranks_tensor = torch.tensor([0 for _ in range(size)], dtype=torch.int, device=comm_device)
orig_ranks = torch.tensor([i for i in range(size)], dtype=torch.int, device=comm_device)
if is_last_rank():
ranks_tensor = orig_ranks
ranks_list = torch.distributed.get_process_group_ranks(mp_groups[-1])
ranks_tensor = torch.tensor(ranks_list, dtype=torch.int, device=comm_device)
ranks_tensor = torch.tensor(ranks_list, dtype=torch.int, device=comm_device)
orig_ranks = ranks_tensor.clone().detach()
for group in mp_groups:
ranks_tensor = orig_ranks
ranks_tensor = orig_ranks.clone()
torch.distributed.all_reduce(ranks_tensor, group=group)
if torch.distributed.get_rank() in ranks_tensor.tolist():
_set_wandb_writer(args)
Expand Down
4 changes: 2 additions & 2 deletions megatron/megatron/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def calc_params_l2_norm(model):
# Sum across all model-parallel GPUs(tensor + pipeline).
mp_groups = mpu.get_model_parallel_group()
if isinstance(mp_groups, list):
original_norm_2 = norm_2
original_norm_2 = norm_2.clone().detach()
for mp_group in mp_groups:
norm_2 = original_norm_2
norm_2 = original_norm_2.clone()
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mp_group)
Expand Down

0 comments on commit 8797120

Please sign in to comment.