Skip to content

Commit

Permalink
Fix zygote compat
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Sep 26, 2024
1 parent f1f850b commit c11dd72
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 112 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,22 @@ version = "0.2.4"
[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
TaylorDiffNNlibExt = ["NNlib"]
TaylorDiffSFExt = ["SpecialFunctions"]
TaylorDiffZygoteExt = ["Zygote"]

[compat]
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesOverloadGeneration = "0.1"
NNlib = "0.9"
SpecialFunctions = "2"
SymbolicUtils = "2, 3"
Expand Down
22 changes: 22 additions & 0 deletions ext/TaylorDiffZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module TaylorDiffZygoteExt

using TaylorDiff
import Zygote: @adjoint, Numeric, _dual_safearg, ZygoteRuleConfig
using ChainRulesCore: @opt_out

# Zygote can't infer this constructor function
# defining rrule for this doesn't seem to work for Zygote
# so need to use @adjoint
@adjoint TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T, N, M} = TaylorScalar{T, N}(t),
-> (TaylorScalar{T, M}(x̄),)

# Zygote will try to use ForwardDiff to compute broadcast functions
# However, TaylorScalar is not dual safe, so we opt out of this
_dual_safearg(::Numeric{<:TaylorScalar}) = false

# Zygote has a rule for literal power, need to opt out of this
@opt_out rrule(
::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::TaylorScalar, ::Val{p}
) where {p}

end
109 changes: 28 additions & 81 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
using Base.Broadcast: broadcasted
import Zygote: @adjoint, accum_sum, unbroadcast, Numeric, ∇getindex, _project

function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
mapreduce(*, +, value(a), value(b))
Expand Down Expand Up @@ -31,7 +30,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
end

function rrule(::typeof(*), A::AbstractMatrix{S},
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real}
project_A = ProjectTo(A)
function gemv_pullback(x̄)
= reinterpret(reshape, T, x̄)
Expand All @@ -42,7 +41,7 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
end

function rrule(::typeof(*), A::AbstractMatrix{S},
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real}
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function gemm_pullback(x̄)
Expand All @@ -54,88 +53,36 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
return A * B, gemm_pullback
end

@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T}
project_v = ProjectTo(v)
t + v, x̄ -> (x̄, project_v(x̄))
end

@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T}
project_v = ProjectTo(v)
v + t, x̄ -> (project_v(x̄), x̄)
end

(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T} = primal(dx)

# Not-a-number patches

ProjectTo(::T) where {T <: TaylorScalar} = ProjectTo{T}()
(p::ProjectTo{T})(x::T) where {T <: TaylorScalar} = x
function ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar}
ProjectTo{AbstractArray}(; element = ProjectTo(zero(T)), axes = axes(x))
end
(p::ProjectTo{AbstractArray{T}})(x::AbstractArray{T}) where {T <: TaylorScalar} = x
accum_sum(xs::AbstractArray{T}; dims = :) where {T <: TaylorScalar} = sum(xs, dims = dims)

TaylorNumeric{T <: TaylorScalar} = Union{T, AbstractArray{<:T}}

@adjoint function broadcasted(::typeof(+), xs::TaylorNumeric...)
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)
end
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)

struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N}
val::T
ind::I
axes::A
function TaylorOneElement(val::T, ind::I,
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
A <: NTuple{N, AbstractUnitRange}} where {N}
new{T, N, I, A}(val, ind, axes)
end
end
# opt-outs

Base.size(A::TaylorOneElement) = map(length, A.axes)
Base.axes(A::TaylorOneElement) = A.axes
function Base.getindex(A::TaylorOneElement{T, N}, i::Vararg{Int, N}) where {T, N}
ifelse(i == A.ind, A.val, zero(T))
end

function ∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N}
dy -> begin
dx = TaylorOneElement(dy, inds, axes(x))
return (_project(x, dx), map(_ -> nothing, inds)...)
end
end
# Unary functions

@generated function mul_adjoint::TaylorScalar{T, N}, x::TaylorScalar{T, N}) where {T, N}
return quote
vΩ, vx = value(Ω), value(x)
@inbounds TaylorScalar($([:(+($([:($(binomial(j - 1, i - 1)) * vΩ[$j] *
vx[$(j + 1 - i)]) for j in i:N]...)))
for i in 1:N]...))
end
for f in (
exp, exp10, exp2, expm1,
sin, cos, tan, sec, csc, cot,
sinh, cosh, tanh, sech, csch, coth,
log, log10, log2, log1p,
asin, acos, atan, asec, acsc, acot,
asinh, acosh, atanh, asech, acsch, acoth,
sqrt, cbrt, inv
)
@eval @opt_out frule(::typeof($f), x::TaylorScalar)
@eval @opt_out rrule(::typeof($f), x::TaylorScalar)
end

rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x)

function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar)
function times_pullback2(Ω̇)
ΔΩ = unthunk(Ω̇)
return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)),
ProjectTo(y)(mul_adjoint(ΔΩ, x)))
end
return x * y, times_pullback2
end
# Binary functions

