Skip to content

Commit

Permalink
Experiment with Zygote codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Jan 17, 2023
1 parent b1aaf45 commit 36737d5
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
15 changes: 9 additions & 6 deletions benchmark/pinn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ using Plots
const input = 2
const hidden = 16

model = Chain(Dense(input => hidden, sin),
Dense(hidden => hidden, sin),
Dense(hidden => 1),
first)
trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x)
# model = Chain(Dense(input => hidden, exp),
# Dense(hidden => hidden, exp),
# Dense(hidden => 1),
# first)
# trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x)

model = Chain(Dense(input => 1, exp), first)
trial(model, x) = model(x)

M = 100
data = [rand(Float32, input) for _ in 1:M]
Expand All @@ -25,7 +28,7 @@ function loss_by_finitediff(model, x)
end
function loss_by_taylordiff(model, x)
f(x) = trial(model, x)
error = derivative(f, x, Float32[1, 0], 2) + derivative(f, x, Float32[0, 1], 2) +
error = derivative(f, x, Float32[1, 0], Val(3)) + derivative(f, x, Float32[0, 1], Val(3)) +
sin* x[1]) * sin* x[2])
abs2(error)
end
Expand Down
5 changes: 3 additions & 2 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import ChainRulesCore: rrule, RuleConfig, ProjectTo
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing
using ZygoteRules: @adjoint

function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
Expand Down Expand Up @@ -39,7 +39,8 @@ end
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(v̄)
# for structural tangent, convert to tuple
value_pullback(v̄) = NoTangent(), TaylorScalar(map(x -> convert(T, x), Tuple(v̄)))
value_pullback(v̄::Tangent{P, NTuple{N, T}}) where P = NoTangent(), TaylorScalar{T, N}(backing(v̄))
value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(map(x -> convert(T, x), Tuple(v̄)))
return value(t), value_pullback
end

Expand Down
7 changes: 5 additions & 2 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ end
return extract_derivative(f(t), N)
end

# Need to rewrite like this to help Zygote infer types
make_taylor(t0::T, t1::T, ::Val{N}) where {T, N} = TaylorScalar{T, N}(t0, t1)

@inline function derivative(f, x::Vector{T}, l::Vector{T},
::Val{N}) where {T <: Number, N}
t = map(TaylorScalar{T, N}, x, l)
vN::Val{N}) where {T <: Number, N}
t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l)
return extract_derivative(f(t), N)
end
10 changes: 9 additions & 1 deletion src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ import Base: hypot, max, min
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
@inline inv(t::TaylorScalar) = one(t) / t

exp(t::TaylorScalar{T, 2}) where T = let v = value(t), e1 = exp(v[1])
TaylorScalar{T, 2}((e1, e1 * v[2]))
end

exp(t::TaylorScalar{T, 3}) where T = let v = value(t), e1 = exp(v[1])
TaylorScalar{T, 3}((e1, e1 * v[2], e1 * v[3] + e1 * v[2] * v[2]))
end

for func in (:exp, :expm1, :exp2, :exp10)
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
ex = quote
Expand All @@ -37,7 +45,7 @@ for func in (:exp, :expm1, :exp2, :exp10)
if $(QuoteNode(func)) == :expm1
ex = :($ex; v1 = expm1(v[1]))
end
ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...)))
ex = :($ex; TaylorScalar{T, N}(tuple($([Symbol('v', i) for i in 1:N]...))))
return :(@inbounds $ex)
end
end
Expand Down

0 comments on commit 36737d5

Please sign in to comment.