Skip to content

Commit

Permalink
Add initial support for torch.utils.checkpoint (#1127)
Browse files Browse the repository at this point in the history
A checkpointed function doesn't save any intermediates from forward to backward. Instead, all required values are recomputed during the backward pass. Because less intermediates are saved, peak memory usage is usually decreased.

This PR introduces the support of recognizing `torch.utils.checkpoint.checkpoint` calls and inserting a new bound symbol in the initial trace. Then in the forward-backward generation pass this bound symbol is converted into augmented forward and backward parts of the computation. This step requires the function argument to `thunder.torch.checkpoint` be a Thunder function. Currently, there's no conversion PyTorch->Thunder implemented and this works only for simple functions that are both recognized by Thunder and PyTorch, for example when only methods are used.

The PyTorch function needs to be converted to a Thunder function in Thunder's JIT. Previously we could simply use `thunder.preprocess` which is not available today. When I attempted implementing a redispatching/reinterpretation of PyTorch functions using `general_thunder_jit` I hit the following bug: #1126.

Example:
```py
import thunder
import torch

def f(x):
    return torch.utils.checkpoint.checkpoint(lambda x: x.sin().cos().exp(), x)

jf = thunder.jit(f)
x = torch.randn(3, 4, device="cuda", requires_grad=True)
jf(x).backward(x)
print(thunder.last_traces(jf)[-1])
print(thunder.last_backward_traces(jf)[-1])
```
Forward execution trace:
```py
def augmented_forward_fn(x):
  # x: "cuda:0 f32[3, 4]"
  [t2] = nvFusion0(x)
    # t0 = prims.sin(x)  # t0: "cuda:0 f32[3, 4]"
    # t1 = prims.cos(t0)  # t1: "cuda:0 f32[3, 4]"
    # t2 = prims.exp(t1)  # t2: "cuda:0 f32[3, 4]"
  return {'output': t2, 'flat_args': [x], 'flat_output': (t2,)}, ((x,), ())
```
Backward execution trace:
```py
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t3, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  x, = C0
  clear_mutable_collection(C0)
  del C0
  [t12] = nvFusion0(x, t3)
    # t4 = prims.sin(x)  # t4: "cuda:0 f32[3, 4]"
    # t11 = prims.cos(x)  # t11: "cuda:0 f32[3, 4]"
    # t5 = prims.cos(t4)  # t5: "cuda:0 f32[3, 4]"
    # t8 = prims.sin(t4)  # t8: "cuda:0 f32[3, 4]"
    # t6 = prims.exp(t5)  # t6: "cuda:0 f32[3, 4]"
    # t7 = prims.mul(t3, t6)  # t7: "cuda:0 f32[3, 4]"
    # t9 = prims.neg(t8)  # t9: "cuda:0 f32[3, 4]"
    # t10 = prims.mul(t7, t9)  # t10: "cuda:0 f32[3, 4]"
    # t12 = prims.mul(t10, t11)  # t12: "cuda:0 f32[3, 4]"
  del x, t3
  return (t12,)
```
  • Loading branch information
IvanYashchuk authored Oct 18, 2024
1 parent 953a914 commit 3f3d46a
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 0 deletions.
36 changes: 36 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)

import torch
import torch.utils.checkpoint
from thunder.core.proxies import (
DistParallelType,
proxy,
Expand Down Expand Up @@ -763,6 +764,41 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype):
return res


@register_general_jit_lookaside(torch.utils.checkpoint.checkpoint)
def _general_jit_torch_checkpoint_lookaside(
function: Callable,
*args,
**kwargs: Any,
):
"""
This function does preprocessing of the `function` argument before
dispatching the call to `thunder.torch.checkpoint`. This is necessary
because the `function` is potentially calling into PyTorch functions that
are not yet translated to Thunder. `thunder.torch.checkpoint` is a Thunder
function that can handle only Thunder functions as input.
Args:
function: The function to be checkpointed.
args: Arguments to the function.
kwargs: Keyword arguments to the function.
Returns:
The result of calling `thunder.torch.checkpoint` with the preprocessed
`function` and its arguments.
"""
from thunder.torch import checkpoint