function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar,
more::TaylorScalar...)
Ω2, back2 = rrule(*, x, y)
Ω3, back3 = rrule(*, Ω2, z)
Ω4, back4 = rrule(*, Ω3, more...)
function times_pullback4(Ω̇)
Δ4 = back4(unthunk(Ω̇)) # (0, ΔΩ3, Δmore...)
Δ3 = back3(Δ4[2]) # (0, ΔΩ2, Δz)
Δ2 = back2(Δ3[2]) # (0, Δx, Δy)
return (Δ2..., Δ3[3], Δ4[3:end]...)
for f in (
*, /, ^
)
for (tlhs, trhs) in (
(TaylorScalar, TaylorScalar),
(TaylorScalar, Number),
(Number, TaylorScalar)
)
@eval @opt_out frule(::typeof($f), x::$tlhs, y::$trhs)
@eval @opt_out rrule(::typeof($f), x::$tlhs, y::$trhs)
end
return Ω4, times_pullback4
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ using Test

include("primitive.jl")
include("derivative.jl")
# include("zygote.jl")
include("zygote.jl")
# include("lux.jl")
52 changes: 25 additions & 27 deletions test/zygote.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
using Zygote, LinearAlgebra
using LinearAlgebra
import Zygote # use qualified import to avoid conflict with TaylorDiff

@testset "Zygote for mixed derivative" begin
@testset "Zygote-over-TaylorDiff on same variable" begin
# Scalar functions
some_number = 0.7
some_numbers = [0.3 0.4 0.1;]
for f in (exp, log, sqrt, sin, asin, sinh, asinh)
@test gradient(x -> derivative(f, x, 2), some_number)[1]
for f in (exp, log, sqrt, sin, asin, sinh, asinh, x -> x^3)
@test Zygote.gradient(derivative, f, some_number, 2)[2]
derivative(f, some_number, 3)
derivative_result = vec(derivative.(f, some_numbers, 3))
@test Zygote.jacobian(x -> derivative.(f, x, 2), some_numbers)[1]
diagm(derivative_result)
@test Zygote.jacobian(broadcast, derivative, f, some_numbers, 2)[3]
diagm(vec(derivative.(f, some_numbers, 3)))
end

some_matrix = [0.7 0.1; 0.4 0.2]
f = x -> sum(tanh.(x), dims = 1)
dfdx1(m, x) = derivative(m, x, [1.0, 0.0], 1)
dfdx2(m, x) = derivative(m, x, [0.0, 1.0], 1)
res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x)
grads = Zygote.gradient(some_matrix) do x
sum(res(f, x))
end
expected_grads = x -> -2 * sinh(x) / cosh(x)^3
@test grads[1] [1 0; 0 2] * expected_grads.(some_matrix)

@test gradient(x -> derivative(x -> x * x, x, 1),
5.0)[1] 2.0

# Vector functions
g(x) = x[1] * x[1] + x[2] * x[2]
@test gradient(x -> derivative(g, x, [1.0, 0.0], 1),
[1.0, 2.0])[1] [2.0, 0.0]
@test Zygote.gradient(derivative, g, [1.0, 2.0], [1.0, 0.0], 1)[2] [2.0, 0.0]

# Matrix functions
some_matrix = [0.7 0.1; 0.4 0.2]
f(x) = sum(exp.(x), dims = 1)
dfdx1(x) = derivative(f, x, [1.0, 0.0], 1)
dfdx2(x) = derivative(f, x, [0.0, 1.0], 1)
res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x))
grads = Zygote.gradient(res, some_matrix)
@test grads[1] [1 0; 0 2] * exp.(some_matrix)
end

@testset "Zygote for parameter optimization" begin
gradient(p -> derivative(x -> sum(exp.(x + p)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
gradient(p -> derivative(x -> sum(exp.(p + x)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
@testset "Zygote-over-TaylorDiff on different variable" begin
Zygote.gradient(
p -> derivative(x -> sum(exp.(x + p)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
Zygote.gradient(
p -> derivative(x -> sum(exp.(p + x)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
linear_model(x, p, b) = exp.(b + p * x + b)[1]
some_x, some_v, some_p, some_b = [0.58, 0.36], [0.23, 0.11], [0.49 0.96], [0.88]
loss_taylor(p) = derivative(x -> linear_model(x, p, some_b), some_x, some_v, 1)
Expand All @@ -41,5 +39,5 @@ end
let f = x -> linear_model(x, p, some_b)
(f(some_x + ε * some_v) - f(some_x - ε * some_v)) / 2ε
end
@test gradient(loss_taylor, some_p)[1] gradient(loss_finite, some_p)[1]
@test Zygote.gradient(loss_taylor, some_p)[1] Zygote.gradient(loss_finite, some_p)[1]
end

0 comments on commit c11dd72

Please sign in to comment.