Skip to content

Commit

Permalink
boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 9, 2025
1 parent c40b39a commit ef4d198
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 3 deletions.
58 changes: 58 additions & 0 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ConfigModifier,
ConfigOr,
Required,
ConfigBase,
config_class,
maybe_instantiate,
)
Expand Down Expand Up @@ -146,6 +147,63 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
"""Update the model config for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ModelConfigModifier.
Attributes:
model_cfg_modifications: A mapping from module path
(e.g. `model.decoder.transformer.layer`) to a Config.
"""

model_cfg_modifications: Required[Dict[str, ConfigBase]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._model_cfg_modifications = cfg.model_cfg_modifications

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the mesh shape.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
Returns:
The modified trainer config.
"""

for module_name, model_cfg in self._model_cfg_modifications.items():
if not model_cfg:
continue
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
parent_module = None

for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
parent_module = curr_module
curr_module = getattr(curr_module, target_module)

# Copy configurations from the config being replaced on a best effort basis
for key in model_cfg.keys():
if key == 'klass':
continue
elif hasattr(curr_module, key) and hasattr(curr_module, key):
setattr(model_cfg, key, getattr(curr_module, key))
# Replace in the parent config
setattr(parent_module, target_module, model_cfg)
return cfg

class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

Expand Down
25 changes: 24 additions & 1 deletion axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common import causal_lm, test_utils
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
RematSpecModifier,
ModelConfigModifier,
)
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
from axlearn.common.trainer_test import DummyModel


Expand Down Expand Up @@ -65,6 +67,27 @@ def test_remat_policy_override(self):
_ = cfg_modifier(cfg)


class ModelConfigModifierTest(test_utils.TestCase):
def test_remat_policy_override(self):
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
print(cfg)
self.assertRegex(str(cfg.model.decoder), ".*StackedTransformerLayer")

cfg_modifier = (
ModelConfigModifier.default_config()
.set(
model_cfg_modifications={
"model.decoder.transformer": RepeatedTransformerLayer.default_config(),
}
)
.instantiate()
)

cfg = cfg_modifier(cfg)
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
self.assertRegex(str(cfg.model.decoder), ".*RepeatedTransformerLayer")


class MeshShapeModifierTest(test_utils.TestCase):
def test_mesh_shape_update(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
Expand Down
132 changes: 130 additions & 2 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
BaseStackedTransformerLayer,
FusedGroupedQKVLinear,
FusedQKVLinear,
GroupedQKVLinear,
GroupedQueryAttention,
MultiheadAttention,
RepeatedTransformerLayer,
StackedTransformerLayer,
RoFormerQKVLinear,
)
from axlearn.common.base_layer import RematSpec
Expand All @@ -38,6 +40,7 @@
GradientAccumulationModifier,
MeshShapeModifier,
RematSpecModifier,
ModelConfigModifier,
)
from axlearn.common.utils import extended_checkpoint_policies
from axlearn.experiments.text.gpt.common import (
Expand Down Expand Up @@ -174,6 +177,28 @@ def get_trainer_kwargs(
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications={
"model.decoder.transformer": StackedTransformerLayer.default_config(),
"model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": (
None
if version == Version.V1
else GroupedQKVLinear.default_config()
),
}
),
],
),
),
),
)
elif model_size == "3B":
trainer_kwargs = dict(
Expand All @@ -192,6 +217,28 @@ def get_trainer_kwargs(
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications={
"model.decoder.transformer": StackedTransformerLayer.default_config(),
"model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": (
None
if version == Version.V1
else GroupedQKVLinear.default_config()
),
}
),
],
),
),
),
)
elif model_size == "7B":
trainer_kwargs = dict(
Expand Down Expand Up @@ -287,6 +334,46 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications={
"model.decoder.transformer": StackedTransformerLayer.default_config(),
"model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": (
None
if version == Version.V1
else GroupedQKVLinear.default_config()
),
}
),
],
),
),
(
"neuron-(trn1|trn1n).32xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=8)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications={
"model.decoder.transformer": StackedTransformerLayer.default_config(),
"model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": (
None
if version == Version.V1
else GroupedQKVLinear.default_config()
),
}
),
],
),
),
),
)
elif model_size == "8B":
Expand Down Expand Up @@ -367,6 +454,26 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications={
"model.decoder.transformer": StackedTransformerLayer.default_config(),
"model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": (
None
if version == Version.V1
else GroupedQKVLinear.default_config()
),
}
),
],
),
),
),
)
elif model_size == "70B":
Expand All @@ -385,7 +492,7 @@ def get_trainer_kwargs(
),
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=8,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
mesh_rules=(
Expand Down Expand Up @@ -417,6 +524,26 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications={
"model.decoder.transformer": StackedTransformerLayer.default_config(),
"model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": (
None
if version == Version.V1
else GroupedQKVLinear.default_config()
),
}
),
],
),
),
),
)
else:
Expand Down Expand Up @@ -473,7 +600,8 @@ def model_config(
ffn_dim = scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=256)
if num_kv_heads:
atten_cfg = GroupedQueryAttention.default_config()
atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads)
qkv_linear = FusedGroupedQKVLinear
atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads)
else:
atten_cfg = MultiheadAttention.default_config()
atten_input_linear = FusedQKVLinear.default_config()
Expand Down

0 comments on commit ef4d198

Please sign in to comment.