-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Algorithms which branch to simpler implementations #114
Comments
I'm not able to find a Github reference for this, but this question is very similar to another one we had a while back about taking the minimum of an array. The gist there was as follows: Suppose you have a function like the following: function min1(x)
result = 100000
for val in x
if val <= result
result = val
end
end
return result
end If we were to differentiate this on an input array If instead the code were rewritten to get the "first minimum value" (as below) we would instead get the derivative function min2(x)
result = 100000
for val in x
if val < result
result = val
end
end
return result
end You could also do some crazy things that in implementation lead to any normalized linear combination of those two results (which is also a valid subderivative). function min3(x)
results = [100000]
for val in x
if val < result
results = [val]
else if val == results[1]
push!(results, val)
end
end
return mean(results)
end
# Results in a derivative of [0.5, 0, 0.5, 0, 0] As is Enzyme guarantees to give you a subderivative. In fact, it guarantees to give you the expected subderivative of the code it directly analyzes and differentiates. If one applies optimizations between the original function and before Enzyme, it is possible that Enzyme would give you a different (but still correct) subderivative. In the examples here, an LLVM or Julia optimization could change the code from More generally, this is an example of the issue where "derivative of approximation != approximation of derivative". If the function we wanted to differentiate from your example was the product of the first numbers up to the first zero, Enzyme and FwdDiff would produce the correct result even on the approximation. The error here is that the approximation is not the actual function you want the code to represent. Enzyme does have a construct to help with this problem (and we can think about other ways as well), illustrated here: https://github.com/wsmoses/Enzyme/blob/main/enzyme/test/Integration/ReverseMode/mycos.c [we haven't exposed this up in Enzyme.jl but it should be easy to do so]. In essence the cc @oxinabox |
We can also do some other tricks to handle measure zero spaces [determined by a floating point equality check] explicitly (or say insert things automatically), but I thought this a good opportunity to raise the larger point which is that there's a disconnect between the intent of the code and its implementation. As another solution, we could use our LLVM-level function modification/replacement of julia methods to rewrite the method to conform with the intent of the code. |
Agree this is connected to the subgradient story, in that optimisations don't commute with differentiation. But while we could decide to shrug and declare all subgradients acceptable, we can't do this when one of the branches only treats a subspace. I was thinking mostly of manually written branches, as in LinearAlgebra. But if "optimization could change the code from min1 to min2", might it also change Thanks I will try to decode a bit of what you say about For julia> using ForwardDiff, FiniteDifferences, Statistics
julia> function min4(x) # tweaked min3
results = [100000]
for val in x
if val < results[1]
results = [val]
elseif val == results[1]
push!(results, val)
end
end
return mean(results)
end;
julia> min4([2, 3, 2, 4, 5]) == min2([2, 3, 2, 4, 5]) == 2
true
julia> ForwardDiff.gradient(min1, [2, 3, 2, 4, 5]) |> println
[0, 0, 1, 0, 0]
julia> ForwardDiff.gradient(min2, [2, 3, 2, 4, 5]) |> println
[1, 0, 0, 0, 0]
julia> ForwardDiff.gradient(min4, [2, 3, 2, 4, 5]) |> println
[0.5, 0.0, 0.5, 0.0, 0.0]
julia> grad(central_fdm(5, 1), min2, [2, 3, 2, 4, 5.0]) # similar for all versions
([0.5000000000050965, 1.2170899983279496e-13, 0.5000000000050965, 1.2170899983279496e-13, 1.2170899983279496e-13],) |
One thing that I wanted to add here, just to illustrate how FDM may completely fail here, consider:
Here this is extremely bad since now the total gradient isn't conserved. This means that if you coupled this definition of min in a chainrules like settings with a fill (e.g. |
Oh that's weird. I've been confused by this before; for a while I though I understood some of it as a tangent vs cotangent and the rest as bugs. However, |
Sure in this example when composed with fill, numeric differentiation can get the correct derivative of the entire function. The fact that that case works, however, doesn't detract from the issue that finite differences can give results which aren't even subgradients and have several bad consequences as a result (such as not being able compose). I would argue is far worse than something which gives a correct subgradient, though perhaps not the best one. From here I wrote out a bit of math for myself and I left it here if others find it useful. Suppose we use a simple finite difference algorithm to compute the derivative of each argument. Let
Thus
Let
And this is not a legal subgradient when Let's say
Now let's say f(z)=z to make this applicable to our use case (the fill)
Since we know that the final derivative must be equal to one.
Thus to compute the correct answer when using the chain rule, we must ensure our gradient sums to one. Notably here this permits a choice of index for which value is set to 1 if an AD tool through optimization or otherwise, arbitrarily chose. This doesn't mean there isn't a good choice here for other reasons, but we definitely want to avoid making a choice which breaks composibility. |
I was wondering what happens when there is a branch in the function, to use a simpler algorithm when (say) some variable is zero.
Below is a toy example, which seems to lead to wrong answers. What I was working towards when I hit #112 was the fact that 5-arg
mul!
takes shortcuts whenα == 0
; LinearAlgebra is full of such things which e.g. branch to a simpler algorithm whenishermitian(A)
.Is this a problem which Enzyme could conceivably detect and avoid? ForwardDiff does not do so at present, but I think that essentially changing
iszero(::Dual)
can fix it, discussed here: JuliaDiff/ForwardDiff.jl#480 . Enzyme seems like black magic to me, but if it knows which numbers areActive
, can it somehow use this to avoid measure-zero branches?The text was updated successfully, but these errors were encountered: