Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Algorithms which branch to simpler implementations #114

Open
mcabbott opened this issue Jul 30, 2021 · 6 comments
Open

Algorithms which branch to simpler implementations #114

mcabbott opened this issue Jul 30, 2021 · 6 comments

Comments

@mcabbott
Copy link
Contributor

mcabbott commented Jul 30, 2021

I was wondering what happens when there is a branch in the function, to use a simpler algorithm when (say) some variable is zero.

Below is a toy example, which seems to lead to wrong answers. What I was working towards when I hit #112 was the fact that 5-arg mul! takes shortcuts when α == 0; LinearAlgebra is full of such things which e.g. branch to a simpler algorithm when ishermitian(A).

Is this a problem which Enzyme could conceivably detect and avoid? ForwardDiff does not do so at present, but I think that essentially changing iszero(::Dual) can fix it, discussed here: JuliaDiff/ForwardDiff.jl#480 . Enzyme seems like black magic to me, but if it knows which numbers are Active, can it somehow use this to avoid measure-zero branches?

julia> using Enzyme, ForwardDiff

julia> x = Float32[1 2 0 3 4];

julia> dx = zero(x); autodiff(prod, Active, Duplicated(x, dx)); dx  # ok
1×5 Matrix{Float32}:
 0.0  0.0  24.0  0.0  0.0

# implementations of prod from https://github.com/JuliaDiff/ForwardDiff.jl/issues/480

julia> function prod1(xs::Array)
           p = one(eltype(xs))
           for x in xs
               p = p * x
           end
           p
       end;

julia> function prod2(xs::Array)
           p = one(eltype(xs))
           for x in xs
               p = p * x
               p == 0 && break  # exit early once you know the answer
           end
           p
       end;

julia> dx = zero(x); autodiff(prod1, Active, Duplicated(x, dx)); dx  # ok
1×5 Matrix{Float32}:
 0.0  0.0  24.0  0.0  0.0

 julia> dx = zero(x); autodiff(prod2, Active, Duplicated(x, dx)); dx  # wrong
1×5 Matrix{Float32}:
 0.0  0.0  2.0  0.0  0.0

julia> ForwardDiff.gradient(prod2, x)  # same mistake
1×5 Matrix{Float32}:
 0.0  0.0  2.0  0.0  0.0
@wsmoses
Copy link
Member

wsmoses commented Jul 30, 2021

I'm not able to find a Github reference for this, but this question is very similar to another one we had a while back about taking the minimum of an array. The gist there was as follows:

Suppose you have a function like the following:

function min1(x)
    result = 100000
    for val in x
       if val <= result
         result = val
       end
    end
    return result
end

If we were to differentiate this on an input array x = [2, 3, 2, 4, 5] we would get as the derivative dx = [0, 0, 1, 0, 0] since our implementation defined us to return the "last minimum value" As such, Enzyme running on this (unoptimized) code would return a gradient for the corresponding value.

If instead the code were rewritten to get the "first minimum value" (as below) we would instead get the derivative dx = [1, 0, 0, 0, 0]

function min2(x)
    result = 100000
    for val in x
       if val < result
         result = val
       end
    end
    return result
end

You could also do some crazy things that in implementation lead to any normalized linear combination of those two results (which is also a valid subderivative).

function min3(x)
    results = [100000]
    for val in x
       if val < result
         results = [val]
       else if val == results[1]
         push!(results, val)
       end
    end
    return mean(results)
end
# Results in a derivative of [0.5, 0, 0.5, 0, 0]

As is Enzyme guarantees to give you a subderivative. In fact, it guarantees to give you the expected subderivative of the code it directly analyzes and differentiates. If one applies optimizations between the original function and before Enzyme, it is possible that Enzyme would give you a different (but still correct) subderivative. In the examples here, an LLVM or Julia optimization could change the code from min1 to min2 (as far as I've checked there's nothing that would modify mod3).

More generally, this is an example of the issue where "derivative of approximation != approximation of derivative". If the function we wanted to differentiate from your example was the product of the first numbers up to the first zero, Enzyme and FwdDiff would produce the correct result even on the approximation. The error here is that the approximation is not the actual function you want the code to represent.

Enzyme does have a construct to help with this problem (and we can think about other ways as well), illustrated here: https://github.com/wsmoses/Enzyme/blob/main/enzyme/test/Integration/ReverseMode/mycos.c [we haven't exposed this up in Enzyme.jl but it should be easy to do so].

