From 62d557bcd51288091a20a284752b200ab2721075 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 27 Apr 2022 17:01:47 +0200 Subject: [PATCH] Fix DiffRules-based definitions for complex-valued functions (#577) * Fix DiffRules-based definitions for complex-valued functions * Update tests * Update Project.toml --- Project.toml | 2 +- src/dual.jl | 40 ++++++++++++++++++++++++++++++++++++---- test/DualTest.jl | 48 ++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 75 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 61e31efa..31508c9c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.26" +version = "0.10.27" [deps] CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" diff --git a/src/dual.jl b/src/dual.jl index 08199bc4..c3dc48a0 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -195,6 +195,38 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b return esc(defs) end +# Support complex-valued functions such as `hankelh1` +function dual_definition_retval(::Val{T}, val::Real, deriv::Real, partial::Partials) where {T} + return Dual{T}(val, deriv * partial) +end +function dual_definition_retval(::Val{T}, val::Real, deriv1::Real, partial1::Partials, deriv2::Real, partial2::Partials) where {T} + return Dual{T}(val, _mul_partials(partial1, partial2, deriv1, deriv2)) +end +function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Complex}, partial::Partials) where {T} + reval, imval = reim(val) + if deriv isa Real + p = deriv * partial + return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p))) + else + rederiv, imderiv = reim(deriv) + return Complex(Dual{T}(reval, rederiv * partial), Dual{T}(imval, imderiv * partial)) + end +end +function dual_definition_retval(::Val{T}, val::Complex, deriv1::Union{Real,Complex}, partial1::Partials, deriv2::Union{Real,Complex}, partial2::Partials) where {T} + reval, imval = reim(val) + if deriv1 isa Real && deriv2 isa Real + p = _mul_partials(partial1, partial2, deriv1, deriv2) + return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p))) + else + rederiv1, imderiv1 = reim(deriv1) + rederiv2, imderiv2 = reim(deriv2) + return Complex( + Dual{T}(reval, _mul_partials(partial1, partial2, rederiv1, rederiv2)), + Dual{T}(imval, _mul_partials(partial1, partial2, imderiv1, imderiv2)), + ) + end +end + function unary_dual_definition(M, f) FD = ForwardDiff Mf = M == :Base ? f : :($M.$f) @@ -206,7 +238,7 @@ function unary_dual_definition(M, f) @inline function $M.$f(d::$FD.Dual{T}) where T x = $FD.value(d) $work - return $FD.Dual{T}(val, deriv * $FD.partials(d)) + return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d)) end end end @@ -236,17 +268,17 @@ function binary_dual_definition(M, f) begin vx, vy = $FD.value(x), $FD.value(y) $xy_work - return $FD.Dual{Txy}(val, $FD._mul_partials($FD.partials(x), $FD.partials(y), dvx, dvy)) + return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y)) end, begin vx = $FD.value(x) $x_work - return $FD.Dual{Tx}(val, dvx * $FD.partials(x)) + return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x)) end, begin vy = $FD.value(y) $y_work - return $FD.Dual{Ty}(val, dvy * $FD.partials(y)) + return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y)) end ) end diff --git a/test/DualTest.jl b/test/DualTest.jl index 3cbbaba0..fd5ac210 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -440,7 +440,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32) if V != Int for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) - if f in (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi) + if f in (:/, :rem2pi) continue # Skip these rules elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) continue # Skip rules for methods not defined in the current scope @@ -457,9 +457,20 @@ for N in (0,3), M in (0,4), V in (Int, Float32) end @eval begin x = rand() + $modifier - dx = $M.$f(Dual{TestTag()}(x, one(x))) - @test value(dx) == $M.$f(x) - @test partials(dx, 1) == $deriv + dx = @inferred $M.$f(Dual{TestTag()}(x, one(x))) + actualval = $M.$f(x) + @assert actualval isa Real || actualval isa Complex + if actualval isa Real + @test dx isa Dual{TestTag()} + @test value(dx) == actualval + @test partials(dx, 1) == $deriv + else + @test dx isa Complex{<:Dual{TestTag()}} + @test value(real(dx)) == real(actualval) + @test value(imag(dx)) == imag(actualval) + @test partials(real(dx), 1) == real($deriv) + @test partials(imag(dx), 1) == imag($deriv) + end end elseif arity == 2 derivs = DiffRules.diffrule(M, f, :x, :y) @@ -472,14 +483,31 @@ for N in (0,3), M in (0,4), V in (Int, Float32) end @eval begin x, y = $x, $y - dx = $M.$f(Dual{TestTag()}(x, one(x)), y) - dy = $M.$f(x, Dual{TestTag()}(y, one(y))) + dx = @inferred $M.$f(Dual{TestTag()}(x, one(x)), y) + dy = @inferred $M.$f(x, Dual{TestTag()}(y, one(y))) actualdx = $(derivs[1]) actualdy = $(derivs[2]) - @test value(dx) == $M.$f(x, y) - @test value(dy) == value(dx) - @test partials(dx, 1) ≈ actualdx nans=true - @test partials(dy, 1) ≈ actualdy nans=true + actualval = $M.$f(x, y) + @assert actualval isa Real || actualval isa Complex + if actualval isa Real + @test dx isa Dual{TestTag()} + @test dy isa Dual{TestTag()} + @test value(dx) == actualval + @test value(dy) == actualval + @test partials(dx, 1) ≈ actualdx nans=true + @test partials(dy, 1) ≈ actualdy nans=true + else + @test dx isa Complex{<:Dual{TestTag()}} + @test dy isa Complex{<:Dual{TestTag()}} + @test real(value(dx)) == real(actualval) + @test real(value(dy)) == real(actualval) + @test imag(value(dx)) == imag(actualval) + @test imag(value(dy)) == imag(actualval) + @test partials(real(dx), 1) ≈ real(actualdx) nans=true + @test partials(real(dy), 1) ≈ real(actualdy) nans=true + @test partials(imag(dx), 1) ≈ imag(actualdx) nans=true + @test partials(imag(dy), 1) ≈ imag(actualdy) nans=true + end end end end