From b7c0c5e9bbf9928227bedb2d1e9adb64480c015b Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 6 Jan 2025 22:50:31 +0000 Subject: [PATCH] enable special remat for neuron --- axlearn/common/attention.py | 12 +++- axlearn/common/attention_test.py | 72 +++++++++++++++++-- .../fuji-70B-v1-flash.txt | 16 +++++ .../fuji-70B-v1.txt | 16 +++++ .../fuji-70B-v2-flash.txt | 16 +++++ .../fuji-70B-v2.txt | 16 +++++ .../fuji-70B-v3-flash.txt | 16 +++++ .../fuji-70B-v3.txt | 16 +++++ axlearn/experiments/text/gpt/common.py | 11 +-- axlearn/experiments/text/gpt/fuji.py | 40 +++++++++-- 10 files changed, 211 insertions(+), 20 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 4c61e4712..e02edb4b5 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -4001,15 +4001,21 @@ def _save_and_offload_only_these_names_regex( ) -SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)" -FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*" +# Regex patterns for matching remat names +class RematRegexSavePatterns(enum.Enum): + QKV_PROJ = r".*[kqv]_proj" + O_PROJ = r".*o_proj" + CONTEXT = r".*context" + LINEAR1_X = r".*linear1_[01].*" + SELF_ATTENTION = r".*([qkvo]_proj|context)" + FEED_FORWARD = r".*linear[12]_.*" def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore ], - save_pattern: SavePattern = SELF_ATTENTION_SAVE_PATTERN, + save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value, offload_pattern: SavePattern = None, offload_dst: str = "pinned_host", ) -> Optional[RematSpec]: diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index a8555ec4c..591558528 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -41,7 +41,6 @@ from axlearn.common import attention, attention_bias, test_utils, utils from axlearn.common.attention import ( - FEED_FORWARD_SAVE_PATTERN, BaseStackedTransformerLayer, BaseTransformerLayer, BottleNeckAdapterTransformerLayer, @@ -58,6 +57,7 @@ PipelinedTransformerLayer, QKVLinear, QLinear, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, StackedTransformerLayer, @@ -65,7 +65,6 @@ TransformerFeedForwardLayer, TransformerLayer, _next_power_of_two, - _save_and_offload_only_these_names_regex, apply_attention_logit_biases, apply_rotary_position_embeddings, build_remat_spec, @@ -125,6 +124,7 @@ as_tensor, flatten_items, shapes, + save_and_offload_only_these_names_regex, ) @@ -3445,8 +3445,8 @@ def f(x, layer_params): _, save_name_backward = jax.linearize( jax.remat( f, - policy=_save_and_offload_only_these_names_regex( - names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN, + policy=save_and_offload_only_these_names_regex( + names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value, names_which_can_be_offloaded=None, offload_src="device", offload_dst="pinned_host", @@ -3901,6 +3901,70 @@ def f(x, layer_params): 5, ) + def test_build_remat_spec_neuron(self): + model_dim, num_heads = 6, 2 + cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) + cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + cfg.feed_forward.hidden_dim = model_dim * 4 + cfg.vlog = 5 + + layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) + + def f(x, layer_params): + forward_outputs, _ = F( + layer, + inputs=dict( + data=x, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return forward_outputs + + # Ignore type errors. + spec: Any = build_remat_spec(mock.MagicMock()) + + policy = ( + config_for_function(save_and_offload_only_these_names_regex) + .set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ) + .instantiate() + ) + + _, default_policy_backward = jax.linearize( + jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse), + jnp.asarray(target), + layer_params, + ) + _, full_remat_backward = jax.linearize( + jax.remat(f), + jnp.asarray(target), + layer_params, + ) + + # Eliminated the remat of qkv_proj and linear1_0 = 4 dots. This assumes + # FlashAttention is not enabled. + self.assertEqual( + str(full_remat_backward).count(" dot_general") + - str(default_policy_backward).count(" dot_general"), + 4, + ) + class TestStackModel(BaseLayer): """A dummy transformer stack.""" diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index 278d72e61..3052dab86 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1 mesh_rules[3][1][3]: 128 mesh_rules[3][1][4]: 1 mesh_rules[3][1][5]: 1 +mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index 6ca6030fe..d57c35aea 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1 mesh_rules[3][1][3]: 128 mesh_rules[3][1][4]: 1 mesh_rules[3][1][5]: 1 +mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index 8db672f15..605d5f326 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1 mesh_rules[3][1][3]: 128 mesh_rules[3][1][4]: 1 mesh_rules[3][1][5]: 1 +mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index 422e651cf..ccf575c40 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1 mesh_rules[3][1][3]: 128 mesh_rules[3][1][4]: 1 mesh_rules[3][1][5]: 1 +mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index e2fd015c3..e3f269bfa 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1 mesh_rules[3][1][3]: 128 mesh_rules[3][1][4]: 1 mesh_rules[3][1][5]: 1 +mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index 829ba2d34..b7457c951 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1 mesh_rules[3][1][3]: 128 mesh_rules[3][1][4]: 1 mesh_rules[3][1][5]: 1 +mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 973fb9234..c01789a39 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -34,6 +34,7 @@ BaseQKVLinear, MultiheadAttention, RepeatedTransformerLayer, + StackedTransformerLayer, TransformerLayer, build_remat_spec, set_double_shard_weights_config, @@ -190,20 +191,12 @@ def update_model_remat_config( ): """Recomputes and sets the remat_spec based on provided layer_cfg. - Only applied if the stack_cfg is a RepeatedTransformerLayer. - Args: stack_cfg: The transformer stack config. layer_cfg: The transformer layer config. offload_dst: Destination of remat checkptoing offloading. - Raises: - NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: - raise NotImplementedError( - f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" - ) remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg)) layer_cfg.set(remat_spec=remat_spec) @@ -277,7 +270,7 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: + if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)): update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 6cd498143..001a873e3 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -19,12 +19,12 @@ from axlearn.common import causal_lm, config from axlearn.common.attention import ( - SELF_ATTENTION_SAVE_PATTERN, BaseStackedTransformerLayer, FusedGroupedQKVLinear, FusedQKVLinear, GroupedQueryAttention, MultiheadAttention, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, ) @@ -40,7 +40,10 @@ MeshShapeModifier, RematSpecModifier, ) -from axlearn.common.utils import extended_checkpoint_policies +from axlearn.common.utils import ( + extended_checkpoint_policies, + save_and_offload_only_these_names_regex, +) from axlearn.experiments.text.gpt.common import ( STEP_DTYPE, SourceBuilder, @@ -86,7 +89,6 @@ class Version(enum.Enum): Version.V3: 5e5, } - # Mapping from Fuji versions to total number of tokens used in training. TOTAL_TOKENS = { Version.V1: { @@ -147,7 +149,7 @@ def get_trainer_kwargs( extended_checkpoint_policies.save_and_offload_only_these_names_regex ).set( names_which_can_be_saved=None, - names_which_can_be_offloaded=SELF_ATTENTION_SAVE_PATTERN, + names_which_can_be_offloaded=RematRegexSavePatterns.SELF_ATTENTION.value, offload_src="device", offload_dst="pinned_host", ) @@ -492,6 +494,36 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=config_for_function( + save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ), + ), + } + ), + ], + ), + ), ), ) else: