Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inequality, ε-balls, and accidental structural zeros #480

Open
mcabbott opened this issue Nov 27, 2020 · 9 comments
Open

Inequality, ε-balls, and accidental structural zeros #480

mcabbott opened this issue Nov 27, 2020 · 9 comments

Comments

@mcabbott
Copy link
Member

Consider this function:

sq(x) = x==1 ? one(x) : x^2

@test FiniteDifferences.central_fdm(5, 1)(sq, 1)  2.0
@test_broken ForwardDiff.derivative(sq, 1.0) == 2.0

Here ForwardDiff gets the wrong answer, according to your first calculus class: The derivative is defined by taking limits, evaluating sq(x + ε) for some small ε, and these always see the continuum x^2 not the special point.

One to think about this is to say that x==1 really means abs(x-1) < ζ with some tiny ζ, which we keep finite until we are sure we aren't confused. The calculus class assumption is that ζ << ε.

The assumption of ForwardDiff is the opposite. Its Dual(x,1) encodes a perturbation x + 1ε with ε smaller than everything else around, and in particular ε << ζ. Or in other words, sq is viewed as being piecewise continuous, with a small flat area of width , which is still large enough for us to see that its slope is zero.

Of course nobody really writes contrived examples like sq. But they do write things like this:

function prod1(xs::Vector)
    p = one(eltype(xs))
    for x in xs
        p = p * x
    end
    p
end

function prod2(xs::Vector)
    p = one(eltype(xs))
    for x in xs
        p = p * x
        p == 0 && break  # exit early once you know the answer
    end
    p
end

@test ForwardDiff.gradient(prod1, [1,2,0,4,0,6]) == zeros(6)
@test_broken ForwardDiff.gradient(prod2, [1,2,0,4,0,6]) == zeros(6)

This has almost the same problem as #197, where det(A) tests for istriu(A) || istril(A) before calling a simpler branch. The fact that f(x,y) == g(x,y) when y==0 does not imply that df/dy == dg/dy. So it seems AD ought not to take that branch.

In which case we want something like this:

Base.:(==)(x::Dual, y::Int) = x.value == y && iszero(x.partials)
Base.:(!=)(x::Dual, y::Int) = x.value != y || !iszero(x.partials)

This fixes the tests above, and (a slightly more careful version) fixes #197 and #407.

However, it means that fun(Dual(x,1)).value need not be equal to fun(x), on a discontinuous function. Although fun(Dual(x,0)).value should still be equal, @assert zero(x) == 0 isn't broken, and there should be no problems where functions use things like zero(eltype(xs)) for type-stability.

The idea that the forward evaluation is unchanged is often thought of as an axiom of AD, but for discontinuous functions, I think that's another way of saying ε << ζ. Which is a choice. And one that your calculus teacher would disapprove of. The point of evaluating a function with dual numbers is, presumably, to find derivatives, so finding them correctly ought to have a higher priority.

There are other comparisons to think about, for example:

sq2(x) = x>1 ? x^2 : x<1 ? x^2 : one(x)

clamp2(x, lo=0, hi=1) = x>hi ? oftype(x,hi) : x<lo ? oftype(x,lo) : x
clamp3(x, lo=0, hi=1) = x>=hi ? oftype(x,hi) : x<=lo ? oftype(x,lo) : x

[ForwardDiff.derivative(cl, 1.0) for cl in [x->clamp(x,0,1), clamp2, clamp3]] == [1,1,0]
[central_fdm(5, 1)(cl, 1.0) for cl in [x->clamp(x,0,1), clamp2, clamp3]]  [0.5, 0.5, 0.5]

I'm not sure how often simulating x==1 as in sq2(x) happens in the wild. Perhaps from some combination like f(x) = relu(x) + 0.1*relu(-x)?

But clamping parameters to some range is routine. Your calculus teacher would throw an error here, but that's probably not the most helpful response for the computer.

Returning a nonzero derivative here is useful because, if this is some parameter being optimised, it means gradient descent won't get stuck against the wall, when the gradient is away from it. So you can argue that the ability to choose which sub-gradient ForwardDiff will use is a feature. The 0.5 gradient alla FiniteDifferences would also be fine for getting un-stuck, but it's very difficult to picture how ForwardDiff could evaluate both branches, and easy to picture doing so having awful side-effects.

