diff --git a/benchmark/pinn.jl b/benchmark/pinn.jl index 5ebc910..c2ba229 100644 --- a/benchmark/pinn.jl +++ b/benchmark/pinn.jl @@ -6,11 +6,14 @@ using Plots const input = 2 const hidden = 16 -model = Chain(Dense(input => hidden, sin), - Dense(hidden => hidden, sin), - Dense(hidden => 1), - first) -trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x) +# model = Chain(Dense(input => hidden, exp), +# Dense(hidden => hidden, exp), +# Dense(hidden => 1), +# first) +# trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x) + +model = Chain(Dense(input => 1, exp), first) +trial(model, x) = model(x) M = 100 data = [rand(Float32, input) for _ in 1:M] @@ -25,7 +28,7 @@ function loss_by_finitediff(model, x) end function loss_by_taylordiff(model, x) f(x) = trial(model, x) - error = derivative(f, x, Float32[1, 0], 2) + derivative(f, x, Float32[0, 1], 2) + + error = derivative(f, x, Float32[1, 0], Val(3)) + derivative(f, x, Float32[0, 1], Val(3)) + sin(π * x[1]) * sin(π * x[2]) abs2(error) end diff --git a/src/chainrules.jl b/src/chainrules.jl index 2b10231..2ea9fca 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,4 +1,4 @@ -import ChainRulesCore: rrule, RuleConfig, ProjectTo +import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing using ZygoteRules: @adjoint function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} @@ -39,7 +39,8 @@ end function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T} value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(v̄) # for structural tangent, convert to tuple - value_pullback(v̄) = NoTangent(), TaylorScalar(map(x -> convert(T, x), Tuple(v̄))) + value_pullback(v̄::Tangent{P, NTuple{N, T}}) where P = NoTangent(), TaylorScalar{T, N}(backing(v̄)) + value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(map(x -> convert(T, x), Tuple(v̄))) return value(t), value_pullback end diff --git a/src/derivative.jl b/src/derivative.jl index 33ce200..b035ee8 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -28,8 +28,11 @@ end return extract_derivative(f(t), N) end +# Need to rewrite like this to help Zygote infer types +make_taylor(t0::T, t1::T, ::Val{N}) where {T, N} = TaylorScalar{T, N}(t0, t1) + @inline function derivative(f, x::Vector{T}, l::Vector{T}, - ::Val{N}) where {T <: Number, N} - t = map(TaylorScalar{T, N}, x, l) + vN::Val{N}) where {T <: Number, N} + t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) end diff --git a/src/primitive.jl b/src/primitive.jl index d31de26..31b828a 100644 --- a/src/primitive.jl +++ b/src/primitive.jl @@ -15,6 +15,14 @@ import Base: hypot, max, min @inline cbrt(t::TaylorScalar) = ^(t, 1 / 3) @inline inv(t::TaylorScalar) = one(t) / t +exp(t::TaylorScalar{T, 2}) where T = let v = value(t), e1 = exp(v[1]) + TaylorScalar{T, 2}((e1, e1 * v[2])) +end + +exp(t::TaylorScalar{T, 3}) where T = let v = value(t), e1 = exp(v[1]) + TaylorScalar{T, 3}((e1, e1 * v[2], e1 * v[3] + e1 * v[2] * v[2])) +end + for func in (:exp, :expm1, :exp2, :exp10) @eval @generated function $func(t::TaylorScalar{T, N}) where {T, N} ex = quote @@ -37,7 +45,7 @@ for func in (:exp, :expm1, :exp2, :exp10) if $(QuoteNode(func)) == :expm1 ex = :($ex; v1 = expm1(v[1])) end - ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...))) + ex = :($ex; TaylorScalar{T, N}(tuple($([Symbol('v', i) for i in 1:N]...)))) return :(@inbounds $ex) end end