Skip to content

Commit

Permalink
Delete duplicate stack_cfg config and setdefault input_partition_type
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-mengchiy committed Jan 2, 2025
1 parent 0f8a6ab commit 2ce84ed
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,10 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand All @@ -202,12 +200,10 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand All @@ -227,12 +223,10 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand Down Expand Up @@ -334,12 +328,10 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=False,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand Down Expand Up @@ -426,7 +418,6 @@ def get_trainer_kwargs(
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
mesh_rules=(
Expand Down Expand Up @@ -474,6 +465,7 @@ def get_trainer_kwargs(
max_step=trainer_kwargs["max_step"],
**trainer_kwargs.pop("learner_kwargs"),
)
trainer_kwargs.setdefault("input_partition_type", None if backend != "neuron" else DataPartitionType.BATCH)
# pylint: enable=use-dict-literal
return trainer_kwargs

Expand Down

0 comments on commit 2ce84ed

Please sign in to comment.