In essence the __enzyme_iter(a, b) function returns a + (# of derivatives taken) * b. This is especially useful for maintaining the accuracy of taylor series approximations as shown in the code there. For this example, one could imagine using __enzyme_iter to do the early exit branch when not being differentiated (e.g. purely an optimization for the primal), whereas when differentiating it won't take the branch and thus pick the subderivative you'd prefer.

cc @oxinabox

@wsmoses
Copy link
Member

wsmoses commented Jul 30, 2021

We can also do some other tricks to handle measure zero spaces [determined by a floating point equality check] explicitly (or say insert things automatically), but I thought this a good opportunity to raise the larger point which is that there's a disconnect between the intent of the code and its implementation.

As another solution, we could use our LLVM-level function modification/replacement of julia methods to rewrite the method to conform with the intent of the code.

@mcabbott
Copy link
Contributor Author

mcabbott commented Jul 30, 2021

Agree this is connected to the subgradient story, in that optimisations don't commute with differentiation. But while we could decide to shrug and declare all subgradients acceptable, we can't do this when one of the branches only treats a subspace.

I was thinking mostly of manually written branches, as in LinearAlgebra. But if "optimization could change the code from min1 to min2", might it also change prod1 into prod2? I suppose without @fastmath zero isn't a fixed-point of *, it could get a NaN or an Inf so it can't stop.

Thanks I will try to decode a bit of what you say about __enzyme_iter, sounds interesting.

For minimum, ForwardDiff of course takes the same paths through these examples, xref also JuliaDiff/ChainRules.jl#480 which wants to pick a convention:

julia> using ForwardDiff, FiniteDifferences, Statistics

julia> function min4(x)  # tweaked min3
           results = [100000]
           for val in x
              if val < results[1]
                results = [val]
              elseif val == results[1]
                push!(results, val)
              end
           end
           return mean(results)
       end;

julia> min4([2, 3, 2, 4, 5]) == min2([2, 3, 2, 4, 5]) == 2
true

julia> ForwardDiff.gradient(min1, [2, 3, 2, 4, 5]) |> println
[0, 0, 1, 0, 0]
 
julia> ForwardDiff.gradient(min2, [2, 3, 2, 4, 5]) |> println
[1, 0, 0, 0, 0]

julia> ForwardDiff.gradient(min4, [2, 3, 2, 4, 5]) |> println
[0.5, 0.0, 0.5, 0.0, 0.0]

julia> grad(central_fdm(5, 1), min2, [2, 3, 2, 4, 5.0])  # similar for all versions
([0.5000000000050965, 1.2170899983279496e-13, 0.5000000000050965, 1.2170899983279496e-13, 1.2170899983279496e-13],)

@wsmoses
Copy link
Member

wsmoses commented Jul 30, 2021

One thing that I wanted to add here, just to illustrate how FDM may completely fail here, consider:

julia> grad(central_fdm(5, 1), min1, [2, 3, 2, 4, 5.0, 2])
([0.5000000000050965, 1.2170899983279496e-13, 0.5000000000050965, 1.2170899983279496e-13, 1.2170899983279496e-13, 0.5000000000050965],)

Here this is extremely bad since now the total gradient isn't conserved. This means that if you coupled this definition of min in a chainrules like settings with a fill (e.g. min(fill(x, 10)) ) instead of getting 1 you would get back 5, which is clearly wrong and not even a subgradient.

@mcabbott
Copy link
Contributor Author

Oh that's weird. I've been confused by this before; for a while I though I understood some of it as a tangent vs cotangent and the rest as bugs. However, grad(central_fdm(5, 1), x -> minimum(fill(x,10)), 1.0) is indeed 1, because it's working right through, not composing pieces like that isolated above, the way reverse-mode AD would.

@wsmoses
Copy link
Member

wsmoses commented Jul 30, 2021

Sure in this example when composed with fill, numeric differentiation can get the correct derivative of the entire function.

The fact that that case works, however, doesn't detract from the issue that finite differences can give results which aren't even subgradients and have several bad consequences as a result (such as not being able compose). I would argue is far worse than something which gives a correct subgradient, though perhaps not the best one.

From here I wrote out a bit of math for myself and I left it here if others find it useful.

Suppose we use a simple finite difference algorithm to compute the derivative of each argument. Let x is equal to the (non unique) minimum value of all arguments. A simple finite difference method could define

d/dx min(x, ...) = [ min(x+ δ, ...) - min(x-δ, ...)] / (2δ)
= [ (x+ δ) - (x)] / (2δ)
= 1/2

Thus ∇ min (2, 2, 2, 2, ...) = (1/2, 1/2, 1/2, 1/2, ...) A subgradient must satisfy the property below:

f(y) ≥ f(x) + ∇f(x) * (y − x)

Let x = (2, ...) and y = (3, ...). Thus we find (with d as dimension):

min(3, ...) ≥ min(2, ...) + ∇min(2, ...) * (1, ...)
3 ≥ 2 + (1/2, ...) * (1, ...)
3 ≥ 2 + d/2

And this is not a legal subgradient when d>2.

Let's say x[0] = f(z), x[1] = f(z), ..., let's compute r = min(x = vec f(z))
By applying the chain rule,

dz/dr = \sum_i (partial min/partial x[i]) * f'(z)

Now let's say f(z)=z to make this applicable to our use case (the fill)

dz/dr = \sum_i (partial min/partial x[i])

Since we know that the final derivative must be equal to one.

1 = \sum_i (partial min/partial x[i])

Thus to compute the correct answer when using the chain rule, we must ensure our gradient sums to one. Notably here this permits a choice of index for which value is set to 1 if an AD tool through optimization or otherwise, arbitrarily chose. This doesn't mean there isn't a good choice here for other reasons, but we definitely want to avoid making a choice which breaks composibility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants