diff --git a/Project.toml b/Project.toml index 881541e4d..6764a580c 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" [extensions] SymbolicsGroebnerExt = "Groebner" SymbolicsPreallocationToolsExt = ["ForwardDiff", "PreallocationTools"] +SymbolicsForwardDiffExt = "ForwardDiff" SymbolicsSymPyExt = "SymPy" [compat] diff --git a/ext/SymbolicsForwardDiffExt.jl b/ext/SymbolicsForwardDiffExt.jl new file mode 100644 index 000000000..480fc83f5 --- /dev/null +++ b/ext/SymbolicsForwardDiffExt.jl @@ -0,0 +1,256 @@ +module SymbolicsForwardDiffExt + +using ForwardDiff +using ForwardDiff.NaNMath +using ForwardDiff.DiffRules +using ForwardDiff: value, Dual, partials +using Symbolics + +# The method generation in this file have been adapted from +# https://github.com/JuliaDiff/ForwardDiff.jl/blob/v0.10.36/src/dual.jl + +const AMBIGUOUS_TYPES = (Num,) + +#################################### +# N-ary Operation Definition Tools # +#################################### + +macro define_binary_dual_op(f, xy_body, x_body, y_body, Ts) + FD = ForwardDiff + defs = quote end + for R in Ts + expr = quote + @inline $(f)(x::$FD.Dual{Tx}, y::$R) where {Tx} = $x_body + @inline $(f)(x::$R, y::$FD.Dual{Ty}) where {Ty} = $y_body + end + append!(defs.args, expr.args) + end + return esc(defs) +end + +macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body, Ts) + FD = ForwardDiff + defs = quote end + for R in Ts + expr = quote + @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body + @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body + @inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body + @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz ≺ Tx ? $x_body : $z_body + @inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body + @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body + end + append!(defs.args, expr.args) + for Q in Ts + Q === R && continue + expr = quote + @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body + @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body + @inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body + end + append!(defs.args, expr.args) + end + expr = quote + @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body + @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body + @inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body + end + append!(defs.args, expr.args) + end + return esc(defs) +end + +function binary_dual_definition(M, f, Ts) + FD = ForwardDiff + dvx, dvy = DiffRules.diffrule(M, f, :vx, :vy) + Mf = M == :Base ? f : :($M.$f) + xy_work = FD.qualified_cse!(quote + val = $Mf(vx, vy) + dvx = $dvx + dvy = $dvy + end) + dvx, _ = DiffRules.diffrule(M, f, :vx, :y) + x_work = FD.qualified_cse!(quote + val = $Mf(vx, y) + dvx = $dvx + end) + _, dvy = DiffRules.diffrule(M, f, :x, :vy) + y_work = FD.qualified_cse!(quote + val = $Mf(x, vy) + dvy = $dvy + end) + expr = quote + @define_binary_dual_op( + $M.$f, + begin + vx, vy = $FD.value(x), $FD.value(y) + $xy_work + return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y)) + end, + begin + vx = $FD.value(x) + $x_work + return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x)) + end, + begin + vy = $FD.value(y) + $y_work + return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y)) + end, + $Ts + ) + end + return expr +end + +################################### +# General Mathematical Operations # +################################### + +for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) + if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-), (:Base, :sin), (:Base, :cos)) + continue # Skip methods which we define elsewhere. + elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) + continue # Skip rules for methods not defined in the current scope + end + if arity == 1 + # no-op + elseif arity == 2 + eval(binary_dual_definition(M, f, AMBIGUOUS_TYPES)) + else + # no-op + end +end + +################# +# Special Cases # +################# + +# +/- # +#-----# + +@eval begin + @define_binary_dual_op( + Base.:+, + begin + vx, vy = value(x), value(y) + Dual{Txy}(vx + vy, partials(x) + partials(y)) + end, + Dual{Tx}(value(x) + y, partials(x)), + Dual{Ty}(x + value(y), partials(y)), + $AMBIGUOUS_TYPES + ) +end + +@eval begin + @define_binary_dual_op( + Base.:-, + begin + vx, vy = value(x), value(y) + Dual{Txy}(vx - vy, partials(x) - partials(y)) + end, + Dual{Tx}(value(x) - y, partials(x)), + Dual{Ty}(x - value(y), -partials(y)), + $AMBIGUOUS_TYPES + ) +end + +# / # +#---# + +# We can't use the normal diffrule autogeneration for this because (x/y) === (x * (1/y)) +# doesn't generally hold true for floating point; see issue #264 +@eval begin + @define_binary_dual_op( + Base.:/, + begin + vx, vy = value(x), value(y) + Dual{Txy}(vx / vy, _div_partials(partials(x), partials(y), vx, vy)) + end, + Dual{Tx}(value(x) / y, partials(x) / y), + begin + v = value(y) + divv = x / v + Dual{Ty}(divv, -(divv / v) * partials(y)) + end, + $AMBIGUOUS_TYPES + ) +end + +# exponentiation # +#----------------# + +for f in (:(Base.:^), :(NaNMath.pow)) + @eval begin + @define_binary_dual_op( + $f, + begin + vx, vy = value(x), value(y) + expv = ($f)(vx, vy) + powval = vy * ($f)(vx, vy - 1) + if isconstant(y) + logval = one(expv) + elseif iszero(vx) && vy > 0 + logval = zero(vx) + else + logval = expv * log(vx) + end + new_partials = _mul_partials(partials(x), partials(y), powval, logval) + return Dual{Txy}(expv, new_partials) + end, + begin + v = value(x) + expv = ($f)(v, y) + if y == zero(y) || iszero(partials(x)) + new_partials = zero(partials(x)) + else + new_partials = partials(x) * y * ($f)(v, y - 1) + end + return Dual{Tx}(expv, new_partials) + end, + begin + v = value(y) + expv = ($f)(x, v) + deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x) + return Dual{Ty}(expv, deriv * partials(y)) + end, + $AMBIGUOUS_TYPES + ) + end +end + +# hypot # +#-------# + +@eval begin + @define_ternary_dual_op( + Base.hypot, + calc_hypot(x, y, z, Txyz), + calc_hypot(x, y, z, Txy), + calc_hypot(x, y, z, Txz), + calc_hypot(x, y, z, Tyz), + calc_hypot(x, y, z, Tx), + calc_hypot(x, y, z, Ty), + calc_hypot(x, y, z, Tz), + $AMBIGUOUS_TYPES + ) +end + +# muladd # +#--------# + +@eval begin + @define_ternary_dual_op( + Base.muladd, + calc_muladd_xyz(x, y, z), # xyz_body + calc_muladd_xy(x, y, z), # xy_body + calc_muladd_xz(x, y, z), # xz_body + Base.muladd(y, x, z), # yz_body + Dual{Tx}(muladd(value(x), y, z), partials(x) * y), # x_body + Base.muladd(y, x, z), # y_body + Dual{Tz}(muladd(x, y, value(z)), partials(z)), # z_body + $AMBIGUOUS_TYPES + ) +end + +end diff --git a/test/forwarddiff_symbolic_dual_ops.jl b/test/forwarddiff_symbolic_dual_ops.jl new file mode 100644 index 000000000..1b89c79e6 --- /dev/null +++ b/test/forwarddiff_symbolic_dual_ops.jl @@ -0,0 +1,109 @@ +using ForwardDiff +using Symbolics +using Symbolics.SymbolicUtils +using Symbolics.SymbolicUtils.SpecialFunctions +using Symbolics.NaNMath +using Test + +SF = SymbolicUtils.SpecialFunctions + +@variables x + +# Test functions from Symbolics # +#-------------------------------# + +for f ∈ SymbolicUtils.basic_monadic + fun = eval(:(ξ ->($f)(ξ))) + + fd = ForwardDiff.derivative(fun, x) + sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives + + @test isequal(fd, sym) +end + +for f ∈ SymbolicUtils.monadic + # The polygamma and trigamma functions seem to be missing rules in ForwardDiff. + # The abs rule uses conditionals and cannot be used with Symbolics.Num. + # acsc, asech, NanMath.log2 and NaNMath.log10 are tested separately + if f ∈ (abs, SF.polygamma, SF.trigamma, acsc, acsch, asech, NaNMath.log2, NaNMath.log10) + continue + end + + fun = eval(:(ξ ->($f)(ξ))) + + fd = ForwardDiff.derivative(fun, x) + sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives + + @test isequal(fd, sym) +end + +# These are evaluated numerically. For some reason isequal evaluates to false for the symbolic expressions. +for f ∈ (acsc, asech, NaNMath.log2, NaNMath.log10) + fun = eval(:(ξ ->($f)(ξ))) + + fd = ForwardDiff.derivative(fun, 1.0) + sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives + + @test fd ≈ substitute(sym, Dict(x => 1.0)) +end + +for f ∈ SymbolicUtils.basic_diadic + if f ∈ (//,) + continue + end + + fun = eval(:(ξ ->($f)(ξ, 2.0))) + + fd = ForwardDiff.derivative(fun, x) + sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives + + @test isequal(fd, sym) +end + +for f ∈ SymbolicUtils.diadic + if f ∈ (max, min, NaNMath.atanh, mod, rem, copysign, besselj, bessely, besseli, besselk) + continue + end + + fun = eval(:(ξ ->($f)(ξ, 2.0))) + + fd = ForwardDiff.derivative(fun, x) + sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives + + @test isequal(fd, sym) +end + +for f ∈ (NaNMath.atanh,) + fun = eval(:(ξ ->($f)(ξ))) + + fd = ForwardDiff.derivative(fun, x) + sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives + + @test isequal(fd, sym) +end + +for f ∈ (besselj, bessely, besseli, besselk) + fun = eval(:(ξ ->($f)(ξ, 2))) + + fd = ForwardDiff.derivative(fun, x) + sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives + + @test isequal(fd, sym) +end + +# Additionally test these definitions from ForwardDiff # +#------------------------------------------------------# + +# https://github.com/JuliaDiff/ForwardDiff.jl/blob/d3002093beb88ff0b98ed178377961dfd55c1247/src/dual.jl#L599 +# and +# https://github.com/JuliaDiff/ForwardDiff.jl/blob/d3002093beb88ff0b98ed178377961dfd55c1247/src/dual.jl#L683 +for f ∈ (hypot, muladd) + fun = eval(:(ξ ->($f)(ξ, 2.0, 3.0))) + + fd = ForwardDiff.derivative(fun, 5.0) + sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives + + @test fd ≈ substitute(sym, Dict(x => 5.0)) +end + +# fma is not defined for Symbolics.Num diff --git a/test/runtests.jl b/test/runtests.jl index d8e6e544c..7d42fb555 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Linear Solver Test" begin include("linear_solver.jl") end @safetestset "Algebraic Solver Test" begin include("solver.jl") end @safetestset "Overloading Test" begin include("overloads.jl") end + @safetestset "ForwardDiff Extension Test" begin include("forwarddiff_symbolic_dual_ops.jl") end @safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end @safetestset "Build Function Test" begin include("build_function.jl") end @safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end