Here is one way to relate the present rule for >(::Dual, ::Real) and >=(::Dual, ::Real) to the finite-everything ζ << ε story. We can say that while the ε-ball overlaps with both sides, the vote from the longer side (longer by about ) always wins by a hair:

----------(==========1==========)----------  abs(x-1) < ε
---------------------1-(===================  x > 1+ζ
          +++++++++++++..........            gradient votes, clamp2(1.0)

Trying out the above ==(::Dual, ::Real) rule, it looks like the tests of this package all pass, except for the ones explicitly testing such rules. It would be interesting to know if this breaks any other uses in the wild. It would also be interesting to think up other pathological examples, maybe I've missed something important.

Also:

  • Another way to talk about this: The problem with prod2 above, and det in support for the determinant function #197, is that they promote accidental zeros of the input to structural zeros. And AD then respects these, and gets the wrong answer. What looked like a simple optimisation when writing for real numbers, has been unintentionally promoted to a constraint on what derivatives are allowed. This is the reverse of the discussion about preserving structural zeros in things like Zygote.gradient(sum∘exp, Diagonal(ones(3)))[1].

  • Some other packages get this right, such as TaylorSeries.jl, this C++ code, and this Ruby. Some get it wrong (according to me) like this Mathematica code, this paper with Matlab, and this blog post, although he changed his mind from right to wrong. More mathematical treatments seem to regard the tuple (x,δx) as inheriting == from tuples, i.e. they get it right.

  • Similar things were also discussed in Problem with abs #377, where the example is this:

sq3(z) = abs(-z^2)
sq4(z) = abs2(z)
sq5(z) = z^2

[ForwardDiff.derivative(f, 0.0) for f in [sq, sq2, sq3, sq4, sq5]] == [0,0,0,0,0]
[ForwardDiff.hessian(x -> f(x[1]), [0.0])[1] for f in [sq, sq2, sq3, sq4, sq5]] == [2,2,-2,2,2]

A rule was suggested there in which x > y behaves differently for x.value == y.value, breaking such ties by comparing x.partials > y.partials. In the clamp2 example, whether you get stuck against the wall presumably shouldn't depend on whether you minimise loss(x) or maximise -loss(x), so we probably don't want to compare x.partials .> 0 when only x is a dual number. But the rule when both x and y are dual might be worth some more thought.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 27, 2020

There's another example here: TuringLang/DistributionsAD.jl#23 (comment). Behaviour on tagged ForwardDiff:

using ForwardDiff, LinearAlgebra
A = Matrix(I, 2,2)

ForwardDiff.gradient(x -> sum(x \ [1,3]), A)           # [-1  0; 0 -3]
ForwardDiff.gradient(x -> sum(cholesky(x) \ [1,3]), A) # [-1 -4; 0 -3] 

Aplus2 = [1 0; 0.001 1]  # perturb A[2]
sum(Aplus2 \ [1,3]) - sum(A \ [1,3])                      # -0.001
sum(cholesky(Aplus2) \ [1,3]) - sum(cholesky(A) \ [1,3])  # PosDefException: matrix is not Hermitian

Aplus23 = [1 0.001; 0.001 1]  # perturb A[2] and A[3] together
sum(Aplus23 \ [1,3]) - sum(A \ [1,3])                     # -0.004
sum(cholesky(Aplus23) \ [1,3]) - sum(cholesky(A) \ [1,3]) # -0.004

Perhaps one could claim the cholesky result is particular sub-gradient. But will the algorithm always behave this way? I'm don't know. The one for simple \ is silently wrong --- or at least, it has promoted A to be structurally diagonal. The Jacobians are:

ForwardDiff.jacobian(x -> x \ [1,3], A)
 # -1.0  -0.0  -0.0  -0.0
 # -0.0  -0.0  -0.0  -3.0

ForwardDiff.jacobian(x -> cholesky(x) \ [1,3], A)
 # -1.0  0.0  -3.0   0.0
 #  0.0  0.0  -1.0  -3.0

With the proposed change to ==, the results instead match finite differences. Here \ may choose a different algorithm, but this shouldn't cause a discontinuity of the forward pass.

ForwardDiff.gradient(x -> sum(x \ [1,3]), A)           # [-1 -3; -1 -3]
ForwardDiff.gradient(x -> sum(cholesky(x) \ [1,3]), A) # ERROR: PosDefException

ForwardDiff.jacobian(x -> x \ [1,3], A)
 # -1.0   0.0  -3.0   0.0
 #  0.0  -1.0   0.0  -3.0

Aplus2 = [1 0; 0.001 1]
Aplus2 \ [1, 3] - A \ [1, 3]  # [0, -0.001]

Aplus3 = [1 0.001; 0 1]
Aplus3 \ [1, 3] - A \ [1, 3]  # [-0.003, 0]

The example just above that, in TuringLang/DistributionsAD.jl#23 (comment), also gives an error on this branch, also from cholesky. As does a finite perturbation.

@mateuszbaran
Copy link

I didn't fully think it through but wouldn't it break functions like f(x) = x == 0.0 ? 1.0 : sin(x)/x? What if I replaced == 0.0 with iszero or ≈ 0.0?

@mateuszbaran
Copy link

Right now if I wanted correct derivatives of such functions up to a certain order I'd just put in the "measure zero branch" a Taylor approximation of the function.

@mcabbott
Copy link
Member Author

This f would indeed change. However, it has problems on a region bigger than one point, and so it would be better anyway for the smoothed branch to be used within some finite interval. For example:

julia> fpi(x) = x==0 ? 1.0 : sin(pi*x) / (pi*x);

julia> fpi(1e-40)
1.0

julia> ForwardDiff.derivative(fpi, 1e-40)
1.2089258196146292e24

This is a variant of #466, for which the solution is

julia> cosc(1e-40) 
-3.2898681336964526e-40

Even if you don't have an exact closed-form derivative, I think that you would usually want to replace f(x) with a constant / polynomial within some small interval, not at a single point. Perhaps always, are there exceptions?

Since the fallback is iszero(x) = x == zero(x) I'd be hesitant to mess with that, but isapprox could be given a different behaviour for dual numbers. At zero it doesn't by default widen the interval:

julia> ForwardDiff.Dual(1e-40,0)  0
false

julia> nextfloat(0.0)  0.0
false

@mcabbott
Copy link
Member Author

One more data point is that https://github.com/JuliaPhysics/Measurements.jl gets this right according to my argument above. While its error bars aren't exactly dual numbers, they are a related species. Here's a version of #536 in which mul! does not take the shortcut:

julia> using Measurements

julia> λ = measurement(0, 0.1)
0.0 ± 0.1

julia> iszero(λ)
false

julia> A = measurement.([1 2; 3 4], 0.1);

julia> B = measurement.([5 6; 7 8], 0.1);

julia> A * (B * λ)
2×2 Matrix{Measurement{Float64}}:
 0.0±1.9  0.0±2.2
 0.0±4.3  0.0±5.0

julia> mul!(similar(A), A, B, λ, 0)  # like issue 536
2×2 Matrix{Measurement{Float64}}:
 0.0±1.9  0.0±2.2
 0.0±4.3  0.0±5.0

@ChrisRackauckas
Copy link
Member

The idea that the forward evaluation is unchanged is often thought of as an axiom of AD

And it's definitely a false one. It's an axiom for computer scientists, but not for numerical analysts 😅. A nice example is in the space of ODEs. Automatic differentiation is equivalent to solving the expanded ODE known as the forward sensitivity equations, essentially:

u' = f(u,p,t)
d/dt (du/dp) = df/du du/dp + df/dp

Straight automatic differentiation is equivalent to solving the expanded ODE with the adaptive error controls only applying to the first part of the equation u' = f(u,p,t). Are there ODEs for which the second part is unstable when adaptivity is only applied to the first part? Yes. SciML/SciMLSensitivity.jl#273 is a real-world example where this came up. The solution was SciML/SciMLSensitivity.jl#273, i.e. the norm used in the ODE solver has to account for the pseudo-ODEs if you want it to be stable, and so the default norm that is used adds the partials to the primal part.

https://github.com/SciML/DiffEqBase.jl/blob/v6.83.1/src/forwarddiff.jl#L31-L34

This means that solving with ForwardDiff gives different stepping behavior, but if you don't do that, then there will be cases where you have "infinite" derivative because of numerical instability in the derivative calculation even when the primal is stable. So definitely, this axiom does not hold for the realities of numerical computing.

So back to the core of the thread, I definitely agree with you. In fact, DiffEq specializes its interpolation computation in order to work around this kind of issue. Normally it would just pull sol.u[i] if the interpolant matches sol.t[i] directly, but it needs to still use the interpolation if it's a dual number since otherwise the derivative is zero. If it had this epsilon ball definition, the workaround to force Dual numbers to not take the sol.t[i] == t branch would be eliminated.

@mcabbott
Copy link
Member Author

Perhaps it's worth noting that many other AD systems have the same problem. On the example above:

julia> Zygote.gradient(prod1, [1,2,0,4,0,6])
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0],)

