Skip to content

Commit

Permalink
Fix thunder.torch.checkpoint to support multiple arguments (#1391)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk authored Nov 1, 2024
1 parent a24e86e commit c5f8bf7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
18 changes: 11 additions & 7 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,8 +1704,8 @@ def test_torch_checkpoint():
import torch.utils.checkpoint
import torch._higher_order_ops.wrap

def fn_to_checkpoint(x):
return x.sin().cos().exp()
def fn_to_checkpoint(x, y):
return x.sin().cos().exp().mul(y)

checkpoint_fns = (
thunder.torch.checkpoint,
Expand All @@ -1715,26 +1715,30 @@ def fn_to_checkpoint(x):

for checkpoint_fn in checkpoint_fns:

def f(x):
return checkpoint_fn(fn_to_checkpoint, x)
def f(x, y):
return checkpoint_fn(fn_to_checkpoint, x, y)

x = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True)
y = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True)
jf = thunder.jit(f)
out = jf(x)
out = jf(x, y)

# With activation checkpointing, we are saving only the original input.
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.saved_tensors) == 1
assert len(out.grad_fn.saved_tensors) == 2
# We detach the saved tensors (which returns a new Python tensor backed by same storage)
assert out.grad_fn.saved_tensors[0].data_ptr() == x.data_ptr()
assert out.grad_fn.saved_tensors[1].data_ptr() == y.data_ptr()

g = torch.ones_like(out)
out.backward(g)

x_ref = x.detach().requires_grad_()
out_ref = fn_to_checkpoint(x_ref)
y_ref = y.detach().requires_grad_()
out_ref = fn_to_checkpoint(x_ref, y_ref)
out_ref.backward(g)
torch.testing.assert_close(x.grad, x_ref.grad)
torch.testing.assert_close(y.grad, y_ref.grad)


def test_inconsistent_output_length_grad_transform():
Expand Down
4 changes: 2 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5301,8 +5301,8 @@ def _backward_checkpoint(
) -> tuple[None | TensorLike, ...]:
from thunder.core.transforms import vjp

result = vjp(function)(args, grad_outputs, **kwargs)
return result
_, grads = vjp(function)(args, grad_outputs, **kwargs)
return grads


#
Expand Down

0 comments on commit c5f8bf7

Please sign in to comment.