Skip to content

Commit

Permalink
instroduce regex based checkpoint policy
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 16, 2024
1 parent 9e130a8 commit 053fa0a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 36 deletions.
36 changes: 6 additions & 30 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,8 +1259,8 @@ def apply_rotary_position_embeddings(
Rotary position affined value embeddings with shape [batch_size, seq_len, num_heads, dim]
if rotary_value == True, else original value embeddings
"""
# sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
# # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
# sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape)
Expand Down Expand Up @@ -4131,21 +4131,6 @@ def forward(

# TODO(sneha): extend_step

def save_only_these(*names_to_save):
# Save all values, including unnamed ones, excluding the specified names.
names_to_save = frozenset(names_to_save)
def policy(prim, *_, **params):
if 'name' in params and params['name'] in names_to_save:
print(f"[WIP] Saving {params['name']}")
return True
elif 'name' in params:
print(f"[WIP] Not saving tensor: {params['name']}")
return False
else:
print("[WIP] Not saving unnamed tensor")
return False
return policy

def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
Expand Down Expand Up @@ -4180,24 +4165,15 @@ def build_remat_spec(
if stack_cfg.klass is PipelinedTransformerLayer:
return None

backend = jax.default_backend()
checkpoints = []
if self_attention:
attention_name = stack_cfg.layer.self_attention.attention.klass.__name__
if backend != "neuron":
checkpoints.extend(
[f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]]
)
else:
checkpoints.extend(
[f"{attention_name}.{el}" for el in ['q_proj', 'k_proj', 'v_proj']] + ["TransformerAttentionLayer.residual_add", "TransformerFeedForwardLayer.mlp_residual"]
)
checkpoints.extend(
[f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]]
)
if feed_forward and hasattr(stack_cfg.layer, "feed_forward"):
ffn_name = stack_cfg.layer.feed_forward.klass.__name__
if backend != "neuron":
checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]])
else:
checkpoints.extend([f"{ffn_name}.{el}" for el in ["linear1_0", "linear1_1"]])
checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]])

policy = config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
Expand Down
18 changes: 18 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import types
from collections.abc import Mapping, Sequence
from enum import Enum
import re
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union

import jax
Expand Down Expand Up @@ -1464,3 +1465,20 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]):
f"Input is expected to contain '{path}'; "
f"instead, it contains: '{jax.tree_structure(x)}'."
) from e

def save_only_these_regex_patterns(*names_to_save):
# Save all values, including unnamed ones, excluding the specified names.
names_to_save = frozenset(names_to_save)
def policy(prim, *_, **params):
if 'name' in params:
for name_to_save in names_to_save:
if re.search(name_to_save, params['name']):
# if name exists and matches any regex pattern specified
return True
elif 'name' in params:
# named but not specified
return False
else:
# Unnamed tensor is not saved
return False
return policy
38 changes: 32 additions & 6 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
MeshShapeModifier,
RematSpecModifier,
)
from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies
from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies, save_only_these_regex_patterns
from axlearn.experiments.text.gpt.common import (
STEP_DTYPE,
SourceBuilder,
Expand Down Expand Up @@ -184,7 +184,7 @@ def get_trainer_kwargs(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
)
),
)
elif model_size == "3B":
trainer_kwargs = dict(
Expand All @@ -208,7 +208,7 @@ def get_trainer_kwargs(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
)
),
)
elif model_size == "7B":
trainer_kwargs = dict(
Expand Down Expand Up @@ -450,15 +450,41 @@ def get_trainer_kwargs(
),
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(save_only_these_regex_patterns).set(
names_to_save=[
"^\.?(q_proj)?.*$",
"^\.?(k_pro)?.*$j",
"^\.?(v_proj)?.*$",
"^\.?(TransformerAttentionLayer.residual_add)?.*$",
"^\.?(TransformerFeedForwardLayer.mlp_residual)?.*$",
"^\.?(linear1_0)?.*$",
"^\.?(linear1_1)?.*$",
]
),
),
}
),
],
),
),
),
)
else:
raise NotImplementedError(f"Unknown model size {model_size}.")
model_kwargs = trainer_kwargs.pop("model_kwargs")
model_kwargs.setdefault("vocab_size", vocab_size)
model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config())
model_kwargs.setdefault(
"stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config()
)
trainer_kwargs["model_cfg"] = model_config(**model_kwargs)
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config(
max_step=trainer_kwargs["max_step"],
Expand Down Expand Up @@ -511,7 +537,7 @@ def model_config(
atten_cfg = GroupedQueryAttention.default_config()
backend = jax.default_backend()

qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear
qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear
atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads)
else:
atten_cfg = MultiheadAttention.default_config()
Expand Down

0 comments on commit 053fa0a

Please sign in to comment.