From 053fa0a3e670dc5b59136c70c67fb81540593c99 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 16 Dec 2024 04:05:01 +0000 Subject: [PATCH] instroduce regex based checkpoint policy --- axlearn/common/attention.py | 36 +++++--------------------- axlearn/common/utils.py | 18 +++++++++++++ axlearn/experiments/text/gpt/fuji.py | 38 +++++++++++++++++++++++----- 3 files changed, 56 insertions(+), 36 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 466453c32..159a5bb85 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -1259,8 +1259,8 @@ def apply_rotary_position_embeddings( Rotary position affined value embeddings with shape [batch_size, seq_len, num_heads, dim] if rotary_value == True, else original value embeddings """ - # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] - # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] + # # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] + # # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] # sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) # # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] # sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape) @@ -4131,21 +4131,6 @@ def forward( # TODO(sneha): extend_step -def save_only_these(*names_to_save): - # Save all values, including unnamed ones, excluding the specified names. - names_to_save = frozenset(names_to_save) - def policy(prim, *_, **params): - if 'name' in params and params['name'] in names_to_save: - print(f"[WIP] Saving {params['name']}") - return True - elif 'name' in params: - print(f"[WIP] Not saving tensor: {params['name']}") - return False - else: - print("[WIP] Not saving unnamed tensor") - return False - return policy - def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore @@ -4180,24 +4165,15 @@ def build_remat_spec( if stack_cfg.klass is PipelinedTransformerLayer: return None - backend = jax.default_backend() checkpoints = [] if self_attention: attention_name = stack_cfg.layer.self_attention.attention.klass.__name__ - if backend != "neuron": - checkpoints.extend( - [f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]] - ) - else: - checkpoints.extend( - [f"{attention_name}.{el}" for el in ['q_proj', 'k_proj', 'v_proj']] + ["TransformerAttentionLayer.residual_add", "TransformerFeedForwardLayer.mlp_residual"] - ) + checkpoints.extend( + [f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]] + ) if feed_forward and hasattr(stack_cfg.layer, "feed_forward"): ffn_name = stack_cfg.layer.feed_forward.klass.__name__ - if backend != "neuron": - checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]]) - else: - checkpoints.extend([f"{ffn_name}.{el}" for el in ["linear1_0", "linear1_1"]]) + checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]]) policy = config_for_function(jax_remat_policies.save_only_these_names).set( names_which_can_be_saved=checkpoints diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 1e141318e..bf330ad5d 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -23,6 +23,7 @@ import types from collections.abc import Mapping, Sequence from enum import Enum +import re from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union import jax @@ -1464,3 +1465,20 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]): f"Input is expected to contain '{path}'; " f"instead, it contains: '{jax.tree_structure(x)}'." ) from e + +def save_only_these_regex_patterns(*names_to_save): + # Save all values, including unnamed ones, excluding the specified names. + names_to_save = frozenset(names_to_save) + def policy(prim, *_, **params): + if 'name' in params: + for name_to_save in names_to_save: + if re.search(name_to_save, params['name']): + # if name exists and matches any regex pattern specified + return True + elif 'name' in params: + # named but not specified + return False + else: + # Unnamed tensor is not saved + return False + return policy \ No newline at end of file diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 09e1ef8ea..eed9964ea 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -43,7 +43,7 @@ MeshShapeModifier, RematSpecModifier, ) -from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies +from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies, save_only_these_regex_patterns from axlearn.experiments.text.gpt.common import ( STEP_DTYPE, SourceBuilder, @@ -184,7 +184,7 @@ def get_trainer_kwargs( "neuron-(trn2|trn2n).48xlarge-64", mesh_shape_from_axes(fsdp=-1, model=4), ), - ) + ), ) elif model_size == "3B": trainer_kwargs = dict( @@ -208,7 +208,7 @@ def get_trainer_kwargs( "neuron-(trn2|trn2n).48xlarge-64", mesh_shape_from_axes(fsdp=-1, model=4), ), - ) + ), ) elif model_size == "7B": trainer_kwargs = dict( @@ -450,7 +450,31 @@ def get_trainer_kwargs( ), ( "neuron-(trn2|trn2n).48xlarge-64", - mesh_shape_from_axes(fsdp=-1, model=4), + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=config_for_function(save_only_these_regex_patterns).set( + names_to_save=[ + "^\.?(q_proj)?.*$", + "^\.?(k_pro)?.*$j", + "^\.?(v_proj)?.*$", + "^\.?(TransformerAttentionLayer.residual_add)?.*$", + "^\.?(TransformerFeedForwardLayer.mlp_residual)?.*$", + "^\.?(linear1_0)?.*$", + "^\.?(linear1_1)?.*$", + ] + ), + ), + } + ), + ], + ), ), ), ) @@ -458,7 +482,9 @@ def get_trainer_kwargs( raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) - model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config()) + model_kwargs.setdefault( + "stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config() + ) trainer_kwargs["model_cfg"] = model_config(**model_kwargs) trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( max_step=trainer_kwargs["max_step"], @@ -511,7 +537,7 @@ def model_config( atten_cfg = GroupedQueryAttention.default_config() backend = jax.default_backend() - qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear + qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads) else: atten_cfg = MultiheadAttention.default_config()