diff --git a/src/utils.jl b/src/utils.jl index 4cc9c23..4dff349 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,32 +1,46 @@ using ChainRules using ChainRulesCore -using Symbolics: @variables -using SymbolicUtils, SymbolicUtils.Code -using SymbolicUtils: Pow +using Symbolics: @variables, @rule, unwrap, isdiv +using SymbolicUtils.Code: toexpr -dummy = (NoTangent(), 1) -@variables z +""" +Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule. +""" +function get_term_raiser(func) + @variables z + r1 = @rule -1 * (1 / ~x) => (-1) / ~x + der = frule((NoTangent(), true), func, z)[2] + term = unwrap(der) + maybe_rewrite = r1(term) + if maybe_rewrite !== nothing + term = maybe_rewrite + end + if isdiv(term) && (term.num == 1 || term.num == -1) + term.den * term.num, raiseinv + else + term, raise + end +end function define_unary_function(func, m) F = typeof(func) - # base case + # First order: call frule directly @eval m function (op::$F)(t::TaylorScalar{T, 1}) where {T} t0 = value(t) t1 = first(partials(t)) f0, f1 = frule((NoTangent(), t1), op, t0) TaylorScalar{T, 1}(f0, zero_tangent(f0) + f1) end - der = frule(dummy, func, z)[2] - term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise) - # recursion by raising + term, raiser = get_term_raiser(func) + # Higher order: recursion by raising @eval m @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N} - der_expr = $(QuoteNode(toexpr(term))) + expr = $(QuoteNode(toexpr(term))) f = $func quote $(Expr(:meta, :inline)) z = TaylorScalar{T, N - 1}(t) f0 = $f(value(t)[1]) - df = zero_tangent(z) + $der_expr + df = zero_tangent(z) + $expr $$raiser(f0, df, t) end end