Skip to content

Commit

Permalink
simplify custom_scalar_root + add jacrev test
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Kleinhenz committed Aug 28, 2024
1 parent 31d3d07 commit 4764b34
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
35 changes: 17 additions & 18 deletions src/beignet/func/_custom_scalar_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,29 @@ def backward(ctx, *grad_outputs):
xstar, *args = ctx.saved_tensors
nargs = len(args)

# optimality condition:
# f(x^*(theta), theta) = 0

argnums = tuple(range(nargs + 1))
xstar, *args = torch.atleast_1d(xstar, *args)
xstar, *args = torch.broadcast_tensors(xstar, *args)
shape = xstar.shape

a, *b = torch.func.jacrev(f, argnums=argnums)(xstar, *args)
xstar = xstar.view(-1)
args = (arg.view(-1) for arg in args)

match a.ndim:
case 0:
output = ()
argnums = tuple(range(nargs + 1))

for g, b2 in zip(grad_outputs, b, strict=True):
output = (*output, -g * b2 / a)
# optimality condition:
# f(x^*(theta), theta) = 0

return output
case 2: # NOTE: `a` is diagonal because `f` is scalar
output = ()
# because f is applied elementwise just compute diagonal of jacobian
a, *b = torch.vmap(
torch.func.grad(f, argnums=argnums), in_dims=(0,) * (nargs + 1)
)(xstar, *args)

for g, b2 in zip(grad_outputs, b, strict=True):
output = (*output, torch.linalg.solve(a, -g * b2))
output = tuple(
(-g * b2 / a).view(*shape)
for g, b2 in zip(grad_outputs, b, strict=True)
)

return output
case _:
raise ValueError
return output

@staticmethod
def vmap(info, in_dims, *args):
Expand Down
17 changes: 17 additions & 0 deletions tests/beignet/test__root_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,20 @@ def test_root_scalar_grad(method):
expected = torch.func.vmap(torch.func.grad(xstar))(c)

torch.testing.assert_close(grad, expected)


@pytest.mark.parametrize("method", ["bisect", "chandrupatla"])
def test_root_scalar_jacrev(method):
c = torch.linspace(1.0, 10.0, 101, dtype=torch.float64)

lower = 0.0
upper = 5.0
options = {"lower": lower, "upper": upper, "dtype": torch.float64}

jac = torch.func.jacrev(
lambda c: beignet.root_scalar(f, c, method=method, options=options)
)(c)

expected = torch.func.vmap(torch.func.grad(xstar))(c)

torch.testing.assert_close(torch.diag(jac), expected)

0 comments on commit 4764b34

Please sign in to comment.