diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 26aceb797..192b5e537 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -3315,6 +3315,7 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): set_attn_partition_specs(layer_cfg.cross_attention.attention) if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config): set_ffn_partition_specs(layer_cfg.feed_forward) + # pytype: enable=attribute-error diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index cc9798afd..1664bc2dc 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -56,6 +56,7 @@ Tensor, partial_with_fn_metadata, with_sharding_constraint, + maybe_shard ) # TODO(dhwang2): remove them. @@ -331,6 +332,10 @@ class Config(BaseNormalizationLayer.Config): eps: float = 1e-8 # Cast input to this dtype for the 'forward' call. If None, do not cast. forward_dtype: Optional[jnp.dtype] = jnp.float32 + # If not None, how to partition input activation values. + input_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition output activation values. + output_partition_spec: Optional[tuple[Optional[str]]] = None def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config @@ -341,6 +346,7 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor: del paddings # paddings do not affect LayerNorm results cfg = self.config + x = maybe_shard(x, cfg.input_partition_spec) x_dtype = x.dtype if cfg.forward_dtype is not None: x = x.astype(cfg.forward_dtype) @@ -348,6 +354,7 @@ def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor: x = x * jax.lax.rsqrt(moment2 + cfg.eps) x = x.astype(x_dtype) x = x * self.parameters["scale"] + x = maybe_shard(x, cfg.output_partition_spec) return x @@ -780,6 +787,12 @@ class Config(BaseLayer.Config): num_embeddings: Required[int] = REQUIRED # Maximum number of embeddings in table. dim: Required[int] = REQUIRED # Embedding vector dimensionality. + # If not None, how to partition input activation values. + input_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition embedding table. + embedding_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition output activation values. + output_partition_spec: Optional[tuple[Optional[str]]] = None @classmethod def default_config(cls): @@ -814,8 +827,13 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: ) def forward(self, x: Tensor) -> Tensor: + cfg = self.config + x = maybe_shard(x, cfg.input_partition_spec) emb = self.parameters["weight"] - return emb[x] + emb = maybe_shard(emb, cfg.embedding_partition_spec) + activation = emb[x] + activation = maybe_shard(activation, cfg.output_partition_spec) + return activation def attend(self, x: Tensor) -> Tensor: """Apply query array 'x' to the embedding weight array. diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 1401337f8..c7fcef7dd 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -443,6 +443,10 @@ def with_sharding_constraint(x, shardings): return x return jax.lax.with_sharding_constraint(x, shardings) +def maybe_shard(x, partition_spec) -> Tensor: + if partition_spec is None: + return x + return with_sharding_constraint(x, PartitionSpec(*partition_spec)) def replicate_to_local_data(x: NestedTensor) -> NestedTensor: """Replicates and converts Tensors in `x` to local DeviceArrays.