Skip to content

Commit

Permalink
neuron changes for 1B,3B,8B models
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-mengchiy committed Jan 2, 2025
1 parent f4a68f9 commit 0f8a6ab
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ class Version(enum.Enum):
},
Version.V2: {
"test": 2 * (1024**4), # 2T tokens
"1B": 2 * (1024**4), # 2T tokens
"3B": 2 * (1024**4), # 2T tokens
"7B": 2 * (1024**4), # 2T tokens
"8B": 2 * (1024**4), # 2T tokens
"70B": 2 * (1024**4), # 2T tokens
},
Version.V3: {
Expand Down Expand Up @@ -173,10 +176,12 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand All @@ -197,10 +202,12 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand All @@ -220,6 +227,7 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
Expand Down Expand Up @@ -326,10 +334,12 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=False,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand Down

0 comments on commit 0f8a6ab

Please sign in to comment.