From 482dde05b7c780f25e1fc162d3f14984f8040442 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Thu, 12 Oct 2023 13:11:59 -0500 Subject: [PATCH] Resolve `NaN` in hypot near zero This checks whether `hypot` of the value component is zero, and if so switches to a next-order method. --- src/dual.jl | 9 ++++++++- test/MiscTest.jl | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/dual.jl b/src/dual.jl index 2ca4683f..1b89da48 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -587,11 +587,18 @@ end #-------# @inline function calc_hypot(x, y, z, ::Type{T}) where T + pm1(x) = signbit(x) ? -1 : 1 + vx = value(x) vy = value(y) vz = value(z) h = hypot(vx, vy, vz) - p = (vx / h) * partials(x) + (vy / h) * partials(y) + (vz / h) * partials(z) + p = if iszero(h) + hp = sqrt(sum(abs2, partials(x)) + sum(abs2, partials(y)) + sum(abs2, partials(z))) + pm1(vx) * partials(x) / hp + pm1(vy) * partials(y) / hp + pm1(vz) * partials(z) / hp + else + (vx / h) * partials(x) + (vy / h) * partials(y) + (vz / h) * partials(z) + end return Dual{T}(h, p) end diff --git a/test/MiscTest.jl b/test/MiscTest.jl index 0ed8039a..99aaaeee 100644 --- a/test/MiscTest.jl +++ b/test/MiscTest.jl @@ -140,6 +140,9 @@ let i, j i != j && @test h[i, j] ≈ 0.0 end end +hypotx(x) = hypot(x, 0.0, 0.0) +@test ForwardDiff.derivative(hypotx, 0.0) ≈ 1 +@test ForwardDiff.derivative(hypotx, -0.0) ≈ -1 ######## # misc #