You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The actual use case is that I'm trying to train a NN where I need to compute the gradient of its parameters w.r.t a loss function that involves finite differences given by central_fdm. Is this possible?
The text was updated successfully, but these errors were encountered:
Can you just use it the other way, i.e. FiniteDifferences over Zygote instead of Zygote over FiniteDifferences? You could actually define a custom adjoint for FiniteDifferenceMethod which does this automatically for cases like this.
Step size adaptation appears to be not working well with Zygote. Unfortunately, I'm not familiar enough with the inner workings of Zygote to see what precisely breaks down. Perhaps @willtebbutt or @oxinabox could provide some insight. From the error message, if appears that a pullback attempts to modify a StaticArray in-place, which won't work.
Once you turn off step size adaptation, things seem to work:
julia> fdm =central_fdm(5, 1; adapt=0);
julia>cos_(x) =fdm(sin, x)
cos_ (generic function with 1 method)
julia> cos_'(1)
-0.8414709848079269
julia>-sin(1)
-0.8414709848078965
I would be careful with AD-ing through finite difference estimates, though. I second @simeonschaub's suggestion of taking finite differences of gradients computed by AD.
EDIT: If we wanted to make this work, we could do something like
using FiniteDifferences, Zygote
function (m::FiniteDifferences.AdaptedFiniteDifferenceMethod)(f::TF, x::Real) where TF<:Function
x =float(x) # Assume that converting to float is desired, if it isn't already.
step = Zygote.dropgrad(first(FiniteDifferences.estimate_step(m, f, x)))
returnm(f, x, step)
end
Then
julia> fdm =central_fdm(5, 1; adapt=1);
julia>cos_(x) =fdm(sin, x)
cos_ (generic function with 1 method)
julia> cos_'(1)
-0.8414709848078701
julia>-sin(1)
-0.8414709848078965
Summary: it seems that the finite difference can't be further differentiated by Zygote
Example: If I do
and then call
I get the following error message:
The actual use case is that I'm trying to train a NN where I need to compute the gradient of its parameters w.r.t a loss function that involves finite differences given by
central_fdm
. Is this possible?The text was updated successfully, but these errors were encountered: