Skip to content

Commit

Permalink
[activation checkpointing] Add default autocast keys to functional rn…
Browse files Browse the repository at this point in the history
…g wrappers (pytorch#107934)

Pull Request resolved: pytorch#107934
Approved by: https://github.com/xw285cornell
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Aug 25, 2023
1 parent 3992450 commit 78a053b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
24 changes: 24 additions & 0 deletions test/dynamo/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,30 @@ def fn(x, y):
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
self.assertEqual(len(wrap_node.args), 3)

@requires_cuda()
def test_autocast_flash_attention(self):
def fn(primals_1, primals_2, primals_3):
return torch.ops.aten._scaled_dot_product_efficient_attention.default(
primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687
)[0]

def gn(*args):
return torch.utils.checkpoint.checkpoint(fn, *args)

with torch.cuda.amp.autocast():
x = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
y = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
z = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
args = (x, y, z)

torch.manual_seed(0)
ref = gn(*args)

opt_gn = torch.compile(gn)
torch.manual_seed(0)
res = opt_gn(*args)
self.assertEqual(ref, res)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
4 changes: 4 additions & 0 deletions torch/_prims/rng_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def register_run_and_save_rng_state_op():
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")

run_and_save_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
run_and_save_rng_state.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
run_and_save_rng_state.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
run_and_save_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
run_and_save_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]

Expand Down Expand Up @@ -219,6 +221,8 @@ def register_run_with_rng_state_op():
run_with_rng_state = HigherOrderOperator("run_with_rng_state")

run_with_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
run_with_rng_state.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
run_with_rng_state.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
run_with_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
run_with_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]

Expand Down

0 comments on commit 78a053b

Please sign in to comment.