diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 5a6efa434e..1a3e6daf23 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -52,6 +52,7 @@ ) import torch +import torch.utils.checkpoint from thunder.core.proxies import ( DistParallelType, proxy, @@ -763,6 +764,41 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype): return res +@register_general_jit_lookaside(torch.utils.checkpoint.checkpoint) +def _general_jit_torch_checkpoint_lookaside( + function: Callable, + *args, + **kwargs: Any, +): + """ + This function does preprocessing of the `function` argument before + dispatching the call to `thunder.torch.checkpoint`. This is necessary + because the `function` is potentially calling into PyTorch functions that + are not yet translated to Thunder. `thunder.torch.checkpoint` is a Thunder + function that can handle only Thunder functions as input. + + Args: + function: The function to be checkpointed. + args: Arguments to the function. + kwargs: Keyword arguments to the function. + + Returns: + The result of calling `thunder.torch.checkpoint` with the preprocessed + `function` and its arguments. + """ + from thunder.torch import checkpoint + + # It should be possible to call the general_thunder_jit here to handle the + # conversion from torch to thunder but it doesn't work now + # See https://github.com/Lightning-AI/lightning-thunder/issues/1126 + # TODO: Convert the function to a Thunder function + def thunder_function(*args, **kwargs): + return unwrap(function)(*args, **kwargs) + + wrapped_thunder_function = wrap_const(thunder_function) + return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) + + # Adds proxy methods # NOTE These methods map to themselves, which prevents the interpreter from looking into them # This is OK because these methods are written in a tracing-safe manner, and trying to diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 97248c75bf..6a3d3d3a76 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -7,6 +7,7 @@ import thunder.core.dtypes as dtypes import thunder.core.devices as devices from thunder.core.baseutils import ProxyInterface +from types import FunctionType OPTREE_NAMESPACE = "thunder" diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 3f8fe0c50b..33b4e4df2c 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1700,6 +1700,42 @@ def func(a, b): get_saved_for_backward_tensors(execution_trace) +def test_torch_checkpoint(): + import torch.utils.checkpoint + import torch._higher_order_ops.wrap + + def fn_to_checkpoint(x): + return x.sin().cos().exp() + + checkpoint_fns = ( + thunder.torch.checkpoint, + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False), + torch.ops.higher_order.tag_activation_checkpoint, + ) + + for checkpoint_fn in checkpoint_fns: + + def f(x): + return checkpoint_fn(fn_to_checkpoint, x) + + x = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True) + jf = thunder.jit(f) + out = jf(x) + + # With activation checkpointing, we are saving only the original input. + # The intermediate values are recomputed during backward pass. + assert len(out.grad_fn.saved_tensors) == 1 + assert out.grad_fn.saved_tensors[0] is x + + g = torch.ones_like(out) + out.backward(g) + + x_ref = x.detach().requires_grad_() + out_ref = fn_to_checkpoint(x_ref) + out_ref.backward(g) + torch.testing.assert_close(x.grad, x_ref.grad) + + def test_inconsistent_output_length_grad_transform(): from thunder.extend import OperatorExecutor from thunder.core.proxies import AnyProxy, TensorProxy diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 2e92c6bbce..ed69b4f096 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -61,6 +61,8 @@ # NOTE torch is a requirement import torch +import torch.utils.checkpoint +import torch._higher_order_ops.wrap import warnings @@ -5199,6 +5201,71 @@ def _unwrap_if_dead(tensor): register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead) +@torchsymbol( + torch.utils.checkpoint.checkpoint, + torch.ops.higher_order.tag_activation_checkpoint, + id="activation_checkpoint", +) +def checkpoint( + function: Callable[..., TensorLike], + *args: TensorLike, + context_fn: None | Callable[..., Any] = None, + debug: None | bool = None, + determinism_check: None | str = None, + preserve_rng_state: None | bool = None, + use_reentrant: bool = False, + **kwargs: Any, +) -> TensorLike: + utils.check( + not use_reentrant, + lambda: "torch.checkpoint: use_reentrant=True is not supported in Thunder", + ) + # NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments + # Let's raise a warning if any of these arguments are passed + if context_fn is not None: + warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored") + if debug is not None: + warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored") + if determinism_check is not None: + warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored") + if preserve_rng_state is not None: + warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored") + return function(*args, **kwargs) + + +@register_augmented_forward( + "activation_checkpoint", +) +def _augmented_forward_checkpoint( + function: Callable[..., TensorLike], + *args: TensorLike, + context_fn: None | Callable[..., Any] = None, + debug: None | bool = None, + determinism_check: None | str = None, + preserve_rng_state: None | bool = None, + use_reentrant: bool = False, + **kwargs: Any, +) -> TensorLike: + result = function(*args, **kwargs) + saved_for_backward = (function, args, kwargs) + return result, saved_for_backward + + +@register_backward( + "activation_checkpoint", +) +def _backward_checkpoint( + function, + args, + kwargs, + *grad_outputs, +) -> tuple[None | TensorLike, ...]: + from thunder.core.transforms import vjp + + result = vjp(function)(args, grad_outputs, **kwargs) + return result + + # # Distributed operations #