From a3b3024fe65afded1aaf21f658742df4e72fe203 Mon Sep 17 00:00:00 2001
From: Patrick Toulme <135739773+ptoulme-aws@users.noreply.github.com>
Date: Fri, 22 Nov 2024 12:18:12 -0800
Subject: [PATCH] Support fine grained activation sharding. (#21)

---
 axlearn/common/attention.py |  2 +-
 axlearn/common/layers.py    | 20 +++++++++++++++++++-
 axlearn/common/utils.py     |  4 ++++
 3 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py
index 26aceb797..15b3c2046 100644
--- a/axlearn/common/attention.py
+++ b/axlearn/common/attention.py
@@ -3314,7 +3314,7 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config):
         if layer_cfg.cross_attention is not None:
             set_attn_partition_specs(layer_cfg.cross_attention.attention)
         if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config):
-            set_ffn_partition_specs(layer_cfg.feed_forward)
+            set_ffn_partition_specs(layer_cfg.feed_forward)        
     # pytype: enable=attribute-error
 
 
diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py
index cc9798afd..1664bc2dc 100644
--- a/axlearn/common/layers.py
+++ b/axlearn/common/layers.py
@@ -56,6 +56,7 @@
     Tensor,
     partial_with_fn_metadata,
     with_sharding_constraint,
+    maybe_shard
 )
 
 # TODO(dhwang2): remove them.
@@ -331,6 +332,10 @@ class Config(BaseNormalizationLayer.Config):
         eps: float = 1e-8
         # Cast input to this dtype for the 'forward' call. If None, do not cast.
         forward_dtype: Optional[jnp.dtype] = jnp.float32
+        # If not None, how to partition input activation values.
+        input_partition_spec: Optional[tuple[Optional[str]]] = None
+        # If not None, how to partition output activation values.
+        output_partition_spec: Optional[tuple[Optional[str]]] = None
 
     def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
         cfg = self.config
@@ -341,6 +346,7 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
     def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor:
         del paddings  # paddings do not affect LayerNorm results
         cfg = self.config
+        x = maybe_shard(x, cfg.input_partition_spec)
         x_dtype = x.dtype
         if cfg.forward_dtype is not None:
             x = x.astype(cfg.forward_dtype)
@@ -348,6 +354,7 @@ def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor:
         x = x * jax.lax.rsqrt(moment2 + cfg.eps)
         x = x.astype(x_dtype)
         x = x * self.parameters["scale"]
+        x = maybe_shard(x, cfg.output_partition_spec)
         return x
 
 
@@ -780,6 +787,12 @@ class Config(BaseLayer.Config):
 
         num_embeddings: Required[int] = REQUIRED  # Maximum number of embeddings in table.
         dim: Required[int] = REQUIRED  # Embedding vector dimensionality.
+        # If not None, how to partition input activation values.
+        input_partition_spec: Optional[tuple[Optional[str]]] = None
+        # If not None, how to partition embedding table.
+        embedding_partition_spec: Optional[tuple[Optional[str]]] = None
+        # If not None, how to partition output activation values.
+        output_partition_spec: Optional[tuple[Optional[str]]] = None
 
     @classmethod
     def default_config(cls):
@@ -814,8 +827,13 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
         )
 
     def forward(self, x: Tensor) -> Tensor:
+        cfg = self.config
+        x = maybe_shard(x, cfg.input_partition_spec)
         emb = self.parameters["weight"]
-        return emb[x]
+        emb = maybe_shard(emb, cfg.embedding_partition_spec)
+        activation = emb[x]
+        activation = maybe_shard(activation, cfg.output_partition_spec)
+        return activation
 
     def attend(self, x: Tensor) -> Tensor:
         """Apply query array 'x' to the embedding weight array.
diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py
index 1401337f8..c7fcef7dd 100644
--- a/axlearn/common/utils.py
+++ b/axlearn/common/utils.py
@@ -443,6 +443,10 @@ def with_sharding_constraint(x, shardings):
         return x
     return jax.lax.with_sharding_constraint(x, shardings)
 
+def maybe_shard(x, partition_spec) -> Tensor:
+    if partition_spec is None:
+        return x
+    return with_sharding_constraint(x, PartitionSpec(*partition_spec))
 
 def replicate_to_local_data(x: NestedTensor) -> NestedTensor:
     """Replicates and converts Tensors in `x` to local DeviceArrays.