Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve Symbolics.Num - ForwardDiff.Dual ambiguities #1036

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
[extensions]
SymbolicsGroebnerExt = "Groebner"
SymbolicsPreallocationToolsExt = ["ForwardDiff", "PreallocationTools"]
SymbolicsForwardDiffExt = "ForwardDiff"
SymbolicsSymPyExt = "SymPy"

[compat]
Expand Down
256 changes: 256 additions & 0 deletions ext/SymbolicsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
module SymbolicsForwardDiffExt

using ForwardDiff
using ForwardDiff.NaNMath
using ForwardDiff.DiffRules
using ForwardDiff: value, Dual, partials
using Symbolics

# The method generation in this file have been adapted from
# https://github.com/JuliaDiff/ForwardDiff.jl/blob/v0.10.36/src/dual.jl

const AMBIGUOUS_TYPES = (Num,)

####################################
# N-ary Operation Definition Tools #
####################################

macro define_binary_dual_op(f, xy_body, x_body, y_body, Ts)
FD = ForwardDiff
defs = quote end
for R in Ts
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}) where {Ty} = $y_body
end
append!(defs.args, expr.args)
end
return esc(defs)
end

macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body, Ts)
FD = ForwardDiff
defs = quote end
for R in Ts
expr = quote
@inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body
@inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body
@inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz ≺ Tx ? $x_body : $z_body
@inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body
end
append!(defs.args, expr.args)
for Q in Ts
Q === R && continue
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body
@inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body
end
append!(defs.args, expr.args)
end
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body
@inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body
end
append!(defs.args, expr.args)
end
return esc(defs)
end

function binary_dual_definition(M, f, Ts)
FD = ForwardDiff
dvx, dvy = DiffRules.diffrule(M, f, :vx, :vy)
Mf = M == :Base ? f : :($M.$f)
xy_work = FD.qualified_cse!(quote
val = $Mf(vx, vy)
dvx = $dvx
dvy = $dvy
end)
dvx, _ = DiffRules.diffrule(M, f, :vx, :y)
x_work = FD.qualified_cse!(quote
val = $Mf(vx, y)
dvx = $dvx
end)
_, dvy = DiffRules.diffrule(M, f, :x, :vy)
y_work = FD.qualified_cse!(quote
val = $Mf(x, vy)
dvy = $dvy
end)
expr = quote
@define_binary_dual_op(
$M.$f,
begin
vx, vy = $FD.value(x), $FD.value(y)
$xy_work
return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y))
end,
begin
vx = $FD.value(x)
$x_work
return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x))
end,
begin
vy = $FD.value(y)
$y_work
return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y))
end,
$Ts
)
end
return expr
end

###################################
# General Mathematical Operations #
###################################

for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-), (:Base, :sin), (:Base, :cos))
continue # Skip methods which we define elsewhere.
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
continue # Skip rules for methods not defined in the current scope
end
if arity == 1
# no-op
elseif arity == 2
eval(binary_dual_definition(M, f, AMBIGUOUS_TYPES))
else
# no-op
end
end

#################
# Special Cases #
#################

# +/- #
#-----#

@eval begin
@define_binary_dual_op(
Base.:+,
begin
vx, vy = value(x), value(y)
Dual{Txy}(vx + vy, partials(x) + partials(y))
end,
Dual{Tx}(value(x) + y, partials(x)),
Dual{Ty}(x + value(y), partials(y)),
$AMBIGUOUS_TYPES
)
end

@eval begin
@define_binary_dual_op(
Base.:-,
begin
vx, vy = value(x), value(y)
Dual{Txy}(vx - vy, partials(x) - partials(y))
end,
Dual{Tx}(value(x) - y, partials(x)),
Dual{Ty}(x - value(y), -partials(y)),
$AMBIGUOUS_TYPES
)
end

# / #
#---#

# We can't use the normal diffrule autogeneration for this because (x/y) === (x * (1/y))
# doesn't generally hold true for floating point; see issue #264
@eval begin
@define_binary_dual_op(
Base.:/,
begin
vx, vy = value(x), value(y)
Dual{Txy}(vx / vy, _div_partials(partials(x), partials(y), vx, vy))
end,
Dual{Tx}(value(x) / y, partials(x) / y),
begin
v = value(y)
divv = x / v
Dual{Ty}(divv, -(divv / v) * partials(y))
end,
$AMBIGUOUS_TYPES
)
end

# exponentiation #
#----------------#

