Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to return auxiliary data from the primal #720

Open
niklasschmitz opened this issue Feb 7, 2025 · 3 comments
Open

Option to return auxiliary data from the primal #720

niklasschmitz opened this issue Feb 7, 2025 · 3 comments

Comments

@niklasschmitz
Copy link
Contributor

A very common use case is that one wants to not only differentiate an objective, but also get some auxiliary output (intermediate results, the predictions of an ML model, data structures of a PDE solver, etc.)

For example, in JAX there is the has_aux keyword option in jax.value_and_grad, which is actually the most common usage pattern of AD in JAX I have seen. The pattern looks like this (See e.g. the flax docs for a full example in context)

def loss_fn(params):
     ...
     return loss, extra_data

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, extra_data), grads = grad_fn(params)

I typically use some hacky workarounds to achieve similar behavior in Julia, but maybe it is common enough to solve it at the interface level?

@gdalle
Copy link
Member

gdalle commented Feb 7, 2025

The issue I see is that most backends want a single-output function. For instance, to get the gradient of loss_fn with ForwardDiff, I'd have to call ForwardDiff.gradient(p -> loss_fn(p)[1], params), and then we lose the benefit of "side-effect computations".
Which backends can actually discard-but-return this extra_data without calling the function twice?

@niklasschmitz
Copy link
Contributor Author

Fair point! I guess the only way to get the extra_data out of a single call without purely returning it is to extract it by a side effect indeed (i.e. global state / save-to-file / ...). That does seem tricky to do in a nice way for all AD backends.

@gdalle
Copy link
Member

gdalle commented Feb 10, 2025

Idea in passing: maybe we could define a new type of Context which, unlike Cache, would guarantee that it is not overwritten, and allow returning auxiliary data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants