Skip to content

Commit

Permalink
Low_pt axlearn changes Jun23-Nov20
Browse files Browse the repository at this point in the history
Changes:
- Remove learner.AccumulatedLearner
  • Loading branch information
apivovarov committed Dec 2, 2024
1 parent c20387c commit 1b0004b
Show file tree
Hide file tree
Showing 15 changed files with 422 additions and 74 deletions.
84 changes: 72 additions & 12 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@
import enum
import functools
import math
import os
from collections.abc import Sequence
from enum import Enum, unique
from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, Union

import einops
import jax
from jax import numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies

from axlearn.common import ops, param_init
Expand Down Expand Up @@ -111,6 +113,7 @@
get_or_none,
shapes,
split_prng_key,
with_sharding_constraint,
)

NEG_INF = -1e15
Expand Down Expand Up @@ -1082,6 +1085,7 @@ def forward(
# N.B. this branch (with just the query inputs) is required in
# order to get the best step time on TPU for self-attention.
inputs = query # [batch, target_length, target_dim].
inputs = checkpoint_name(inputs, name='input_to_qkv')
proj = self.qkv_proj.einsum_maybe_quantized(
"btd,pdnh->pbtnh", activation=inputs, kernel=params["weight"]
)
Expand Down Expand Up @@ -1260,24 +1264,26 @@ def apply_rotary_position_embeddings(
"""
# sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]

def _rotate_half(x: jnp.ndarray) -> jnp.ndarray:
halves = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-halves[1], halves[0]), axis=-1)

sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape)
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query = jnp.reshape(
jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape
)
rotate_half_query = _rotate_half(query)

query = query * cos_pos + rotate_half_query * sin_pos
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape)
rotate_half_key = _rotate_half(key)
key = key * cos_pos + rotate_half_key * sin_pos
if rotary_value:
# rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
rotate_half_value = jnp.reshape(
jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape
)
rotate_half_value = _rotate_half(value)
value = value * cos_pos + rotate_half_value * sin_pos
return query, key, value

Expand Down Expand Up @@ -1330,6 +1336,7 @@ def forward(
time_step: Optional[Tensor] = None,
) -> BaseQKVLinear.Output:
cfg = self.config
query = self._remat_name(query, "input_qkv_ag")
# Query should have shape of [batch_size, seq_len, num_heads, per_head_dim].
query, key, value = self.i_proj(query, key=key, value=value)
query_pos = jnp.arange(query.shape[1])[None] # [batch_size=1, seq_len].
Expand Down Expand Up @@ -1776,6 +1783,7 @@ def _forward_for_mode(
ValueError: If key & value are an invalid combination.
ValueError: If `mode` is unsupported.
"""
query = self._remat_name(query, "input_qkv_ag")
# Validate key & value combination.
if (key is None) != (value is None):
raise ValueError(
Expand Down Expand Up @@ -1969,6 +1977,7 @@ def forward(
segment_ids=segment_ids,
return_aux=return_aux,
)
output = with_sharding_constraint(output, PartitionSpec('data', None, None))
return output

def _cap_logits(self, logits: Tensor) -> Tensor:
Expand Down Expand Up @@ -2648,10 +2657,17 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
return dict(attention=atten_state), atten_output

if cfg.structure == "prenorm":
target = with_sharding_constraint(target, PartitionSpec('data','model',None))
skip_input = target # pre-norm: where normalization happens within the residual part.
skip_input = self._remat_name(skip_input, 'residual_skip')
norm_target = self.norm(target)
norm_target = with_sharding_constraint(norm_target, PartitionSpec('data',None,None))
norm_target = checkpoint_name(norm_target, name='before_thunk')
#norm_target = self._remat_name(norm_target, 'attention_norm')
atten_state, atten_output = attention_thunk(norm_target)
atten_output = with_sharding_constraint(atten_output, PartitionSpec('data','model',None))
data = skip_input + self.stochastic_depth(self.dropout(atten_output.data))
data = self._remat_name(data, 'residual_add')
elif cfg.structure == "postnorm":
# This is the structure used by the original Transformer, BERT, and RoBERTa.
atten_state, atten_output = attention_thunk(target)
Expand Down Expand Up @@ -2941,18 +2957,24 @@ def _linear2(x):

remat_pt1 = "activation"
remat_pt2 = "linear2"
inputs = self._remat_name(inputs, 'residual_input')
if cfg.structure == "prenorm":
x = with_sharding_constraint(inputs, PartitionSpec('data','model',None))
x = self.norm(inputs)
x = self._remat_name(x, 'mlp_norm')
x = with_sharding_constraint(x, PartitionSpec('data',None,None))
x = self._linear1_activation(x)
x = self._remat_name(x, remat_pt1)
x = self.dropout1(x)
x = _linear2(x)
x = self._remat_name(x, remat_pt2)
x = with_sharding_constraint(x, PartitionSpec('data','model',None))
x = self.dropout2(x)
x = self.stochastic_depth(x)
if cfg.residual_weight != 1:
x *= cfg.residual_weight
x += inputs
x=self._remat_name(x, 'mlp_residual')
elif cfg.structure == "postnorm":
x = self._linear1_activation(inputs)
x = self._remat_name(x, remat_pt1)
Expand Down Expand Up @@ -3289,6 +3311,7 @@ def forward(
"""
inputs = data
data = self.norm(data)
data = checkpoint_name(data, name='before_attention')
self_atten_outputs = self.self_attention(
query=data,
key=data,
Expand Down Expand Up @@ -3476,8 +3499,8 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config):
ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names)
ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names)
# Encourage the right activation sharding.
ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
ff_layer.linear1.output_partition_spec = (batch_axis_names, None, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, None, None)

if not isinstance(cfg, Sequence):
cfg = [cfg]
Expand Down Expand Up @@ -4071,6 +4094,14 @@ def forward(

# TODO(sneha): extend_step

def save_all_names_but_these(*names_not_to_save):
# Save all values, including unnamed ones, excluding the specified names.
names_not_to_save = frozenset(names_not_to_save)
def policy(prim, *_, **params):
if 'name' in params and params['name'] in names_not_to_save:
return False
return True
return policy

def build_remat_spec(
stack_cfg: Union[
Expand Down Expand Up @@ -4105,15 +4136,43 @@ def build_remat_spec(
# TODO(markblee): Switch to using isinstance everywhere.
if stack_cfg.klass is PipelinedTransformerLayer:
return None

print(f'Stack_cfg {stack_cfg}')
if jax.default_backend() == 'neuron':
remat_style = os.getenv('REMAT_STYLE', 'default')
if remat_style == 'none':
# new remat 3
return RematSpec(
prevent_cse=True,
policy=config_for_function(save_all_names_but_these).set(
names_not_to_save=(["noname"]
)
),
)
else:
fused_qkv_name = stack_cfg.layer.self_attention.attention.input_linear.klass.__name__
ffn_name = stack_cfg.layer.feed_forward.klass.__name__
attention_name = stack_cfg.layer.self_attention.attention.klass.__name__
print(stack_cfg.layer.self_attention.attention)
return RematSpec(
prevent_cse=stack_cfg.klass is StackedTransformerLayer,
# If we are running inside a jax.lax.scan (Repeated/Pipelined transformers
# or Repeated Conformers) we can enable common subexpression elimination optimizations.
policy=config_for_function(jax.checkpoint_policies.save_any_names_but_these).set(
names_not_to_save=(["all_gather","before_attention", "before_thunk", "input_to_qkv"] +
[f"{attention_name}.{el}"
for el in ['input_qkv_ag', 'o_proj']] +
[f"{ffn_name}.{el}" for el in ["mlp_norm", "linear2"]]
)
),
)
checkpoints = []
if self_attention:
attention_name = stack_cfg.layer.self_attention.attention.klass.__name__
checkpoints.extend(
[f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]]
)

if feed_forward and hasattr(stack_cfg.layer, "feed_forward"):
if False and feed_forward and hasattr(stack_cfg.layer, "feed_forward"):
ffn_name = stack_cfg.layer.feed_forward.klass.__name__
checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]])

Expand Down Expand Up @@ -4184,7 +4243,8 @@ class CausalAttentionLogitBiasLayer(AttentionLogitBiasLayer):
def forward(self, *, segment_ids: Tensor, positions: Tensor) -> Tensor:
"""Refer to AttentionLogitBiasLayer.forward for docstring."""
# Note: padding tokens are not explicitly masked.
causal_bias = (positions[:, None, :, None] < positions[:, None, None, :]) * NEG_INF
segment_ids = jnp.asarray(segment_ids, dtype=jnp.bfloat16)
causal_bias = jnp.asarray((positions[:, None, :, None] < positions[:, None, None, :]) * NEG_INF, dtype=jnp.bfloat16)
return apply_attention_logit_biases(
causal_bias, make_segment_mask(source_segments=segment_ids, target_segments=segment_ids)
)
Expand Down
3 changes: 3 additions & 0 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ def _metrics(
target_labels: Tensor,
target_num_bytes: Optional[Tensor],
) -> dict[str, Tensor]:
if logits.dtype in (jnp.bfloat16, jnp.float16):
logits = logits.astype(jnp.float32)

live_targets = (target_labels != self.decoder.config.pad_token_id) & (target_labels >= 0)
num_targets = live_targets.sum()
accuracy = (
Expand Down
1 change: 1 addition & 0 deletions axlearn/common/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def _forward_for_mode(

if "output_norm" in self.children:
x = self.output_norm(x)
x = self._remat_name(x, 'output_norm')
self._add_tensor_stats("norm_outputs", x)
x = self.output_dropout(x)
if "lm_head" in self.children:
Expand Down
9 changes: 5 additions & 4 deletions axlearn/common/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
input_partition_spec,
replicate_to_local_data,
with_sharding_constraint,
DataPartitionType,
)


Expand Down Expand Up @@ -188,11 +189,11 @@ def _pjit(self, fn: Callable) -> Callable:
in_shardings=(
self._model_param_partition_specs, # model_params.
None, # replicated_inputs (e.g., prng_key).
utils.input_partition_spec(), # per_example_inputs.
utils.input_partition_spec(DataPartitionType.DATA), # per_example_inputs.
),
out_shardings=dict(
replicated=None,
per_example=utils.input_partition_spec(),
per_example=utils.input_partition_spec(DataPartitionType.DATA),
),
)

Expand Down Expand Up @@ -691,7 +692,7 @@ def eval_step(

with jax.profiler.StepTraceAnnotation(cfg.name, step_num=step):
with jax.profiler.TraceAnnotation(f"{cfg.name}.forward"):
global_input_batch = utils.host_to_global_device_array(input_batch)
global_input_batch = utils.host_to_global_device_array(input_batch, partition=DataPartitionType.DATA)
forward_outputs = self.metric_calculator.forward(
global_input_batch,
model_params=model_params,
Expand Down Expand Up @@ -864,7 +865,7 @@ def _compute_metrics_in_pjit(
concatenated_outputs = jax.tree.map(
lambda xs: (
with_sharding_constraint(
jnp.reshape(xs, (-1, *xs.shape[2:])), input_partition_spec()
jnp.reshape(xs, (-1, *xs.shape[2:])), input_partition_spec(DataPartitionType.DATA)
)
if all(dim > 0 for dim in xs.shape[2:])
else None
Expand Down
13 changes: 7 additions & 6 deletions axlearn/common/flash_attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,13 @@ def _compute_attention(
else PartitionSpec(None)
)

attention_logit_biases_spec = cfg.mha_dim_to_partition_spec["bnts"]
if attention_logit_biases is not None:
attention_logit_biases_spec = self._logit_biases_spec(attention_logit_biases)
attention_logit_biases = with_sharding_constraint(
attention_logit_biases, attention_logit_biases_spec
)
# attention_logit_biases_spec = cfg.mha_dim_to_partition_spec["bnts"]
attention_logit_biases_spec = PartitionSpec(None, None, None, None)
# if attention_logit_biases is not None:
# attention_logit_biases_spec = self._logit_biases_spec(attention_logit_biases)
# attention_logit_biases = with_sharding_constraint(
# attention_logit_biases, attention_logit_biases_spec
# )

# Scale query and key.
q_proj = self.scale_query(q_proj)
Expand Down
Loading

0 comments on commit 1b0004b

Please sign in to comment.