Skip to content

Commit

Permalink
Improve log-like function
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Oct 3, 2024
1 parent e6ab873 commit 1005d83
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,46 @@
using ChainRules
using ChainRulesCore
using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: Pow
using Symbolics: @variables, @rule, unwrap, isdiv
using SymbolicUtils.Code: toexpr

dummy = (NoTangent(), 1)
@variables z
"""
Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule.
"""
function get_term_raiser(func)
@variables z
r1 = @rule -1 * (1 / ~x) => (-1) / ~x
der = frule((NoTangent(), true), func, z)[2]
term = unwrap(der)
maybe_rewrite = r1(term)
if maybe_rewrite !== nothing
term = maybe_rewrite
end
if isdiv(term) && (term.num == 1 || term.num == -1)
term.den * term.num, raiseinv
else
term, raise
end
end

function define_unary_function(func, m)
F = typeof(func)
# base case
# First order: call frule directly
@eval m function (op::$F)(t::TaylorScalar{T, 1}) where {T}
t0 = value(t)
t1 = first(partials(t))
f0, f1 = frule((NoTangent(), t1), op, t0)
TaylorScalar{T, 1}(f0, zero_tangent(f0) + f1)
end
der = frule(dummy, func, z)[2]
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
# recursion by raising
term, raiser = get_term_raiser(func)
# Higher order: recursion by raising
@eval m @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
der_expr = $(QuoteNode(toexpr(term)))
expr = $(QuoteNode(toexpr(term)))
f = $func
quote
$(Expr(:meta, :inline))
z = TaylorScalar{T, N - 1}(t)
f0 = $f(value(t)[1])
df = zero_tangent(z) + $der_expr
df = zero_tangent(z) + $expr
$$raiser(f0, df, t)
end
end
Expand Down

0 comments on commit 1005d83

Please sign in to comment.