diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index d647e1a06..9d37ca256 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -11,6 +11,7 @@ ConfigModifier, ConfigOr, Required, + ConfigBase, config_class, maybe_instantiate, ) @@ -146,6 +147,63 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: return cfg +class ModelConfigModifier(ConfigModifier): + """Update the model config for the trainer config.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure ModelConfigModifier. + + Attributes: + model_cfg_modifications: A mapping from module path + (e.g. `model.decoder.transformer.layer`) to a Config. + """ + + model_cfg_modifications: Required[Dict[str, ConfigBase]] = REQUIRED + + def __init__(self, cfg: Config): + super().__init__(cfg) + cfg = self.config + self._model_cfg_modifications = cfg.model_cfg_modifications + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Overwrite the mesh shape. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + + Returns: + The modified trainer config. + """ + + for module_name, model_cfg in self._model_cfg_modifications.items(): + if not model_cfg: + continue + # Here we assume x.y.z format. + # One example would be model.decoder.transformer.layer. + target_modules = module_name.split(".") + curr_module = cfg + parent_module = None + + for target_module in target_modules: + if not hasattr(curr_module, target_module): + raise ValueError(f"{target_module} is not found in {curr_module}.") + parent_module = curr_module + curr_module = getattr(curr_module, target_module) + + # Copy configurations from the config being replaced on a best effort basis + for key in model_cfg.keys(): + if key == 'klass': + continue + elif hasattr(curr_module, key) and hasattr(curr_module, key): + setattr(model_cfg, key, getattr(curr_module, key)) + # Replace in the parent config + setattr(parent_module, target_module, model_cfg) + return cfg + class ChainConfigModifier(ConfigModifier): """Chain multiple config modifiers together.""" diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py index ccfe00823..369f97c1b 100644 --- a/axlearn/common/trainer_config_modifier_test.py +++ b/axlearn/common/trainer_config_modifier_test.py @@ -5,7 +5,7 @@ import jax from absl.testing import absltest -from axlearn.common import test_utils +from axlearn.common import causal_lm, test_utils from axlearn.common.base_layer import RematSpec from axlearn.common.trainer import SpmdTrainer from axlearn.common.trainer_config_modifier import ( @@ -13,7 +13,9 @@ GradientAccumulationModifier, MeshShapeModifier, RematSpecModifier, + ModelConfigModifier, ) +from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer from axlearn.common.trainer_test import DummyModel @@ -65,6 +67,27 @@ def test_remat_policy_override(self): _ = cfg_modifier(cfg) +class ModelConfigModifierTest(test_utils.TestCase): + def test_remat_policy_override(self): + cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config()) + print(cfg) + self.assertRegex(str(cfg.model.decoder), ".*StackedTransformerLayer") + + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + model_cfg_modifications={ + "model.decoder.transformer": RepeatedTransformerLayer.default_config(), + } + ) + .instantiate() + ) + + cfg = cfg_modifier(cfg) + # The default StackedTransformerLayer should have changed to RepeatedTransformerLayer + self.assertRegex(str(cfg.model.decoder), ".*RepeatedTransformerLayer") + + class MeshShapeModifierTest(test_utils.TestCase): def test_mesh_shape_update(self): cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..569a09c64 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -22,9 +22,11 @@ BaseStackedTransformerLayer, FusedGroupedQKVLinear, FusedQKVLinear, + GroupedQKVLinear, GroupedQueryAttention, MultiheadAttention, RepeatedTransformerLayer, + StackedTransformerLayer, RoFormerQKVLinear, ) from axlearn.common.base_layer import RematSpec @@ -38,6 +40,7 @@ GradientAccumulationModifier, MeshShapeModifier, RematSpecModifier, + ModelConfigModifier, ) from axlearn.common.utils import extended_checkpoint_policies from axlearn.experiments.text.gpt.common import ( @@ -174,6 +177,28 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications={ + "model.decoder.transformer": StackedTransformerLayer.default_config(), + "model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": ( + None + if version == Version.V1 + else GroupedQKVLinear.default_config() + ), + } + ), + ], + ), + ), + ), ) elif model_size == "3B": trainer_kwargs = dict( @@ -192,6 +217,28 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications={ + "model.decoder.transformer": StackedTransformerLayer.default_config(), + "model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": ( + None + if version == Version.V1 + else GroupedQKVLinear.default_config() + ), + } + ), + ], + ), + ), + ), ) elif model_size == "7B": trainer_kwargs = dict( @@ -287,6 +334,46 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications={ + "model.decoder.transformer": StackedTransformerLayer.default_config(), + "model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": ( + None + if version == Version.V1 + else GroupedQKVLinear.default_config() + ), + } + ), + ], + ), + ), + ( + "neuron-(trn1|trn1n).32xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=8) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications={ + "model.decoder.transformer": StackedTransformerLayer.default_config(), + "model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": ( + None + if version == Version.V1 + else GroupedQKVLinear.default_config() + ), + } + ), + ], + ), + ), ), ) elif model_size == "8B": @@ -367,6 +454,26 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications={ + "model.decoder.transformer": StackedTransformerLayer.default_config(), + "model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": ( + None + if version == Version.V1 + else GroupedQKVLinear.default_config() + ), + } + ), + ], + ), + ), ), ) elif model_size == "70B": @@ -385,7 +492,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=8, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -417,6 +524,26 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications={ + "model.decoder.transformer": StackedTransformerLayer.default_config(), + "model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": ( + None + if version == Version.V1 + else GroupedQKVLinear.default_config() + ), + } + ), + ], + ), + ), ), ) else: @@ -473,7 +600,8 @@ def model_config( ffn_dim = scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=256) if num_kv_heads: atten_cfg = GroupedQueryAttention.default_config() - atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads) + qkv_linear = FusedGroupedQKVLinear + atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads) else: atten_cfg = MultiheadAttention.default_config() atten_input_linear = FusedQKVLinear.default_config()