From 1b0004bdf10df757dbe1bb1252d481c1ccc73e40 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 6 Nov 2024 23:54:11 +0000 Subject: [PATCH] Low_pt axlearn changes Jun23-Nov20 Changes: - Remove learner.AccumulatedLearner --- axlearn/common/attention.py | 84 +++++++++++-- axlearn/common/causal_lm.py | 3 + axlearn/common/decoder.py | 1 + axlearn/common/evaler.py | 9 +- axlearn/common/flash_attention/layer.py | 13 +- .../flash_attention/neuron_attention.py | 119 ++++++++++++++++++ axlearn/common/flash_attention/utils.py | 15 ++- axlearn/common/input_tf_data.py | 3 +- axlearn/common/layers.py | 13 +- axlearn/common/learner.py | 29 ++++- axlearn/common/trainer.py | 18 ++- axlearn/common/utils.py | 46 ++++--- axlearn/experiments/text/gpt/c4_trainer.py | 3 + axlearn/experiments/text/gpt/common.py | 34 +++-- axlearn/experiments/text/gpt/fuji.py | 106 +++++++++++++--- 15 files changed, 422 insertions(+), 74 deletions(-) create mode 100644 axlearn/common/flash_attention/neuron_attention.py diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index b1a9a7e3b..788bd6c46 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -52,6 +52,7 @@ import enum import functools import math +import os from collections.abc import Sequence from enum import Enum, unique from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, Union @@ -59,6 +60,7 @@ import einops import jax from jax import numpy as jnp +from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies from axlearn.common import ops, param_init @@ -111,6 +113,7 @@ get_or_none, shapes, split_prng_key, + with_sharding_constraint, ) NEG_INF = -1e15 @@ -1082,6 +1085,7 @@ def forward( # N.B. this branch (with just the query inputs) is required in # order to get the best step time on TPU for self-attention. inputs = query # [batch, target_length, target_dim]. + inputs = checkpoint_name(inputs, name='input_to_qkv') proj = self.qkv_proj.einsum_maybe_quantized( "btd,pdnh->pbtnh", activation=inputs, kernel=params["weight"] ) @@ -1260,24 +1264,26 @@ def apply_rotary_position_embeddings( """ # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] + + def _rotate_half(x: jnp.ndarray) -> jnp.ndarray: + halves = jnp.split(x, 2, axis=-1) + return jnp.concatenate((-halves[1], halves[0]), axis=-1) + sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape) # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape) # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] - rotate_half_query = jnp.reshape( - jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape - ) + rotate_half_query = _rotate_half(query) + query = query * cos_pos + rotate_half_query * sin_pos # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] - rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape) + rotate_half_key = _rotate_half(key) key = key * cos_pos + rotate_half_key * sin_pos if rotary_value: # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] - rotate_half_value = jnp.reshape( - jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape - ) + rotate_half_value = _rotate_half(value) value = value * cos_pos + rotate_half_value * sin_pos return query, key, value @@ -1330,6 +1336,7 @@ def forward( time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: cfg = self.config + query = self._remat_name(query, "input_qkv_ag") # Query should have shape of [batch_size, seq_len, num_heads, per_head_dim]. query, key, value = self.i_proj(query, key=key, value=value) query_pos = jnp.arange(query.shape[1])[None] # [batch_size=1, seq_len]. @@ -1776,6 +1783,7 @@ def _forward_for_mode( ValueError: If key & value are an invalid combination. ValueError: If `mode` is unsupported. """ + query = self._remat_name(query, "input_qkv_ag") # Validate key & value combination. if (key is None) != (value is None): raise ValueError( @@ -1969,6 +1977,7 @@ def forward( segment_ids=segment_ids, return_aux=return_aux, ) + output = with_sharding_constraint(output, PartitionSpec('data', None, None)) return output def _cap_logits(self, logits: Tensor) -> Tensor: @@ -2648,10 +2657,17 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: return dict(attention=atten_state), atten_output if cfg.structure == "prenorm": + target = with_sharding_constraint(target, PartitionSpec('data','model',None)) skip_input = target # pre-norm: where normalization happens within the residual part. + skip_input = self._remat_name(skip_input, 'residual_skip') norm_target = self.norm(target) + norm_target = with_sharding_constraint(norm_target, PartitionSpec('data',None,None)) + norm_target = checkpoint_name(norm_target, name='before_thunk') + #norm_target = self._remat_name(norm_target, 'attention_norm') atten_state, atten_output = attention_thunk(norm_target) + atten_output = with_sharding_constraint(atten_output, PartitionSpec('data','model',None)) data = skip_input + self.stochastic_depth(self.dropout(atten_output.data)) + data = self._remat_name(data, 'residual_add') elif cfg.structure == "postnorm": # This is the structure used by the original Transformer, BERT, and RoBERTa. atten_state, atten_output = attention_thunk(target) @@ -2941,18 +2957,24 @@ def _linear2(x): remat_pt1 = "activation" remat_pt2 = "linear2" + inputs = self._remat_name(inputs, 'residual_input') if cfg.structure == "prenorm": + x = with_sharding_constraint(inputs, PartitionSpec('data','model',None)) x = self.norm(inputs) + x = self._remat_name(x, 'mlp_norm') + x = with_sharding_constraint(x, PartitionSpec('data',None,None)) x = self._linear1_activation(x) x = self._remat_name(x, remat_pt1) x = self.dropout1(x) x = _linear2(x) x = self._remat_name(x, remat_pt2) + x = with_sharding_constraint(x, PartitionSpec('data','model',None)) x = self.dropout2(x) x = self.stochastic_depth(x) if cfg.residual_weight != 1: x *= cfg.residual_weight x += inputs + x=self._remat_name(x, 'mlp_residual') elif cfg.structure == "postnorm": x = self._linear1_activation(inputs) x = self._remat_name(x, remat_pt1) @@ -3289,6 +3311,7 @@ def forward( """ inputs = data data = self.norm(data) + data = checkpoint_name(data, name='before_attention') self_atten_outputs = self.self_attention( query=data, key=data, @@ -3476,8 +3499,8 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names) ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names) # Encourage the right activation sharding. - ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) - ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + ff_layer.linear1.output_partition_spec = (batch_axis_names, None, tp_axis_names) + ff_layer.linear2.output_partition_spec = (batch_axis_names, None, None) if not isinstance(cfg, Sequence): cfg = [cfg] @@ -4071,6 +4094,14 @@ def forward( # TODO(sneha): extend_step +def save_all_names_but_these(*names_not_to_save): + # Save all values, including unnamed ones, excluding the specified names. + names_not_to_save = frozenset(names_not_to_save) + def policy(prim, *_, **params): + if 'name' in params and params['name'] in names_not_to_save: + return False + return True + return policy def build_remat_spec( stack_cfg: Union[ @@ -4105,7 +4136,35 @@ def build_remat_spec( # TODO(markblee): Switch to using isinstance everywhere. if stack_cfg.klass is PipelinedTransformerLayer: return None - + print(f'Stack_cfg {stack_cfg}') + if jax.default_backend() == 'neuron': + remat_style = os.getenv('REMAT_STYLE', 'default') + if remat_style == 'none': + # new remat 3 + return RematSpec( + prevent_cse=True, + policy=config_for_function(save_all_names_but_these).set( + names_not_to_save=(["noname"] + ) + ), + ) + else: + fused_qkv_name = stack_cfg.layer.self_attention.attention.input_linear.klass.__name__ + ffn_name = stack_cfg.layer.feed_forward.klass.__name__ + attention_name = stack_cfg.layer.self_attention.attention.klass.__name__ + print(stack_cfg.layer.self_attention.attention) + return RematSpec( + prevent_cse=stack_cfg.klass is StackedTransformerLayer, + # If we are running inside a jax.lax.scan (Repeated/Pipelined transformers + # or Repeated Conformers) we can enable common subexpression elimination optimizations. + policy=config_for_function(jax.checkpoint_policies.save_any_names_but_these).set( + names_not_to_save=(["all_gather","before_attention", "before_thunk", "input_to_qkv"] + + [f"{attention_name}.{el}" + for el in ['input_qkv_ag', 'o_proj']] + + [f"{ffn_name}.{el}" for el in ["mlp_norm", "linear2"]] + ) + ), + ) checkpoints = [] if self_attention: attention_name = stack_cfg.layer.self_attention.attention.klass.__name__ @@ -4113,7 +4172,7 @@ def build_remat_spec( [f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]] ) - if feed_forward and hasattr(stack_cfg.layer, "feed_forward"): + if False and feed_forward and hasattr(stack_cfg.layer, "feed_forward"): ffn_name = stack_cfg.layer.feed_forward.klass.__name__ checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]]) @@ -4184,7 +4243,8 @@ class CausalAttentionLogitBiasLayer(AttentionLogitBiasLayer): def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor: """Refer to AttentionLogitBiasLayer.forward for docstring.""" # Note: padding tokens are not explicitly masked. - causal_bias = (positions[:, None, :, None] < positions[:, None, None, :]) * NEG_INF + segment_ids = jnp.asarray(segment_ids, dtype=jnp.bfloat16) + causal_bias = jnp.asarray((positions[:, None, :, None] < positions[:, None, None, :]) * NEG_INF, dtype=jnp.bfloat16) return apply_attention_logit_biases( causal_bias, make_segment_mask(source_segments=segment_ids, target_segments=segment_ids) ) diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index 0513d8718..d3ccf28fd 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -298,6 +298,9 @@ def _metrics( target_labels: Tensor, target_num_bytes: Optional[Tensor], ) -> dict[str, Tensor]: + if logits.dtype in (jnp.bfloat16, jnp.float16): + logits = logits.astype(jnp.float32) + live_targets = (target_labels != self.decoder.config.pad_token_id) & (target_labels >= 0) num_targets = live_targets.sum() accuracy = ( diff --git a/axlearn/common/decoder.py b/axlearn/common/decoder.py index eaa6ef6d7..36516197d 100644 --- a/axlearn/common/decoder.py +++ b/axlearn/common/decoder.py @@ -544,6 +544,7 @@ def _forward_for_mode( if "output_norm" in self.children: x = self.output_norm(x) + x = self._remat_name(x, 'output_norm') self._add_tensor_stats("norm_outputs", x) x = self.output_dropout(x) if "lm_head" in self.children: diff --git a/axlearn/common/evaler.py b/axlearn/common/evaler.py index 391946f75..1afddb53e 100644 --- a/axlearn/common/evaler.py +++ b/axlearn/common/evaler.py @@ -37,6 +37,7 @@ input_partition_spec, replicate_to_local_data, with_sharding_constraint, + DataPartitionType, ) @@ -188,11 +189,11 @@ def _pjit(self, fn: Callable) -> Callable: in_shardings=( self._model_param_partition_specs, # model_params. None, # replicated_inputs (e.g., prng_key). - utils.input_partition_spec(), # per_example_inputs. + utils.input_partition_spec(DataPartitionType.DATA), # per_example_inputs. ), out_shardings=dict( replicated=None, - per_example=utils.input_partition_spec(), + per_example=utils.input_partition_spec(DataPartitionType.DATA), ), ) @@ -691,7 +692,7 @@ def eval_step( with jax.profiler.StepTraceAnnotation(cfg.name, step_num=step): with jax.profiler.TraceAnnotation(f"{cfg.name}.forward"): - global_input_batch = utils.host_to_global_device_array(input_batch) + global_input_batch = utils.host_to_global_device_array(input_batch, partition=DataPartitionType.DATA) forward_outputs = self.metric_calculator.forward( global_input_batch, model_params=model_params, @@ -864,7 +865,7 @@ def _compute_metrics_in_pjit( concatenated_outputs = jax.tree.map( lambda xs: ( with_sharding_constraint( - jnp.reshape(xs, (-1, *xs.shape[2:])), input_partition_spec() + jnp.reshape(xs, (-1, *xs.shape[2:])), input_partition_spec(DataPartitionType.DATA) ) if all(dim > 0 for dim in xs.shape[2:]) else None diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index 36da8d595..922228a04 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -252,12 +252,13 @@ def _compute_attention( else PartitionSpec(None) ) - attention_logit_biases_spec = cfg.mha_dim_to_partition_spec["bnts"] - if attention_logit_biases is not None: - attention_logit_biases_spec = self._logit_biases_spec(attention_logit_biases) - attention_logit_biases = with_sharding_constraint( - attention_logit_biases, attention_logit_biases_spec - ) + # attention_logit_biases_spec = cfg.mha_dim_to_partition_spec["bnts"] + attention_logit_biases_spec = PartitionSpec(None, None, None, None) + # if attention_logit_biases is not None: + # attention_logit_biases_spec = self._logit_biases_spec(attention_logit_biases) + # attention_logit_biases = with_sharding_constraint( + # attention_logit_biases, attention_logit_biases_spec + # ) # Scale query and key. q_proj = self.scale_query(q_proj) diff --git a/axlearn/common/flash_attention/neuron_attention.py b/axlearn/common/flash_attention/neuron_attention.py new file mode 100644 index 000000000..81dad4613 --- /dev/null +++ b/axlearn/common/flash_attention/neuron_attention.py @@ -0,0 +1,119 @@ +import os +from absl import logging +import jax +import jax.numpy as jnp +import functools +from functools import partial +import jax.numpy as jnp +import neuronxcc.nki.language as nl +import numpy as np +from jax_neuronx import nki_call +from neuronxcc.nki._private_kernels.legacy.attention import flash_attn_bwd, flash_fwd + +if 'LNC' not in os.environ: + raise ValueError("LNC environment variable is not set") + +cores_per_lnc = os.environ['LNC'] +if cores_per_lnc == '2': + use_lnc = True +elif cores_per_lnc == '1': + use_lnc = False +else: + raise ValueError("LNC environment variable must be set to '1' or '2'") + +disable_sharded_attn_kernel = os.environ.get('DISABLE_SHARDED_ATTN_KERNEL') +if disable_sharded_attn_kernel is not None: + use_lnc = False + +if use_lnc: + from neuronxcc.nki._private_kernels.attention import flash_fwd_shardable, flash_attn_bwd_shardable + from neuronxcc.starfish.penguin.targets.nki.private_api import vnc + +from jax import custom_vjp + +@partial(custom_vjp, nondiff_argnums=(3,4)) +def flash_attention(query, key, value, causal, softmax_scale): + out, _ = _mha_forward(query, key, value, causal, softmax_scale) + return out + +def _mha_forward(query, key, value, causal, softmax_scale): + # Get the batch size, sequence lengths, number of heads, and hidden dimension + batch_size, q_seq_len, num_heads, d_model = query.shape + _, kv_seq_len, _, _ = key.shape + + # Transpose the query, key, and value tensors + q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len] + k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len] + v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model] + + # Create the output buffer + attn_output_shape = jax.ShapeDtypeStruct((batch_size, num_heads, q_seq_len, d_model), dtype=query.dtype) + lse_shape = jax.ShapeDtypeStruct((batch_size, num_heads, nl.tile_size.pmax, q_seq_len // nl.tile_size.pmax), dtype=jnp.float32) + seed = jnp.array([1]) + + # Call the NKI kernel + if os.environ.get('ENABLE_NEW_UNSHARDED_ATTN_KERNEL'): + from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd + import neuronxcc.nki.language as nl + + assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f'unexpect num_heads: {num_heads}' + attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, seed, use_causal_mask=causal, softmax_scale=softmax_scale, mixed_precision=True, dropout_p=0.0) + else: + from neuronxcc.nki._private_kernels.legacy.attention import flash_fwd + from neuronxcc.nki._private_kernels.attention import flash_fwd_shardable + from neuronxcc.starfish.penguin.targets.nki.private_api import vnc + attn_output, lse = nki_call( + partial(flash_fwd_shardable if use_lnc else flash_fwd, use_causal_mask=causal, softmax_scale=softmax_scale, mixed_precision=True, dropout_p=0.0), + q, k, v, seed, + out_shape=(attn_output_shape, lse_shape), + grid=(batch_size, num_heads, vnc(2)) if use_lnc else (batch_size, num_heads) + ) + # Transpose the output back to the original shape + attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model] + + return attn_output, (lse, attn_output, q, k, v) + +def _mha_backward(causal, softmax_scale, res, d_attn_output): + lse, o, q, k, v = res + batch_size, num_heads, d_model, seq_len = q.shape + _, kv_seq_len, _, _ = k.shape + + # Transpose the input tensors + o = o.transpose(0, 2, 3, 1) + dy = d_attn_output.transpose(0, 2, 3, 1) + + # Transpose v tensor + v = jnp.transpose(v, axes=(0, 1, 3, 2)) + # Create the output buffer shapes + d_query_shape = jax.ShapeDtypeStruct(q.shape, q.dtype) + d_key_shape = jax.ShapeDtypeStruct(k.shape, k.dtype) + d_value_shape = jax.ShapeDtypeStruct(v.shape, v.dtype) + seed = jnp.array([1]) + + # Call the NKI kernel + if os.environ.get('ENABLE_NEW_UNSHARDED_ATTN_KERNEL'): + from neuronxcc.nki.kernels.attention import flash_attn_bwd + import neuronxcc.nki.language as nl + assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f'unexpected num_heads: {num_heads}' + d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, o, dy, lse, seed, use_causal_mask=causal, mixed_precision=True, dropout_p=0.0, softmax_scale=softmax_scale) + else: + from neuronxcc.nki._private_kernels.legacy.attention import flash_attn_bwd + from neuronxcc.nki._private_kernels.attention import flash_attn_bwd_shardable + from neuronxcc.starfish.penguin.targets.nki.private_api import vnc + d_query, d_key, d_value = nki_call( + partial(flash_attn_bwd_shardable if use_lnc else flash_attn_bwd, use_causal_mask=causal, mixed_precision=True, dropout_p=0.0, softmax_scale=softmax_scale), + q, k, v, o, dy, lse, seed, + out_shape=[d_query_shape, d_key_shape, d_value_shape], + grid=(batch_size, num_heads, vnc(2)) if use_lnc else (batch_size, num_heads) + ) + + # Batch seq_len heads, head_dim + # Transpose the gradients back to the original shape + d_query = d_query.transpose(0, 3, 1, 2) + d_key = d_key.transpose(0, 3, 1, 2) + d_value = d_value.transpose(0, 3, 1, 2) + + return d_query, d_key, d_value + +flash_attention.defvjp(_mha_forward, _mha_backward) + diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 7da0543f1..a86b0768f 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -75,7 +75,7 @@ def mha_reference( def flash_attention_implementation( - backend: Literal["cpu", "tpu", "gpu", "xla"], + backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"], *, mask: Optional[MaskFn] = None, softmax_scale: float, @@ -159,6 +159,19 @@ def jit_attn(query, key, value, bias, segment_ids): return jit_attn + elif backend == "neuron": + from axlearn.common.flash_attention.neuron_attention import ( + flash_attention as neuron_flash_attention, + ) + + # shard_map-decorated function needs to be jitted. + @jax.jit + def jit_attn(query, key, value, bias): + return neuron_flash_attention( + query, key, value, causal, softmax_scale) + + return jit_attn + elif backend in ("cpu", "xla"): if backend == "cpu": logging.warning("Flash attention CPU backend is for testing only.") diff --git a/axlearn/common/input_tf_data.py b/axlearn/common/input_tf_data.py index 1d44ad1bb..432e91460 100644 --- a/axlearn/common/input_tf_data.py +++ b/axlearn/common/input_tf_data.py @@ -819,7 +819,8 @@ def batch( f"global_batch_size ({global_batch_size}) must be divisible by " f"number of JAX processes (data feeds) ({num_data_feeds})." ) - per_feed_batch_size = global_batch_size // num_data_feeds + + per_feed_batch_size = global_batch_size if repeat is not None and (not isinstance(repeat, int) or repeat <= 0): raise ValueError(f"Invalid repeat (must be a positive integer): {repeat}") diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index 6736fb580..42fd17b72 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -28,7 +28,7 @@ from jax import nn from jax import numpy as jnp from jax.sharding import PartitionSpec - +from jax.ad_checkpoint import checkpoint_name from axlearn.common.base_layer import BaseLayer, FactorizationSpec, ParameterNoise, ParameterSpec from axlearn.common.config import ( REQUIRED, @@ -334,17 +334,20 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: return { "scale": ParameterSpec(shape=[cfg.input_dim], mesh_axes=(None,)), } - + def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor: del paddings # paddings do not affect LayerNorm results cfg = self.config x_dtype = x.dtype if cfg.forward_dtype is not None: x = x.astype(cfg.forward_dtype) + x = with_sharding_constraint(x, PartitionSpec('data','model', None)) moment2 = (x * x).mean(axis=-1, keepdims=True) x = x * jax.lax.rsqrt(moment2 + cfg.eps) x = x.astype(x_dtype) + x = with_sharding_constraint(x, PartitionSpec('data','model', None)) x = x * self.parameters["scale"] + x = with_sharding_constraint(x, PartitionSpec('data','model', None)) return x @@ -2498,8 +2501,12 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: ) def forward(self, x: Tensor) -> Tensor: + x = with_sharding_constraint(x, PartitionSpec('data', None)) emb = self.parameters["weight"] - return emb[x] + emb = with_sharding_constraint(emb, PartitionSpec('model', None)) + activation = emb[x] + activation = with_sharding_constraint(activation, PartitionSpec('data', None, None)) + return activation def attend(self, x: Tensor) -> Tensor: """Apply query array 'x' to the embedding weight array. diff --git a/axlearn/common/learner.py b/axlearn/common/learner.py index 3e4c68cd5..9dcafd2e9 100644 --- a/axlearn/common/learner.py +++ b/axlearn/common/learner.py @@ -8,10 +8,11 @@ """ from __future__ import annotations +import os import dataclasses import enum from collections.abc import Mapping, Sequence -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, NamedTuple, Optional, Tuple, cast import jax import optax @@ -377,6 +378,32 @@ def _apply_updates(base: Nested[Tensor], updates: Nested[Tensor]) -> Nested[Tens base[k] = _apply_updates(base[k], v) return base +class MetricsAccumulationOp(NamedTuple): + microbatches: int + + def aggregrate(self, x, buffer): + raise NotImplementedError(self) + def normalize(self, buffer): + raise NotImplementedError(self) + +class ArithmeticMeanStrategy(MetricsAccumulationOp): + def aggregrate(self, x, buffer): + return buffer + x + def normalize(self, buffer): + return buffer / self.microbatches + +class GeometricMeanStrategy(MetricsAccumulationOp): + def aggregrate(self, x, buffer): + return buffer * x + def normalize(self, buffer): + return buffer ** (-self.microbatches) + +class AddStrategy(MetricsAccumulationOp): + def aggregrate(self, x, buffer): + return buffer + x + def normalize(self, buffer): + return buffer + class CompositeLearner(BaseLearner): """The composite learner supports different sub learners on different subset of parameters. diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index a60560769..78e63890a 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -18,6 +18,7 @@ from jax import numpy as jnp from jax.experimental import multihost_utils from jax.experimental.pjit import pjit +from jax.sharding import NamedSharding from axlearn.common import file_system as fs from axlearn.common import measurement, utils @@ -53,11 +54,13 @@ NestedPartitionSpec, NestedTensor, PartitionSpec, + DataPartitionType, Tensor, count_model_params, flatten_items, match_regex_rules, thread_stack_traces, + TensorSpec, ) @@ -195,6 +198,9 @@ class Config(Module.Config): # An optional recorder for measuring common metrics like step time. recorder: Optional[InstantiableConfig[measurement.Recorder]] = None + # The input partition: + # Options: FULL (default), DATA, REPLICATED + input_partition_type: Required[DataPartitionType] = DataPartitionType.DATA # An additional context manager to run the training loop and initialization inside of. # The provided config should instantiate to a thunk that returns the context manager. @@ -343,7 +349,7 @@ def trainer_state_partition_specs(self): def _train_step_input_partition_specs(self): # By default, each input tensor is fully partitioned along the batch axis. - return utils.input_partition_spec() + return utils.input_partition_spec(self.config.input_partition_type) def model_params_for_eval(self): state = self.trainer_state @@ -568,14 +574,14 @@ def run( self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) output = self._run_step( - utils.host_to_global_device_array(input_batch), + utils.host_to_global_device_array(input_batch, partition=cfg.input_partition_type), force_run_evals=( force_run_eval_sets_at_max_step if self.step >= cfg.max_step else None ), ) self.vlog(3, "Done step %s", self.step) num_steps += 1 - if num_steps % 100 == 0: + if num_steps % 1 == 0: now = time.perf_counter() average_step_time = (now - start_time) / num_steps self._step_log("Average step time: %s seconds", average_step_time) @@ -585,6 +591,12 @@ def run( if self.step >= cfg.max_step: self._step_log("Reached max_step=%s. Stopping", cfg.max_step) break + max_train_steps = os.environ.get('MAX_TRAIN_STEPS') + if max_train_steps is not None: + max_train_steps = int(max_train_steps) + if self.step >= max_train_steps: + self._step_log(f"Stopping test run after {max_train_steps} steps") + break if self.step < cfg.max_step: self._step_log("Reached end of inputs. Stopping") self._step_log("Checkpointer flushed.") diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 1401337f8..58f5bd14b 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -7,7 +7,6 @@ # Licensed under the Apache License, Version 2.0 (the "License"). """Common utilities.""" - import collections import contextlib import copy @@ -25,9 +24,11 @@ from enum import Enum from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union +import flax.struct import jax import numpy as np from absl import logging +from flax import serialization from jax import numpy as jnp from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal @@ -257,7 +258,7 @@ def tree_unflatten(cls, keys, values): return cls(zip(keys, values)) -# Register VDict as a dict for serialization. +# Register VDict as a dict for Flax serialization. serialization.register_serialization_state( VDict, # pylint: disable-next=protected-access @@ -529,7 +530,16 @@ def add_leaves(i, x): return jax.tree_util.tree_unflatten(treedef, axes) -def input_partition_spec() -> PartitionSpec: +class DataPartitionType(Enum): + # Data are fully partitioned across all devices. + FULL = "full" + # Data are fully replicated across all devices. + REPLICATED = "replicated" + # Data are partially partitioned across rank of data + DATA = "data" + + +def input_partition_spec(partition: DataPartitionType = DataPartitionType.FULL) -> PartitionSpec: """Returns partition spec for the input batch. We partition the inputs along all axes. For example, if the mesh has shape (64, 4) and axis @@ -538,10 +548,13 @@ def input_partition_spec() -> PartitionSpec: Must be called within the context of a Mesh. """ - mesh = thread_resources.env.physical_mesh - return PartitionSpec( - mesh.axis_names, - ) + if partition == DataPartitionType.FULL: + mesh = thread_resources.env.physical_mesh + return PartitionSpec( + mesh.axis_names, + ) + elif partition == DataPartitionType.DATA: + return PartitionSpec('data') # Key associated with per-example dataset dispatch index tensor, indicating which logical @@ -586,19 +599,14 @@ def traverse_and_dispatch(data: NestedTensor) -> NestedTensor: return traverse_and_dispatch(input_batch) -class DataPartitionType(Enum): - # Data are fully partitioned across all devices. - FULL = "full" - # Data are fully replicated across all devices. - REPLICATED = "replicated" - - -def data_partition_type_to_spec(partition: DataPartitionType) -> PartitionSpec: +def data_partition_type_to_spec(partition: DataPartitionType = DataPartitionType.FULL) -> PartitionSpec: """Returns a PartitionSpec for the given partition type.""" if partition == DataPartitionType.FULL: - return input_partition_spec() + return input_partition_spec(partition) elif partition == DataPartitionType.REPLICATED: return None + elif partition == DataPartitionType.DATA: + return input_partition_spec(partition) else: raise NotImplementedError(f"Unsupported partition: {partition}") @@ -636,6 +644,8 @@ def make_gda(x, partition_spec): global_shape = (x.shape[0] * process_count, *x.shape[1:]) elif partition == DataPartitionType.REPLICATED: global_shape = (x.shape[0], *x.shape[1:]) + elif partition == DataPartitionType.DATA: + global_shape = (x.shape[0], *x.shape[1:]) else: raise NotImplementedError(f"Unsupported partition: {partition}") return jax.make_array_from_process_local_data( @@ -1223,6 +1233,10 @@ def create_device_mesh( logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.") return build_standard_mesh(mesh_shape, devices=devices) + # Neuron also only uses standard mesh + if device_platform == "neuron": + return build_standard_mesh(mesh_shape, devices=devices) + # Canonicalize to HybridMeshShape. If DCN mesh is not specified, break the first non-singleton # device axis (the least communication intensive) over the number of slices/granules. If all # axes are singletons, this is effectively a no-op, since this implies a single-granule diff --git a/axlearn/experiments/text/gpt/c4_trainer.py b/axlearn/experiments/text/gpt/c4_trainer.py index 8c70422e3..e1f182a63 100644 --- a/axlearn/experiments/text/gpt/c4_trainer.py +++ b/axlearn/experiments/text/gpt/c4_trainer.py @@ -49,6 +49,9 @@ from axlearn.experiments.text.gpt.common import mixture_train_input_source, tfds_input from axlearn.experiments.trainer_config_utils import TrainerConfigFn +import jax +import jax_neuronx + # Sentencepiece vocabs generated from c4/en:3.0.1. # See bpe_{32k,128k}.json for the sentencepiece settings. _SENTENCEPIECE_MODEL_NAME = { diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index d32120d25..f89a6ca50 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -10,6 +10,7 @@ See c4_trainer.py for how they are used. """ +import os import math from collections.abc import Sequence from typing import Literal, Optional, Protocol, Union @@ -34,6 +35,7 @@ BaseQKVLinear, MultiheadAttention, RepeatedTransformerLayer, + StackedTransformerLayer, TransformerLayer, build_remat_spec, set_double_shard_weights_config, @@ -47,7 +49,7 @@ maybe_instantiate, maybe_set_config, ) -from axlearn.common.decoder import Decoder +from axlearn.common.decoder import Decoder, LmHead from axlearn.common.embedding import TransformerTextEmbeddings from axlearn.common.evaler import BaseMetricCalculator, ModelSummaryAccumulator, SpmdEvaler from axlearn.common.evaler import every_n_steps_policy as eval_every_n_steps_policy @@ -57,7 +59,7 @@ from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer from axlearn.common.summary_writer import BaseWriter from axlearn.common.trainer import MeshShape, SpmdTrainer -from axlearn.common.utils import HybridMeshShape, Nested, get_data_dir +from axlearn.common.utils import DataPartitionType, HybridMeshShape, Nested, get_data_dir from axlearn.experiments.text.common import DataMixtureComponent, tfds_text_source from axlearn.experiments.trainer_config_utils import TrainerConfigFn @@ -67,7 +69,8 @@ # We typically use bfloat16 as the step dtype, # (but usually keep parameters and optimizer state in float32). -STEP_DTYPE = jnp.bfloat16 +STEP_DTYPE = jnp.float32 if os.environ.get('USE_FP32_COMPUTE') == '1' else jnp.bfloat16 +print(f"STEP_DTYPE: {STEP_DTYPE}") # The default mesh-axis names for LM training, from least to most communication intensive. @@ -199,9 +202,13 @@ def update_model_remat_config( offload_dst: Destination of remat checkptoing offloading. Raises: - NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. + NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer + or StackedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: + if ( + stack_cfg.klass is not RepeatedTransformerLayer + and stack_cfg.klass is not StackedTransformerLayer + ): raise NotImplementedError( f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" ) @@ -288,7 +295,7 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: + if stack_cfg.klass is RepeatedTransformerLayer or stack_cfg.klass is StackedTransformerLayer: update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) @@ -321,11 +328,15 @@ def model_config( set_double_shard_weights_config( cfg.decoder.transformer.layer, batch_axis_names=batch_axis_names, - fsdp_axis_names=("expert", "fsdp", "seq"), + fsdp_axis_names=("data"), tp_axis_names="model", seq_axis_names="seq", ) - cfg.decoder.logits_partition_spec = (batch_axis_names, "seq", "model") + + tp_axis_names='model' + fsdp_axis_names='data' + cfg.decoder.emb.token_emb.param_partition_spec = (tp_axis_names, fsdp_axis_names) # shard vocab + set_bias_recursively(cfg, False) set_norm_recursively(cfg, normalization) cfg.z_loss_scale = z_loss_scale @@ -375,6 +386,8 @@ def adamw_decoupled_learner_config( b2: float = 0.95, eps: float = 1e-8, adam_update_transformation: Optional[ConfigOr[PartitionedGradientTransformation]] = None, + gradient_accumulation_microbatches: int = 1, + metrics_accumulation_key_ops: dict = {} ) -> learner.Learner.Config: """Build learner using the AdamW optimizer and a cosine lr schedule with linear warmup.""" update_schedule = config_for_function(schedule.cosine_with_linear_warmup).set( @@ -397,6 +410,7 @@ def adamw_decoupled_learner_config( weight_decay=weight_decay, weight_decay_per_param_scale=None, adam_update_transformation=adam_update_transformation, + mu_dtype=jnp.float32, ), ] ) @@ -636,6 +650,7 @@ def get_trainer_config_fn( max_step: int, train_batch_size: int, train_input_source: InstantiableConfig[input_tf_data.BuildDatasetFn], + input_partition_type: DataPartitionType, evalers: dict[str, SpmdEvaler.Config], mesh_shape: Union[MeshShape, HybridMeshShape], mesh_axis_names: Sequence[str] = MESH_AXIS_NAMES, @@ -689,6 +704,7 @@ def config_fn() -> InstantiableConfig: pad_example_fn=input_tf_data.default_pad_example_fn, ), ) + cfg.input_partition_type = input_partition_type cfg.evalers = {} for name, evaler_cfg in evalers.items(): evaler_cfg.input.batcher.set(global_batch_size=eval_batch_size or train_batch_size) @@ -706,7 +722,7 @@ def config_fn() -> InstantiableConfig: ) cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) cfg.checkpointer.keep_last_n = 3 - cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) + cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 10) cfg.summary_writer.max_queue = 1000 if len(mesh_axis_names) != len(mesh_shape): raise ValueError( diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..a8a8ac689 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -13,6 +13,8 @@ import enum import functools import itertools +import jax +import os from typing import Any, Optional, Union from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies @@ -21,10 +23,12 @@ from axlearn.common.attention import ( BaseStackedTransformerLayer, FusedGroupedQKVLinear, + GroupedQKVLinear, FusedQKVLinear, GroupedQueryAttention, MultiheadAttention, RepeatedTransformerLayer, + StackedTransformerLayer, RoFormerQKVLinear, ) from axlearn.common.base_layer import RematSpec @@ -32,6 +36,7 @@ from axlearn.common.decoder import LmHead from axlearn.common.embedding import TransformerTextEmbeddings from axlearn.common.layers import RMSNorm +from axlearn.common.utils import DataPartitionType from axlearn.common.trainer import SpmdTrainer from axlearn.common.trainer_config_modifier import ( ChainConfigModifier, @@ -65,7 +70,7 @@ class Version(enum.Enum): # Mapping from Fuji versions to vocab sizes. VOCAB_SIZE = { - Version.V1: 32 * 1024, + Version.V1: 128 * 1024, Version.V2: 32 * 1024, Version.V3: 128256, } @@ -79,8 +84,15 @@ class Version(enum.Enum): } +TP_DEGREE = int(os.environ.get('TP_DEGREE', '4')) +TRN_MODEL_AXIS_SIZE = TP_DEGREE + +GRADIENT_ACCUMULATION_MICROBATCHES = int( + os.environ.get("NEURON_GRAD_ACC_COUNT",1)) +NUM_NODES = int(os.environ.get("NEURON_NUM_NODES",1)) + ROPE_THETA = { - Version.V1: 1e4, + Version.V1: 5e5, Version.V2: 1e4, Version.V3: 5e5, } @@ -113,15 +125,20 @@ def get_trainer_kwargs( *, vocab_size: int, version: Version, - flash_attention: bool = False, + flash_attention: bool = True, ) -> dict[str, Any]: """Construct default trainer kwargs given a model size.""" + #NOTE: tokens_per_batch is hardcoded to match apple/axlearn upstream + # This is not a bug. Do not fix unless to get to match upstream baseline tokens_per_batch = 4 * (1024**2) # 4M tokens. if model_size not in TOTAL_TOKENS[version]: return {} max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch max_sequence_length = MAX_SEQUENCE_LENGTH[version] - train_batch_size = tokens_per_batch // max_sequence_length + train_batch_size = int(jax.device_count()/TRN_MODEL_AXIS_SIZE) + train_batch_size *= GRADIENT_ACCUMULATION_MICROBATCHES + train_batch_size *= NUM_NODES + # Whether to use grouped query attention. num_kv_heads = None @@ -149,13 +166,14 @@ def get_trainer_kwargs( flash_attention=flash_attention, ), learner_kwargs=dict(peak_lr=6e-4, weight_decay=0.01), + input_partition_type=DataPartitionType.DATA, max_sequence_length=64, train_batch_size=32, - eval_batch_size=32, max_step=3000, eval_every_n_steps=1500, save_every_n_steps=500, mesh_shape=mesh_shape_from_axes(data=-1), + eval_batch_size=int(jax.device_count()/TRN_MODEL_AXIS_SIZE), ) elif model_size == "1B": trainer_kwargs = dict( @@ -196,19 +214,22 @@ def get_trainer_kwargs( elif model_size == "7B": trainer_kwargs = dict( model_kwargs=dict( - num_layers=32, - hidden_dim=128 * 32, - num_heads=32, - num_kv_heads=num_kv_heads, + num_layers=10, + hidden_dim=8192, + ffn_dim=scaled_hidden_dim(scale=4, round_up_to_multiples_of=16), + num_heads=64, + num_kv_heads=None, rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), + input_partition_type=DataPartitionType.DATA, + # 1 batch per DP replica + train_batch_size=int((jax.device_count()/TRN_MODEL_AXIS_SIZE)*GRADIENT_ACCUMULATION_MICROBATCHES), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, - max_step=max_step, - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + max_step=500_000, # 2T tokens // 4M tokens/step. + mesh_shape=mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE), mesh_rules=( # Step time: # v1 on tpu-v4-1024 (512 chips): 3.03s @@ -287,7 +308,13 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "trn2", + mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE), + ), ), + eval_batch_size=int(jax.device_count()/TRN_MODEL_AXIS_SIZE), + eval_every_n_steps=5000, ) elif model_size == "8B": trainer_kwargs = dict( @@ -372,7 +399,7 @@ def get_trainer_kwargs( elif model_size == "70B": trainer_kwargs = dict( model_kwargs=dict( - num_layers=80, + num_layers=int(os.environ.get('N_LAYERS', 4)), hidden_dim=128 * 64, num_heads=64, # No GQA support in V1 models, so num_kv_heads is the same as num_heads. @@ -385,8 +412,9 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, - max_step=max_step, + input_partition_type=DataPartitionType.DATA, + train_batch_size=int((jax.device_count()/TRN_MODEL_AXIS_SIZE)*GRADIENT_ACCUMULATION_MICROBATCHES), + max_step=500000, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( # TPU V5e maximum per device batch is 1. @@ -417,7 +445,48 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "trn2", + mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE), + ), + ), + eval_batch_size=int(jax.device_count()/TRN_MODEL_AXIS_SIZE), + eval_every_n_steps=500_000, + ) + elif model_size == "405B": + trainer_kwargs = dict( + model_kwargs=dict( + num_layers=int(os.environ.get('N_LAYERS', 4)), + hidden_dim=16384, + ffn_dim=53248, + num_heads=128, + # No GQA support in V1 models, so num_kv_heads is the same as num_heads. + num_kv_heads=None, #if version == Version.V1 else 8, + rope_theta=rope_theta, + flash_attention=flash_attention, + ), + learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), + max_sequence_length=8192, + input_partition_type=DataPartitionType.DATA, + train_batch_size=int((jax.device_count()/TRN_MODEL_AXIS_SIZE)*GRADIENT_ACCUMULATION_MICROBATCHES), + max_step=500000, + mesh_shape=mesh_shape_from_axes(fsdp=-1), + mesh_rules=( + # tpu-v5e. step time: TBD. + ("tpu-v5litepod-256", mesh_shape_from_axes(data=-1, fsdp=256)), + # H100/A100 80G. Maximum per-node batch size = 16, hence need >= 64 nodes. + # v2 on gpu-p5.48xlarge 8x64, step time: 12.9s. + ( + "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", + mesh_shape_from_axes(data=-1, fsdp=128), + ), + ( + "trn2", + mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE), + ), ), + eval_batch_size=int(jax.device_count()/TRN_MODEL_AXIS_SIZE), + eval_every_n_steps=500_000, ) else: raise NotImplementedError(f"Unknown model size {model_size}.") @@ -443,7 +512,7 @@ def model_config( shared_lm_head: bool, dropout_rate: float = 0.0, ffn_dim: Optional[Union[int, config.FunctionConfigBase]] = None, - flash_attention: bool = False, + flash_attention: bool = True, stack_cfg: Optional[BaseStackedTransformerLayer.Config] = None, ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -473,7 +542,8 @@ def model_config( ffn_dim = scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=256) if num_kv_heads: atten_cfg = GroupedQueryAttention.default_config() - atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads) + # atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads) + atten_input_linear = GroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads) else: atten_cfg = MultiheadAttention.default_config() atten_input_linear = FusedQKVLinear.default_config() @@ -491,7 +561,7 @@ def model_config( hidden_dim=hidden_dim, num_heads=num_heads, vocab_size=vocab_size, - stack_cfg=stack_cfg if stack_cfg is not None else RepeatedTransformerLayer.default_config(), + stack_cfg=stack_cfg if stack_cfg is not None else StackedTransformerLayer.default_config(), activation_fn=activation_fn, ffn_dim=ffn_dim, normalization=RMSNorm.default_config().set(eps=1e-5, forward_dtype=None),