-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make params non-differentiable (Closes #2040 & #2048) #2054
Conversation
@ToucheSir can you enable the downstream tests for this pr? |
FastAI failure is innocuous and known. Metalhead I think too (cc @darsnack to confirm it's just the ResNet weights loading). GeometricFlux I'm not sure about, but there doesn't appear to be any AD calls in https://github.com/FluxML/GeometricFlux.jl/blob/master/test/layers/graphlayers.jl. @yuehhua can you confirm? Edit: I forgot AtomicGraphNets in the CI jumble, but the CI there has failed for all other current Flux PRs which run downstream tests and thus I don't think it's related. |
Failure in GeometricFlux is not related to AD and this PR could go. I will check the error further. |
src/functor.jl
Outdated
@@ -88,6 +88,9 @@ function params(m...) | |||
return ps | |||
end | |||
|
|||
# Allows caching of the parameters when params is called within gradient() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a quick mention of #2040 here? Otherwise this looks good to go!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pushed with edited comment.
Thanks! |
Unfortunately this breaks the use of julia> m = (x=[1,2.0], y=[3.0]);
julia> gradient(m -> (sum(norm, Flux.params(m))), m)
((x = [0.4472135954999579, 0.8944271909999159], y = [1.0]),) # before, [email protected]
(nothing,) # after, [email protected]
julia> gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m))
Grads(...)
julia> ans[m.x] # unchanged
2-element Vector{Float64}:
0.4472135954999579
0.8944271909999159 |
This is the new pull request I mentioned in #2048 to allow
params()
calls from withingradient()
to be cached by makingparams()
non-differentiable. This would close issue #2040, and supersede pull request #2048.Might be worth turning on downstream test like in the other pr.