for f in (:(Base.:^), :(NaNMath.pow))
@eval begin
@define_binary_dual_op(
$f,
begin
vx, vy = value(x), value(y)
expv = ($f)(vx, vy)
powval = vy * ($f)(vx, vy - 1)
if isconstant(y)
logval = one(expv)
elseif iszero(vx) && vy > 0
logval = zero(vx)
else
logval = expv * log(vx)
end
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(expv, new_partials)
end,
begin
v = value(x)
expv = ($f)(v, y)
if y == zero(y) || iszero(partials(x))
new_partials = zero(partials(x))
else
new_partials = partials(x) * y * ($f)(v, y - 1)
end
return Dual{Tx}(expv, new_partials)
end,
begin
v = value(y)
expv = ($f)(x, v)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
return Dual{Ty}(expv, deriv * partials(y))
end,
$AMBIGUOUS_TYPES
)
end
end

# hypot #
#-------#

@eval begin
@define_ternary_dual_op(
Base.hypot,
calc_hypot(x, y, z, Txyz),
calc_hypot(x, y, z, Txy),
calc_hypot(x, y, z, Txz),
calc_hypot(x, y, z, Tyz),
calc_hypot(x, y, z, Tx),
calc_hypot(x, y, z, Ty),
calc_hypot(x, y, z, Tz),
$AMBIGUOUS_TYPES
)
end

# muladd #
#--------#

@eval begin
@define_ternary_dual_op(
Base.muladd,
calc_muladd_xyz(x, y, z), # xyz_body
calc_muladd_xy(x, y, z), # xy_body
calc_muladd_xz(x, y, z), # xz_body
Base.muladd(y, x, z), # yz_body
Dual{Tx}(muladd(value(x), y, z), partials(x) * y), # x_body
Base.muladd(y, x, z), # y_body
Dual{Tz}(muladd(x, y, value(z)), partials(z)), # z_body
$AMBIGUOUS_TYPES
)
end

end
109 changes: 109 additions & 0 deletions test/forwarddiff_symbolic_dual_ops.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using ForwardDiff
using Symbolics
using Symbolics.SymbolicUtils
using Symbolics.SymbolicUtils.SpecialFunctions
using Symbolics.NaNMath
using Test

SF = SymbolicUtils.SpecialFunctions

@variables x

# Test functions from Symbolics #
#-------------------------------#

for f ∈ SymbolicUtils.basic_monadic
fun = eval(:(ξ ->($f)(ξ)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

for f ∈ SymbolicUtils.monadic
# The polygamma and trigamma functions seem to be missing rules in ForwardDiff.
# The abs rule uses conditionals and cannot be used with Symbolics.Num.
# acsc, asech, NanMath.log2 and NaNMath.log10 are tested separately
if f ∈ (abs, SF.polygamma, SF.trigamma, acsc, acsch, asech, NaNMath.log2, NaNMath.log10)
continue
end

fun = eval(:(ξ ->($f)(ξ)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

# These are evaluated numerically. For some reason isequal evaluates to false for the symbolic expressions.
for f ∈ (acsc, asech, NaNMath.log2, NaNMath.log10)
fun = eval(:(ξ ->($f)(ξ)))

fd = ForwardDiff.derivative(fun, 1.0)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test fd ≈ substitute(sym, Dict(x => 1.0))
end

for f ∈ SymbolicUtils.basic_diadic
if f ∈ (//,)
continue
end

fun = eval(:(ξ ->($f)(ξ, 2.0)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

for f ∈ SymbolicUtils.diadic
if f ∈ (max, min, NaNMath.atanh, mod, rem, copysign, besselj, bessely, besseli, besselk)
continue
end

fun = eval(:(ξ ->($f)(ξ, 2.0)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

for f ∈ (NaNMath.atanh,)
fun = eval(:(ξ ->($f)(ξ)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

for f ∈ (besselj, bessely, besseli, besselk)
fun = eval(:(ξ ->($f)(ξ, 2)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

# Additionally test these definitions from ForwardDiff #
#------------------------------------------------------#

# https://github.com/JuliaDiff/ForwardDiff.jl/blob/d3002093beb88ff0b98ed178377961dfd55c1247/src/dual.jl#L599
# and
# https://github.com/JuliaDiff/ForwardDiff.jl/blob/d3002093beb88ff0b98ed178377961dfd55c1247/src/dual.jl#L683
for f ∈ (hypot, muladd)
fun = eval(:(ξ ->($f)(ξ, 2.0, 3.0)))

fd = ForwardDiff.derivative(fun, 5.0)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test fd ≈ substitute(sym, Dict(x => 5.0))
end

# fma is not defined for Symbolics.Num
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Linear Solver Test" begin include("linear_solver.jl") end
@safetestset "Algebraic Solver Test" begin include("solver.jl") end
@safetestset "Overloading Test" begin include("overloads.jl") end
@safetestset "ForwardDiff Extension Test" begin include("forwarddiff_symbolic_dual_ops.jl") end
@safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end
@safetestset "Build Function Test" begin include("build_function.jl") end
@safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end
Expand Down
Loading