julia> Zygote.gradient(prod2, [1,2,0,4,0,6])  # wrong
([0.0, 0.0, 2.0, 0.0, 0.0, 0.0],)

julia> Tracker.gradient(prod1, [1,2,0,4,0,6])  # (after removing :::Vector restriction)
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked),)

julia> Tracker.gradient(prod2, [1,2,0,4,0,6])  # wrong
([0.0, 0.0, 2.0, 0.0, 0.0, 0.0] (tracked),)

julia> dx = zeros(6); Enzyme.autodiff(prod1, Duplicated([1,2,0,4,0,6.], dx)); (dx,)
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0],)

julia> dx = zeros(6); Enzyme.autodiff(prod2, Duplicated([1,2,0,4,0,6.], dx)); (dx,)  # wrong
([0.0, 0.0, 2.0, 0.0, 0.0, 0.0],)

I think Zygote cannot fix this, as it does not know which variables are active when it transforms code. Tracker could surely change == for TrackedReal much like #481 here. Enzyme does know about activitiy; the relevant issue there is EnzymeAD/Enzyme.jl#114

@devmotion
Copy link
Member

I think @wsmoses's comment is quite interesting. I've never viewed it that way but I guess one could say that all these AD systems are correct as they correctly return the gradient of the function that is implemented/defined by prod2 - and they seem incorrect only if one considers prod2 just as an optimization of prod1 and actually would like to obtain the gradients of the function implemented by prod1. In any case, I guess these differences are surprising for most users who, I assume, don't expect to see any difference here and have not thought about the implications of different implementations for AD.

