From 0f8a6ab2c7dfab610607eebae1a4907ae9fe21c6 Mon Sep 17 00:00:00 2001 From: Eric Yu Date: Thu, 2 Jan 2025 19:19:38 +0000 Subject: [PATCH] neuron changes for 1B,3B,8B models --- axlearn/experiments/text/gpt/fuji.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 09e1ef8ea..cc6ecdf56 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -99,7 +99,10 @@ class Version(enum.Enum): }, Version.V2: { "test": 2 * (1024**4), # 2T tokens + "1B": 2 * (1024**4), # 2T tokens + "3B": 2 * (1024**4), # 2T tokens "7B": 2 * (1024**4), # 2T tokens + "8B": 2 * (1024**4), # 2T tokens "70B": 2 * (1024**4), # 2T tokens }, Version.V3: { @@ -173,10 +176,12 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, + stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, + input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -197,10 +202,12 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, + stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, + input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -220,6 +227,7 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, + stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, @@ -326,10 +334,12 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=False, flash_attention=flash_attention, + stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, + input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=(