Skip to content

Commit

Permalink
match number of backward return values to number of forward args (#1300)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Oct 15, 2024
1 parent 572c222 commit 40904c7
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def backward(ctx, grad_output):
(x, weight) = ctx.saved_tensors
assert weight.shape == ctx.shape # really bogus, just to use ctx.shape
scaler2 = ctx.shape[0] / ctx.shape[1]
return torch.matmul(grad_output, weight) * ctx.scaler, torch.matmul(grad_output.t(), x) / scaler2
return torch.matmul(grad_output, weight) * ctx.scaler, torch.matmul(grad_output.t(), x) / scaler2, None

class Model(torch.nn.Module):
def __init__(self):
Expand All @@ -1199,9 +1199,14 @@ def forward(self, x):
x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model)

gradcheck(jitted, (x,))

jitted.zero_grad()
x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
out = jitted(x)
out.backward(torch.rand_like(out))
assert jitted.l1.weight.grad is not None


def test_autograd_function_apply():

Expand Down

0 comments on commit 40904c7

Please sign in to comment.