diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index cc6ecdf56..5bbfff261 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -176,12 +176,10 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -202,12 +200,10 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -227,12 +223,10 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -334,12 +328,10 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=False, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -426,7 +418,6 @@ def get_trainer_kwargs( learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -474,6 +465,7 @@ def get_trainer_kwargs( max_step=trainer_kwargs["max_step"], **trainer_kwargs.pop("learner_kwargs"), ) + trainer_kwargs.setdefault("input_partition_type", None if backend != "neuron" else DataPartitionType.BATCH) # pylint: enable=use-dict-literal return trainer_kwargs