# It should be possible to call the general_thunder_jit here to handle the
# conversion from torch to thunder but it doesn't work now
# See https://github.com/Lightning-AI/lightning-thunder/issues/1126
# TODO: Convert the function to a Thunder function
def thunder_function(*args, **kwargs):
return unwrap(function)(*args, **kwargs)

wrapped_thunder_function = wrap_const(thunder_function)
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)


# Adds proxy methods
# NOTE These methods map to themselves, which prevents the interpreter from looking into them
# This is OK because these methods are written in a tracing-safe manner, and trying to
Expand Down
1 change: 1 addition & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.baseutils import ProxyInterface
from types import FunctionType

OPTREE_NAMESPACE = "thunder"

Expand Down
36 changes: 36 additions & 0 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,42 @@ def func(a, b):
get_saved_for_backward_tensors(execution_trace)


def test_torch_checkpoint():
import torch.utils.checkpoint
import torch._higher_order_ops.wrap

def fn_to_checkpoint(x):
return x.sin().cos().exp()

checkpoint_fns = (
thunder.torch.checkpoint,
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False),
torch.ops.higher_order.tag_activation_checkpoint,
)

for checkpoint_fn in checkpoint_fns:

def f(x):
return checkpoint_fn(fn_to_checkpoint, x)

x = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True)
jf = thunder.jit(f)
out = jf(x)

# With activation checkpointing, we are saving only the original input.
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.saved_tensors) == 1
assert out.grad_fn.saved_tensors[0] is x

g = torch.ones_like(out)
out.backward(g)

x_ref = x.detach().requires_grad_()
out_ref = fn_to_checkpoint(x_ref)
out_ref.backward(g)
torch.testing.assert_close(x.grad, x_ref.grad)


def test_inconsistent_output_length_grad_transform():
from thunder.extend import OperatorExecutor
from thunder.core.proxies import AnyProxy, TensorProxy
Expand Down
67 changes: 67 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@

# NOTE torch is a requirement
import torch
import torch.utils.checkpoint
import torch._higher_order_ops.wrap

import warnings

Expand Down Expand Up @@ -5199,6 +5201,71 @@ def _unwrap_if_dead(tensor):
register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead)


@torchsymbol(
torch.utils.checkpoint.checkpoint,
torch.ops.higher_order.tag_activation_checkpoint,
id="activation_checkpoint",
)
def checkpoint(
function: Callable[..., TensorLike],
*args: TensorLike,
context_fn: None | Callable[..., Any] = None,
debug: None | bool = None,
determinism_check: None | str = None,
preserve_rng_state: None | bool = None,
use_reentrant: bool = False,
**kwargs: Any,
) -> TensorLike:
utils.check(
not use_reentrant,
lambda: "torch.checkpoint: use_reentrant=True is not supported in Thunder",
)
# NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments
# Let's raise a warning if any of these arguments are passed
if context_fn is not None:
warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored")
if debug is not None:
warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored")
if determinism_check is not None:
warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored")
if preserve_rng_state is not None:
warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored")
return function(*args, **kwargs)


@register_augmented_forward(
"activation_checkpoint",
)
def _augmented_forward_checkpoint(
function: Callable[..., TensorLike],
*args: TensorLike,
context_fn: None | Callable[..., Any] = None,
debug: None | bool = None,
determinism_check: None | str = None,
preserve_rng_state: None | bool = None,
use_reentrant: bool = False,
**kwargs: Any,
) -> TensorLike:
result = function(*args, **kwargs)
saved_for_backward = (function, args, kwargs)
return result, saved_for_backward


@register_backward(
"activation_checkpoint",
)
def _backward_checkpoint(
function,
args,
kwargs,
*grad_outputs,
) -> tuple[None | TensorLike, ...]:
from thunder.core.transforms import vjp

result = vjp(function)(args, grad_outputs, **kwargs)
return result


#
# Distributed operations
#
Expand Down

0 comments on commit 3f3d46a

Please sign in to comment.