Skip to content

Commit

Permalink
Support fine grained activation sharding. (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-toulme committed Dec 16, 2024
1 parent 73625c9 commit a3b3024
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3314,7 +3314,7 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config):
if layer_cfg.cross_attention is not None:
set_attn_partition_specs(layer_cfg.cross_attention.attention)
if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config):
set_ffn_partition_specs(layer_cfg.feed_forward)
set_ffn_partition_specs(layer_cfg.feed_forward)
# pytype: enable=attribute-error


Expand Down
20 changes: 19 additions & 1 deletion axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
Tensor,
partial_with_fn_metadata,
with_sharding_constraint,
maybe_shard
)

# TODO(dhwang2): remove them.
Expand Down Expand Up @@ -331,6 +332,10 @@ class Config(BaseNormalizationLayer.Config):
eps: float = 1e-8
# Cast input to this dtype for the 'forward' call. If None, do not cast.
forward_dtype: Optional[jnp.dtype] = jnp.float32
# If not None, how to partition input activation values.
input_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition output activation values.
output_partition_spec: Optional[tuple[Optional[str]]] = None

def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
cfg = self.config
Expand All @@ -341,13 +346,15 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor:
del paddings # paddings do not affect LayerNorm results
cfg = self.config
x = maybe_shard(x, cfg.input_partition_spec)
x_dtype = x.dtype
if cfg.forward_dtype is not None:
x = x.astype(cfg.forward_dtype)
moment2 = (x * x).mean(axis=-1, keepdims=True)
x = x * jax.lax.rsqrt(moment2 + cfg.eps)
x = x.astype(x_dtype)
x = x * self.parameters["scale"]
x = maybe_shard(x, cfg.output_partition_spec)
return x


Expand Down Expand Up @@ -780,6 +787,12 @@ class Config(BaseLayer.Config):

num_embeddings: Required[int] = REQUIRED # Maximum number of embeddings in table.
dim: Required[int] = REQUIRED # Embedding vector dimensionality.
# If not None, how to partition input activation values.
input_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition embedding table.
embedding_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition output activation values.
output_partition_spec: Optional[tuple[Optional[str]]] = None

@classmethod
def default_config(cls):
Expand Down Expand Up @@ -814,8 +827,13 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
)

def forward(self, x: Tensor) -> Tensor:
cfg = self.config
x = maybe_shard(x, cfg.input_partition_spec)
emb = self.parameters["weight"]
return emb[x]
emb = maybe_shard(emb, cfg.embedding_partition_spec)
activation = emb[x]
activation = maybe_shard(activation, cfg.output_partition_spec)
return activation

def attend(self, x: Tensor) -> Tensor:
"""Apply query array 'x' to the embedding weight array.
Expand Down
4 changes: 4 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ def with_sharding_constraint(x, shardings):
return x
return jax.lax.with_sharding_constraint(x, shardings)

def maybe_shard(x, partition_spec) -> Tensor:
if partition_spec is None:
return x
return with_sharding_constraint(x, PartitionSpec(*partition_spec))

def replicate_to_local_data(x: NestedTensor) -> NestedTensor:
"""Replicates and converts Tensors in `x` to local DeviceArrays.
Expand Down

0 comments on commit a3b3024

Please sign in to comment.