-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add "selected_ops" transformer AC mode (#71)
- Loading branch information
Showing
4 changed files
with
54 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from collections import defaultdict | ||
from typing import Dict | ||
|
||
import torch | ||
|
||
|
||
def _get_custom_checkpoint_policy(meta: Dict[str, int]): | ||
# Adapted from | ||
# https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py | ||
from torch.utils.checkpoint import CheckpointPolicy | ||
|
||
_save_list = { | ||
torch.ops.aten.mm.default, # type: ignore | ||
torch.ops.aten._scaled_dot_product_efficient_attention.default, # type: ignore | ||
torch.ops.aten._scaled_dot_product_flash_attention.default, # type: ignore | ||
torch.ops._c10d_functional.reduce_scatter_tensor.default, # type: ignore | ||
} | ||
|
||
def _custom_policy(ctx, func, *args, **kwargs): | ||
del args, kwargs | ||
mode = "recompute" if ctx.is_recompute else "forward" | ||
mm_count_key = f"{mode}_mm_count" | ||
if func == torch.ops.aten.mm.default: # type: ignore | ||
meta[mm_count_key] += 1 | ||
# Saves output of all compute ops, except every second mm | ||
to_save = func in _save_list and not ( | ||
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 # type: ignore | ||
) | ||
return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE | ||
|
||
return _custom_policy | ||
|
||
|
||
def selective_checkpointing_context_fn(): | ||
from torch.utils.checkpoint import create_selective_checkpoint_contexts | ||
|
||
meta: Dict[str, int] = defaultdict(int) | ||
return create_selective_checkpoint_contexts(_get_custom_checkpoint_policy(meta)) |