Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid NaN (co)tangents for sqrt(0) #599

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,23 @@ let
# exponents
@scalar_rule cbrt(x) inv(3 * Ω ^ 2)
@scalar_rule inv(x) -(Ω ^ 2)
@scalar_rule sqrt(x) inv(2Ω) # gradient +Inf at x==0
# ensure that at sqrt(0), a zero (co)tangent produces a zero (co)tangent
function frule((_, Δx), ::typeof(sqrt), x::Number)
Ω = sqrt(x)
∂Ω = Δx / 2Ω
return Ω, ifelse(iszero(Δx) & iszero(x), zero(∂Ω), ∂Ω)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason to use ifelse instead of a ternary operator (which does not require to evaluate both branches)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning was that so long as the type of ∂Ω is inferrable, the two branches do no extra work, and the use of & and ifelse both could perform better if this is used in an inner loop and potentially allow Zygote to perform better for higher order AD (since Zygote tends so be slow when hitting control flow but has a special rule for ifelse). However, I was unable to devise a benchmark that showed a substantial difference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that in older Julia versions there were cases where we could improve performance in SciML by avoiding zero or moving it out of loops. But I couldn't reproduce this with a simple example immediately, maybe it's not relevant here and/or fixed in recent Julia versions.

Copy link
Member

@devmotion devmotion Mar 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example where it matters:

julia> using BenchmarkTools

julia> function f(x)    
           s = zero(x)      
           for i in 1:10    
               s += iseven(i) ? zero(x) : x
           end   
           return s
       end             
f (generic function with 1 method) 
                                                     
julia> function g(x)       
           s = zero(x)      
           for i in 1:10
               s += ifelse(iseven(i), zero(x), x)
           end
           return s
       end                 
g (generic function with 1 method)

julia> @btime f($(big"1.0"));
  45.640 μs (3002 allocations: 164.17 KiB)

julia> @btime g($(big"1.0"));
  56.341 μs (4002 allocations: 218.86 KiB)

end
function rrule(::typeof(sqrt), x::Number)
Ω = sqrt(x)
function sqrt_pullback(ΔΩ)
∂x = ΔΩ / 2conj(Ω)
return (
NoTangent(),
ProjectTo(x)(ifelse(iszero(ΔΩ) & iszero(x), zero(∂x), ∂x))
)
end
return Ω, sqrt_pullback
end
@scalar_rule exp(x) Ω
@scalar_rule exp10(x) logten * Ω
@scalar_rule exp2(x) logtwo * Ω
Expand Down
13 changes: 13 additions & 0 deletions test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ const FASTABLE_AST = quote
end
end

# https://github.com/JuliaDiff/ChainRules.jl/issues/576
@testset "sqrt(0)" begin
@testset for T in (Float64, ComplexF64)
z = zero(T)
@test frule((NoTangent(), z), sqrt, z)[2] === z
@test frule((NoTangent(), ZeroTangent()), sqrt, z)[2] === ZeroTangent()
@test !isfinite(frule((NoTangent(), one(z)), sqrt, z)[2])
@test rrule(sqrt, z)[2](z)[2] === z
@test rrule(sqrt, z)[2](ZeroTangent())[2] === ZeroTangent()
@test !isfinite(rrule(sqrt, z)[2](one(z))[2])
end
end

@testset "Unary complex functions" begin
for f ∈ (abs, abs2, conj), z ∈ (-4.1-0.02im, 6.4, 3 + im)
@testset "Unary complex functions f = $f, z = $z" begin
Expand Down