Skip to content

Commit

Permalink
Fix DiffRules-based definitions for complex-valued functions (#577)
Browse files Browse the repository at this point in the history
* Fix DiffRules-based definitions for complex-valued functions

* Update tests

* Update Project.toml
  • Loading branch information
devmotion authored Apr 27, 2022
1 parent e11936c commit 62d557b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.26"
version = "0.10.27"

[deps]
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Expand Down
40 changes: 36 additions & 4 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,38 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b
return esc(defs)
end

# Support complex-valued functions such as `hankelh1`
function dual_definition_retval(::Val{T}, val::Real, deriv::Real, partial::Partials) where {T}
return Dual{T}(val, deriv * partial)
end
function dual_definition_retval(::Val{T}, val::Real, deriv1::Real, partial1::Partials, deriv2::Real, partial2::Partials) where {T}
return Dual{T}(val, _mul_partials(partial1, partial2, deriv1, deriv2))
end
function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Complex}, partial::Partials) where {T}
reval, imval = reim(val)
if deriv isa Real
p = deriv * partial
return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p)))
else
rederiv, imderiv = reim(deriv)
return Complex(Dual{T}(reval, rederiv * partial), Dual{T}(imval, imderiv * partial))
end
end
function dual_definition_retval(::Val{T}, val::Complex, deriv1::Union{Real,Complex}, partial1::Partials, deriv2::Union{Real,Complex}, partial2::Partials) where {T}
reval, imval = reim(val)
if deriv1 isa Real && deriv2 isa Real
p = _mul_partials(partial1, partial2, deriv1, deriv2)
return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p)))
else
rederiv1, imderiv1 = reim(deriv1)
rederiv2, imderiv2 = reim(deriv2)
return Complex(
Dual{T}(reval, _mul_partials(partial1, partial2, rederiv1, rederiv2)),
Dual{T}(imval, _mul_partials(partial1, partial2, imderiv1, imderiv2)),
)
end
end

function unary_dual_definition(M, f)
FD = ForwardDiff
Mf = M == :Base ? f : :($M.$f)
Expand All @@ -206,7 +238,7 @@ function unary_dual_definition(M, f)
@inline function $M.$f(d::$FD.Dual{T}) where T
x = $FD.value(d)
$work
return $FD.Dual{T}(val, deriv * $FD.partials(d))
return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d))
end
end
end
Expand Down Expand Up @@ -236,17 +268,17 @@ function binary_dual_definition(M, f)
begin
vx, vy = $FD.value(x), $FD.value(y)
$xy_work
return $FD.Dual{Txy}(val, $FD._mul_partials($FD.partials(x), $FD.partials(y), dvx, dvy))
return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y))
end,
begin
vx = $FD.value(x)
$x_work
return $FD.Dual{Tx}(val, dvx * $FD.partials(x))
return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x))
end,
begin
vy = $FD.value(y)
$y_work
return $FD.Dual{Ty}(val, dvy * $FD.partials(y))
return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y))
end
)
end
Expand Down
48 changes: 38 additions & 10 deletions test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)

if V != Int
for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if f in (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)
if f in (:/, :rem2pi)
continue # Skip these rules
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
continue # Skip rules for methods not defined in the current scope
Expand All @@ -457,9 +457,20 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
end
@eval begin
x = rand() + $modifier
dx = $M.$f(Dual{TestTag()}(x, one(x)))
@test value(dx) == $M.$f(x)
@test partials(dx, 1) == $deriv
dx = @inferred $M.$f(Dual{TestTag()}(x, one(x)))
actualval = $M.$f(x)
@assert actualval isa Real || actualval isa Complex
if actualval isa Real
@test dx isa Dual{TestTag()}
@test value(dx) == actualval
@test partials(dx, 1) == $deriv
else
@test dx isa Complex{<:Dual{TestTag()}}
@test value(real(dx)) == real(actualval)
@test value(imag(dx)) == imag(actualval)
@test partials(real(dx), 1) == real($deriv)
@test partials(imag(dx), 1) == imag($deriv)
end
end
elseif arity == 2
derivs = DiffRules.diffrule(M, f, :x, :y)
Expand All @@ -472,14 +483,31 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
end
@eval begin
x, y = $x, $y
dx = $M.$f(Dual{TestTag()}(x, one(x)), y)
dy = $M.$f(x, Dual{TestTag()}(y, one(y)))
dx = @inferred $M.$f(Dual{TestTag()}(x, one(x)), y)
dy = @inferred $M.$f(x, Dual{TestTag()}(y, one(y)))
actualdx = $(derivs[1])
actualdy = $(derivs[2])
@test value(dx) == $M.$f(x, y)
@test value(dy) == value(dx)
@test partials(dx, 1) actualdx nans=true
@test partials(dy, 1) actualdy nans=true
actualval = $M.$f(x, y)
@assert actualval isa Real || actualval isa Complex
if actualval isa Real
@test dx isa Dual{TestTag()}
@test dy isa Dual{TestTag()}
@test value(dx) == actualval
@test value(dy) == actualval
@test partials(dx, 1) actualdx nans=true
@test partials(dy, 1) actualdy nans=true
else
@test dx isa Complex{<:Dual{TestTag()}}
@test dy isa Complex{<:Dual{TestTag()}}
@test real(value(dx)) == real(actualval)
@test real(value(dy)) == real(actualval)
@test imag(value(dx)) == imag(actualval)
@test imag(value(dy)) == imag(actualval)
@test partials(real(dx), 1) real(actualdx) nans=true
@test partials(real(dy), 1) real(actualdy) nans=true
@test partials(imag(dx), 1) imag(actualdx) nans=true
@test partials(imag(dy), 1) imag(actualdy) nans=true
end
end
end
end
Expand Down

2 comments on commit 62d557b

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/59240

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.27 -m "<description of version>" 62d557bcd51288091a20a284752b200ab2721075
git push origin v0.10.27

Please sign in to comment.