From 4764b34b4838a04d5df43c1f9289eb7922463f57 Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Wed, 28 Aug 2024 10:13:25 -0700 Subject: [PATCH] simplify custom_scalar_root + add jacrev test --- src/beignet/func/_custom_scalar_root.py | 35 ++++++++++++------------- tests/beignet/test__root_scalar.py | 17 ++++++++++++ 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/beignet/func/_custom_scalar_root.py b/src/beignet/func/_custom_scalar_root.py index 4b3838e4be..dc8ea438fe 100644 --- a/src/beignet/func/_custom_scalar_root.py +++ b/src/beignet/func/_custom_scalar_root.py @@ -31,30 +31,29 @@ def backward(ctx, *grad_outputs): xstar, *args = ctx.saved_tensors nargs = len(args) - # optimality condition: - # f(x^*(theta), theta) = 0 - - argnums = tuple(range(nargs + 1)) + xstar, *args = torch.atleast_1d(xstar, *args) + xstar, *args = torch.broadcast_tensors(xstar, *args) + shape = xstar.shape - a, *b = torch.func.jacrev(f, argnums=argnums)(xstar, *args) + xstar = xstar.view(-1) + args = (arg.view(-1) for arg in args) - match a.ndim: - case 0: - output = () + argnums = tuple(range(nargs + 1)) - for g, b2 in zip(grad_outputs, b, strict=True): - output = (*output, -g * b2 / a) + # optimality condition: + # f(x^*(theta), theta) = 0 - return output - case 2: # NOTE: `a` is diagonal because `f` is scalar - output = () + # because f is applied elementwise just compute diagonal of jacobian + a, *b = torch.vmap( + torch.func.grad(f, argnums=argnums), in_dims=(0,) * (nargs + 1) + )(xstar, *args) - for g, b2 in zip(grad_outputs, b, strict=True): - output = (*output, torch.linalg.solve(a, -g * b2)) + output = tuple( + (-g * b2 / a).view(*shape) + for g, b2 in zip(grad_outputs, b, strict=True) + ) - return output - case _: - raise ValueError + return output @staticmethod def vmap(info, in_dims, *args): diff --git a/tests/beignet/test__root_scalar.py b/tests/beignet/test__root_scalar.py index a8f2cfbb12..24fc7bb7c4 100644 --- a/tests/beignet/test__root_scalar.py +++ b/tests/beignet/test__root_scalar.py @@ -44,3 +44,20 @@ def test_root_scalar_grad(method): expected = torch.func.vmap(torch.func.grad(xstar))(c) torch.testing.assert_close(grad, expected) + + +@pytest.mark.parametrize("method", ["bisect", "chandrupatla"]) +def test_root_scalar_jacrev(method): + c = torch.linspace(1.0, 10.0, 101, dtype=torch.float64) + + lower = 0.0 + upper = 5.0 + options = {"lower": lower, "upper": upper, "dtype": torch.float64} + + jac = torch.func.jacrev( + lambda c: beignet.root_scalar(f, c, method=method, options=options) + )(c) + + expected = torch.func.vmap(torch.func.grad(xstar))(c) + + torch.testing.assert_close(torch.diag(jac), expected)