Skip to content

Commit

Permalink
use default policy
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 13, 2024
1 parent 68c6ee9 commit 9e130a8
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4190,7 +4190,7 @@ def build_remat_spec(
)
else:
checkpoints.extend(
[f"{attention_name}.{el}" for el in ['q_proj', 'k_proj', 'v_proj']] + ["input_to_qkvee", "TransformerAttentionLayer.residual_add", "TransformerFeedForwardLayer.mlp_residual"]
[f"{attention_name}.{el}" for el in ['q_proj', 'k_proj', 'v_proj']] + ["TransformerAttentionLayer.residual_add", "TransformerFeedForwardLayer.mlp_residual"]
)
if feed_forward and hasattr(stack_cfg.layer, "feed_forward"):
ffn_name = stack_cfg.layer.feed_forward.klass.__name__
Expand All @@ -4199,14 +4199,9 @@ def build_remat_spec(
else:
checkpoints.extend([f"{ffn_name}.{el}" for el in ["linear1_0", "linear1_1"]])

if backend != "neuron":
policy = config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
)
else:
policy = config_for_function(save_only_these).set(
names_to_save=checkpoints
)
policy = config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
)

if offload_dst:
policy = config_for_function(jax_remat_policies.save_and_offload_only_these_names).set(
Expand Down

0 comments on commit 9e130a8

Please sign in to comment.