Skip to content

Commit

Permalink
Maybe shard bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-toulme committed Dec 5, 2024
1 parent 002e63e commit b027d07
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 31 deletions.
39 changes: 27 additions & 12 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2683,15 +2683,15 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
return atten_state, atten_output

if cfg.structure == "prenorm":
# target = maybe_shard(target, *cfg.prenorm_partition_spec)
target = with_sharding_constraint(target, PartitionSpec('fsdp','model',None))
target = maybe_shard(target, cfg.prenorm_partition_spec)
#target = with_sharding_constraint(target, PartitionSpec('fsdp','model',None))
skip_input = target # pre-norm: where normalization happens within the residual part.
norm_target = self.norm(target)
norm_target = with_sharding_constraint(norm_target, PartitionSpec('fsdp',None,None))
# norm_target = maybe_shard(norm_target, *cfg.preattention_partition_spec)
#norm_target = with_sharding_constraint(norm_target, PartitionSpec('fsdp',None,None))
norm_target = maybe_shard(norm_target, cfg.preattention_partition_spec)
atten_state, atten_output = attention_thunk(norm_target)
atten_output = with_sharding_constraint(atten_output, PartitionSpec('fsdp','model',None))
# atten_output = maybe_shard(atten_output, *cfg.postattention_partition_spec)
#atten_output = with_sharding_constraint(atten_output, PartitionSpec('fsdp','model',None))
atten_output = maybe_shard(atten_output, cfg.postattention_partition_spec)
data = skip_input + self.stochastic_depth(self.dropout(atten_output.data))
elif cfg.structure == "postnorm":
# This is the structure used by the original Transformer, BERT, and RoBERTa.
Expand Down Expand Up @@ -3011,18 +3011,18 @@ def _linear2(x):
remat_pt1 = "activation"
remat_pt2 = "linear2"
if cfg.structure == "prenorm":
# inputs = maybe_shard(inputs, *cfg.prenorm_partition_spec)
x = with_sharding_constraint(inputs, PartitionSpec('fsdp','model',None))
inputs = maybe_shard(inputs, cfg.prenorm_partition_spec)
# x = with_sharding_constraint(inputs, PartitionSpec('fsdp','model',None))
x = self.norm(inputs)
# x = maybe_shard(x, *cfg.premlp_partition_spec)
x = with_sharding_constraint(x, PartitionSpec('fsdp',None,None))
x = maybe_shard(x, cfg.premlp_partition_spec)
#x = with_sharding_constraint(x, PartitionSpec('fsdp',None,None))
x = self._linear1_activation(x)
x = self._remat_name(x, remat_pt1)
x = self.dropout1(x)
x = _linear2(x)
x = self._remat_name(x, remat_pt2)
x = with_sharding_constraint(x, PartitionSpec('fsdp','model',None))
# x = maybe_shard(x, *cfg.postmlp_partition_spec)
#x = with_sharding_constraint(x, PartitionSpec('fsdp','model',None))
x = maybe_shard(x, cfg.postmlp_partition_spec)
x = self.dropout2(x)
x = self.stochastic_depth(x)
if cfg.residual_weight != 1:
Expand Down Expand Up @@ -3570,6 +3570,21 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config):
set_attn_partition_specs(layer_cfg.cross_attention.attention)
if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config):
set_ffn_partition_specs(layer_cfg.feed_forward)

# Neuron backend needs fine grained activation sharding.
if jax.default_backend() == 'neuron':
prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None)
preattention_partition_spec = (fsdp_axis_names, None, None)
postattention_partition_spec = (fsdp_axis_names, tp_axis_names, None)

layer_cfg.self_attention.set(
prenorm_partition_spec=prenorm_partition_spec,
preattention_partition_spec=preattention_partition_spec,
postattention_partition_spec=postattention_partition_spec)
layer_cfg.feed_forward.set(
prenorm_partition_spec=prenorm_partition_spec,
premlp_partition_spec=preattention_partition_spec,
postmlp_partition_spec=postattention_partition_spec)
# pytype: enable=attribute-error


Expand Down
3 changes: 1 addition & 2 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,7 @@ def with_sharding_constraint(x, shardings):
def maybe_shard(x, partition_spec) -> Tensor:
if partition_spec is None:
return x
assert len(x.shape) == len(partition_spec)
return with_sharding_constraint(x, PartitionSpec(partition_spec))
return with_sharding_constraint(x, PartitionSpec(*partition_spec))

def replicate_to_local_data(x: NestedTensor) -> NestedTensor:
"""Replicates and converts Tensors in `x` to local DeviceArrays.
Expand Down
17 changes: 0 additions & 17 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,23 +513,6 @@ def model_config(
)
atten_qkv_linear.rope_pos_emb_layer.theta = rope_theta

# batch_axis_names=("data", ("replica", "data", "fsdp"))
# fsdp_axis_names=("fsdp")
# tp_axis_names=("model")
# seq_axis_names=("seq",)
# prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None)
# preattention_partition_spec = (fsdp_axis_names, None, None)
# postattention_partition_spec = (fsdp_axis_names, tp_axis_names, None)
# layer_cfg=TransformerLayer.default_config()
# layer_cfg.self_attention.set(
# prenorm_partition_spec=prenorm_partition_spec,
# preattention_partition_spec=preattention_partition_spec,
# postattention_partition_spec=postattention_partition_spec)
# layer_cfg.feed_forward.set(
# prenorm_partition_spec=prenorm_partition_spec,
# premlp_partition_spec=preattention_partition_spec,
# postmlp_partition_spec=postattention_partition_spec),

cfg = common_model_config(
num_layers=num_layers,
hidden_dim=hidden_dim,
Expand Down

0 comments on commit b027d07

Please sign in to comment.