From 55be25d3b7824aeeacc2b25c05fa8083ca395fe8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 11 Mar 2022 23:25:39 +0100 Subject: [PATCH 1/2] Special-case sqrt(0) --- src/rulesets/Base/fastmath_able.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 552392cb7..cb4ed7ac3 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -52,7 +52,23 @@ let # exponents @scalar_rule cbrt(x) inv(3 * Ω ^ 2) @scalar_rule inv(x) -(Ω ^ 2) - @scalar_rule sqrt(x) inv(2Ω) # gradient +Inf at x==0 + # ensure that at sqrt(0), a zero (co)tangent produces a zero (co)tangent + function frule((_, Δx), ::typeof(sqrt), x::Number) + Ω = sqrt(x) + ∂Ω = Δx / 2Ω + return Ω, ifelse(iszero(Δx) & iszero(x), zero(∂Ω), ∂Ω) + end + function rrule(::typeof(sqrt), x::Number) + Ω = sqrt(x) + function sqrt_pullback(ΔΩ) + ∂x = ΔΩ / 2conj(Ω) + return ( + NoTangent(), + ProjectTo(x)(ifelse(iszero(ΔΩ) & iszero(x), zero(∂x), ∂x)) + ) + end + return Ω, sqrt_pullback + end @scalar_rule exp(x) Ω @scalar_rule exp10(x) logten * Ω @scalar_rule exp2(x) logtwo * Ω From e43c03bef8f6b50e1f6c7856d3b6597d966c56cb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 11 Mar 2022 23:25:55 +0100 Subject: [PATCH 2/2] Test sqrt(0) --- test/rulesets/Base/fastmath_able.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index 68b0e2f67..843d7f834 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -90,6 +90,19 @@ const FASTABLE_AST = quote end end + # https://github.com/JuliaDiff/ChainRules.jl/issues/576 + @testset "sqrt(0)" begin + @testset for T in (Float64, ComplexF64) + z = zero(T) + @test frule((NoTangent(), z), sqrt, z)[2] === z + @test frule((NoTangent(), ZeroTangent()), sqrt, z)[2] === ZeroTangent() + @test !isfinite(frule((NoTangent(), one(z)), sqrt, z)[2]) + @test rrule(sqrt, z)[2](z)[2] === z + @test rrule(sqrt, z)[2](ZeroTangent())[2] === ZeroTangent() + @test !isfinite(rrule(sqrt, z)[2](one(z))[2]) + end + end + @testset "Unary complex functions" begin for f ∈ (abs, abs2, conj), z ∈ (-4.1-0.02im, 6.4, 3 + im) @testset "Unary complex functions f = $f, z = $z" begin