Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial support for torch.utils.checkpoint (#1127)
A checkpointed function doesn't save any intermediates from forward to backward. Instead, all required values are recomputed during the backward pass. Because less intermediates are saved, peak memory usage is usually decreased. This PR introduces the support of recognizing `torch.utils.checkpoint.checkpoint` calls and inserting a new bound symbol in the initial trace. Then in the forward-backward generation pass this bound symbol is converted into augmented forward and backward parts of the computation. This step requires the function argument to `thunder.torch.checkpoint` be a Thunder function. Currently, there's no conversion PyTorch->Thunder implemented and this works only for simple functions that are both recognized by Thunder and PyTorch, for example when only methods are used. The PyTorch function needs to be converted to a Thunder function in Thunder's JIT. Previously we could simply use `thunder.preprocess` which is not available today. When I attempted implementing a redispatching/reinterpretation of PyTorch functions using `general_thunder_jit` I hit the following bug: #1126. Example: ```py import thunder import torch def f(x): return torch.utils.checkpoint.checkpoint(lambda x: x.sin().cos().exp(), x) jf = thunder.jit(f) x = torch.randn(3, 4, device="cuda", requires_grad=True) jf(x).backward(x) print(thunder.last_traces(jf)[-1]) print(thunder.last_backward_traces(jf)[-1]) ``` Forward execution trace: ```py def augmented_forward_fn(x): # x: "cuda:0 f32[3, 4]" [t2] = nvFusion0(x) # t0 = prims.sin(x) # t0: "cuda:0 f32[3, 4]" # t1 = prims.cos(t0) # t1: "cuda:0 f32[3, 4]" # t2 = prims.exp(t1) # t2: "cuda:0 f32[3, 4]" return {'output': t2, 'flat_args': [x], 'flat_output': (t2,)}, ((x,), ()) ``` Backward execution trace: ```py def backward_fn(saved_for_backward, cotangents): # saved_for_backward: "Collection" # cotangents: "Collection" C0, _, = saved_for_backward clear_mutable_collection(saved_for_backward) del saved_for_backward t3, = cotangents clear_mutable_collection(cotangents) del cotangents x, = C0 clear_mutable_collection(C0) del C0 [t12] = nvFusion0(x, t3) # t4 = prims.sin(x) # t4: "cuda:0 f32[3, 4]" # t11 = prims.cos(x) # t11: "cuda:0 f32[3, 4]" # t5 = prims.cos(t4) # t5: "cuda:0 f32[3, 4]" # t8 = prims.sin(t4) # t8: "cuda:0 f32[3, 4]" # t6 = prims.exp(t5) # t6: "cuda:0 f32[3, 4]" # t7 = prims.mul(t3, t6) # t7: "cuda:0 f32[3, 4]" # t9 = prims.neg(t8) # t9: "cuda:0 f32[3, 4]" # t10 = prims.mul(t7, t9) # t10: "cuda:0 f32[3, 4]" # t12 = prims.mul(t10, t11) # t12: "cuda:0 f32[3, 4]" del x, t3 return (t12,) ```
- Loading branch information