From 2ce84ed163f5633f8036c27c22c8a4ae1deb6357 Mon Sep 17 00:00:00 2001 From: Eric Yu Date: Thu, 2 Jan 2025 22:10:39 +0000 Subject: [PATCH] Delete duplicate stack_cfg config and setdefault input_partition_type --- axlearn/experiments/text/gpt/fuji.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index cc6ecdf5..5bbfff26 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -176,12 +176,10 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -202,12 +200,10 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -227,12 +223,10 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -334,12 +328,10 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=False, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -426,7 +418,6 @@ def get_trainer_kwargs( learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, - input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -474,6 +465,7 @@ def get_trainer_kwargs( max_step=trainer_kwargs["max_step"], **trainer_kwargs.pop("learner_kwargs"), ) + trainer_kwargs.setdefault("input_partition_type", None if backend != "neuron" else DataPartitionType.BATCH) # pylint: enable=use-dict-literal return trainer_kwargs