-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for AD backends and explicit optimizers #2083
Add support for AD backends and explicit optimizers #2083
Conversation
My main motivation here was to make the AD-agnostic piece of the puzzle simpler. One concern here is AbstractDifferentiation.jl needs more time in the oven. I think since we provide our own backends for Zygote implicit/explicit, and since we don't actually rely on any code in AbstractDifferentiation.jl to take gradients (see Step 2 above), the main feature that it provides here is a smooth transition for when it is ready. |
|
||
# this is a hack to get around | ||
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 | ||
AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be value_and_gradient
to support changes like #2070?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite, because it runs into the issue you mentioned in the link above the code. I could define both gradient
and value_and_gradient
to essentially block out AbstractDifferentiation until they sort out the primitives issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, it might make sense to have Flux.gradient
and Flux.withgradient
that defaults to AD.gradient
and AD.value_and_gradient
. Right now, Flux.gradient(f, xs...)
wouldn't default to ZygoteImplicitBackend
. Defining our own method would allow us to do this.
f261979
to
37c9759
Compare
I wrote some comments which I should post: Complicated Flux models currently use Without
If I understand right, this PR wants to introduce a different mechanism for changing what AD is used, and what mode in the case of Zygote. You would do both by passing a token to
|
We want three things (in order): (1) to use rules from Optimisers.jl, (2) to use explicit gradients from Zygote, and (3) to swap AD backends. If
What this PR aims to do is (a) make implicit vs. explicit Zygote equivalent to using two different AD backends, and (b) "pick" a mechanism for (3) that can be ecosystem-wide (in this case using AbstractDifferentiation.jl). AD has always been the thorn in our side, so a robust solution to (3) is in my mind something we want now. Instead of inventing our own, this PR wants to use and improve AbstractDifferentiation.jl. At the same time, doing (a) means that we don't have to also juggle around Zygote's mode as a separate axis. The fact that #2082 has |
Answering the more specific comments separately:
Alternatively,
Well if you don't care to switch ADs, then there is nothing to know or remember. I think this version is easier to learn and use than knowing to overload
The point is to coalesce around
Unless we plan on killing implicit mode within Zygote itself, I don't see a reason to forbid it. There doesn't need to be |
@darsnack and I had a productive but unfortunately too short discussion on this last ML call, so putting some follow-up thoughts here. Both #2082 and #2083 seek to have training with implicit and explicit params use roughly the same interface. Same function arity, same number of return values, etc. Given we're planning on removing
This means no back and forth conversions. The |
Yes, this is now the present state after #2082. There are doc notes but no dep warnings. They could be added in last version of 0.13? Should be in
The reason to drop it entirely from 0.14 is that this lets us delete Flux.Optimise, and all of its duplicate code for how Adam works, etc. Keeping a path for implicit parameters while having only one Adam definition (in Optimisers.jl) means writing some new code to make corresponding Other uses of Zygote can of course use it as they wish. If ripping implicit parameters out of Zygote completely led to some substantial improvement, then that could be considered as a major change. But realistically nobody is going to get around to trying.
Note that such overloading was never a proposed API. The initial proposal #2082 (now removed) was that there be a macro
Yes. If they do settle on some nice high-level scheme, then one future for The other possible future is that |
Yeah I think we are all in complete agreement, so I will close this one now. Most of what is here is already done or belongs in the AD packages whenever the high level interface comes together. |
This is another approach in a similar vein to #2029 and #2082. The primary goal of this PR is to focus on the use of alternate AD backends, since this is necessary for explicit mode training. Note that while this PR does add explicit support to
train!
, it does not tackle transitioning the optimizers to using only Optimisers.jl. As such, this PR could be merged and #2082 put on top of it (so it is not a complete replacement for the other PRs).The changes allow Flux to be used with explicit mode by passing
Flux.Optimise.ZygoteExplicitBackend()
totrain!
:From the user's perspective, the approach taken here is to explicitly (pun not intended) require the correct things to be passed to train (the AD backend and optimizer state tree). If there's a mismatch, then errors are thrown. The default backend is the implicit mode for Zygote. Since AbstractDifferentiation already has backends for other ADs, a user can load the corresponding AD and run the code they want:
The main change here is to stop using
Zygote.gradient
and useAD.gradient
(whereAD === AbstractDifferentiation
) instead. Even though AbstractDifferentiation.jl supports Zygote, it really only supports the explicit mode as a special case of ChainRules.jl compatible reverse mode ADs. This PR wraps this and does instead:ZygoteImplicitBackend
andZygoteExplicitBackend
as AbstractDifferentiation backends with the appropriate primitives defined. Both wrapAD.ZygoteBackend
but a clear type for each allows Flux to specialize dispatch.AD.gradient
for both of the above to get around AD failure where Zygote succeeds JuliaDiff/AbstractDifferentiation.jl#63 (comment).Optimisers.update
andOptimisers.update!
forFlux.AbstractOptimiser
(removing the oldFlux.update!
)train!
to utilize everything above correctly. Namely, constructing the loss correctly for implicit vs. explicit gradients.@require
-ed in AbstractDifferentiation.jl.PR Checklist