@bvdmitri
Copy link

bvdmitri commented Jan 26, 2023

It would the best if it was only surprising, but that behaviour for users is frustrating at least. From this discussion and discussion on Slack it turned out that this is known issue and property of many other AD backends, but for some reason it is not communicated well to end users who may rely on this in critical systems or extensive simulations. That is important, because it happens not only in toy examples, but in real code as well. A good example from our case is the dot(x, A * x) and dot(x, A, x), which produce different hessians if x is a zeroed vector (and 3-argument function produces not only different, but completely incorrect result too). This is happens all the time when you evaluate the Gaussian logpdf in its mean. And the dot function is not even written by us, but has been taken from the Julia's LinearAlgebra. Distributions.jl are not affected by pure luck, becausePDMats use 2-argument dot function.

This kind-of limitations should be communicated (e.g. in documentation) better to end users, who indeed have not thought about the implications of different implementations for AD. AD systems position themselves as fast and more accurate alternative to finite differences, but do not document clear (known!) pitfalls. That is bad and that is not specific to ForwardDiff. ForwardDiff is an amazing library, but why does a user should start thinking of potential implications of different AD backends if there is no indication that something may go wrong in the first place? Thats an open question of course.

It's great that ForwardDiff has a solution and it looks like the fix has been merged in master. I'm looking forward for the fix to be released.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants