From c11dd72ce8dc2eaf1ffc2def171edd50ebe72632 Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Thu, 26 Sep 2024 15:56:53 -0400 Subject: [PATCH] Fix zygote compat --- Project.toml | 5 +- ext/TaylorDiffZygoteExt.jl | 22 ++++++++ src/chainrules.jl | 109 ++++++++++--------------------------- test/runtests.jl | 2 +- test/zygote.jl | 52 +++++++++--------- 5 files changed, 78 insertions(+), 112 deletions(-) create mode 100644 ext/TaylorDiffZygoteExt.jl diff --git a/Project.toml b/Project.toml index ccb5495..0249b82 100644 --- a/Project.toml +++ b/Project.toml @@ -6,23 +6,22 @@ version = "0.2.4" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] TaylorDiffNNlibExt = ["NNlib"] TaylorDiffSFExt = ["SpecialFunctions"] +TaylorDiffZygoteExt = ["Zygote"] [compat] ChainRules = "1" ChainRulesCore = "1" -ChainRulesOverloadGeneration = "0.1" NNlib = "0.9" SpecialFunctions = "2" SymbolicUtils = "2, 3" diff --git a/ext/TaylorDiffZygoteExt.jl b/ext/TaylorDiffZygoteExt.jl new file mode 100644 index 0000000..cabcf41 --- /dev/null +++ b/ext/TaylorDiffZygoteExt.jl @@ -0,0 +1,22 @@ +module TaylorDiffZygoteExt + +using TaylorDiff +import Zygote: @adjoint, Numeric, _dual_safearg, ZygoteRuleConfig +using ChainRulesCore: @opt_out + +# Zygote can't infer this constructor function +# defining rrule for this doesn't seem to work for Zygote +# so need to use @adjoint +@adjoint TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T, N, M} = TaylorScalar{T, N}(t), +x̄ -> (TaylorScalar{T, M}(x̄),) + +# Zygote will try to use ForwardDiff to compute broadcast functions +# However, TaylorScalar is not dual safe, so we opt out of this +_dual_safearg(::Numeric{<:TaylorScalar}) = false + +# Zygote has a rule for literal power, need to opt out of this +@opt_out rrule( + ::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::TaylorScalar, ::Val{p} +) where {p} + +end diff --git a/src/chainrules.jl b/src/chainrules.jl index fd327c6..8edbf7a 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,6 +1,5 @@ -import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing +import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out using Base.Broadcast: broadcasted -import Zygote: @adjoint, accum_sum, unbroadcast, Numeric, ∇getindex, _project function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} mapreduce(*, +, value(a), value(b)) @@ -31,7 +30,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, end function rrule(::typeof(*), A::AbstractMatrix{S}, - t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T} + t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real} project_A = ProjectTo(A) function gemv_pullback(x̄) x̂ = reinterpret(reshape, T, x̄) @@ -42,7 +41,7 @@ function rrule(::typeof(*), A::AbstractMatrix{S}, end function rrule(::typeof(*), A::AbstractMatrix{S}, - B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T} + B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real} project_A = ProjectTo(A) project_B = ProjectTo(B) function gemm_pullback(x̄) @@ -54,88 +53,36 @@ function rrule(::typeof(*), A::AbstractMatrix{S}, return A * B, gemm_pullback end -@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T} - project_v = ProjectTo(v) - t + v, x̄ -> (x̄, project_v(x̄)) -end - -@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T} - project_v = ProjectTo(v) - v + t, x̄ -> (project_v(x̄), x̄) -end - -(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T} = primal(dx) - -# Not-a-number patches - -ProjectTo(::T) where {T <: TaylorScalar} = ProjectTo{T}() -(p::ProjectTo{T})(x::T) where {T <: TaylorScalar} = x -function ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar} - ProjectTo{AbstractArray}(; element = ProjectTo(zero(T)), axes = axes(x)) -end -(p::ProjectTo{AbstractArray{T}})(x::AbstractArray{T}) where {T <: TaylorScalar} = x -accum_sum(xs::AbstractArray{T}; dims = :) where {T <: TaylorScalar} = sum(xs, dims = dims) - -TaylorNumeric{T <: TaylorScalar} = Union{T, AbstractArray{<:T}} - -@adjoint function broadcasted(::typeof(+), xs::TaylorNumeric...) - broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) -end +(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx) -struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N} - val::T - ind::I - axes::A - function TaylorOneElement(val::T, ind::I, - axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int}, - A <: NTuple{N, AbstractUnitRange}} where {N} - new{T, N, I, A}(val, ind, axes) - end -end +# opt-outs -Base.size(A::TaylorOneElement) = map(length, A.axes) -Base.axes(A::TaylorOneElement) = A.axes -function Base.getindex(A::TaylorOneElement{T, N}, i::Vararg{Int, N}) where {T, N} - ifelse(i == A.ind, A.val, zero(T)) -end - -function ∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N} - dy -> begin - dx = TaylorOneElement(dy, inds, axes(x)) - return (_project(x, dx), map(_ -> nothing, inds)...) - end -end +# Unary functions -@generated function mul_adjoint(Ω::TaylorScalar{T, N}, x::TaylorScalar{T, N}) where {T, N} - return quote - vΩ, vx = value(Ω), value(x) - @inbounds TaylorScalar($([:(+($([:($(binomial(j - 1, i - 1)) * vΩ[$j] * - vx[$(j + 1 - i)]) for j in i:N]...))) - for i in 1:N]...)) - end +for f in ( + exp, exp10, exp2, expm1, + sin, cos, tan, sec, csc, cot, + sinh, cosh, tanh, sech, csch, coth, + log, log10, log2, log1p, + asin, acos, atan, asec, acsc, acot, + asinh, acosh, atanh, asech, acsch, acoth, + sqrt, cbrt, inv +) + @eval @opt_out frule(::typeof($f), x::TaylorScalar) + @eval @opt_out rrule(::typeof($f), x::TaylorScalar) end -rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x) - -function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar) - function times_pullback2(Ω̇) - ΔΩ = unthunk(Ω̇) - return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)), - ProjectTo(y)(mul_adjoint(ΔΩ, x))) - end - return x * y, times_pullback2 -end +# Binary functions -function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, - more::TaylorScalar...) - Ω2, back2 = rrule(*, x, y) - Ω3, back3 = rrule(*, Ω2, z) - Ω4, back4 = rrule(*, Ω3, more...) - function times_pullback4(Ω̇) - Δ4 = back4(unthunk(Ω̇)) # (0, ΔΩ3, Δmore...) - Δ3 = back3(Δ4[2]) # (0, ΔΩ2, Δz) - Δ2 = back2(Δ3[2]) # (0, Δx, Δy) - return (Δ2..., Δ3[3], Δ4[3:end]...) +for f in ( + *, /, ^ +) + for (tlhs, trhs) in ( + (TaylorScalar, TaylorScalar), + (TaylorScalar, Number), + (Number, TaylorScalar) + ) + @eval @opt_out frule(::typeof($f), x::$tlhs, y::$trhs) + @eval @opt_out rrule(::typeof($f), x::$tlhs, y::$trhs) end - return Ω4, times_pullback4 end diff --git a/test/runtests.jl b/test/runtests.jl index 23ec468..a28c383 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,5 +3,5 @@ using Test include("primitive.jl") include("derivative.jl") -# include("zygote.jl") +include("zygote.jl") # include("lux.jl") diff --git a/test/zygote.jl b/test/zygote.jl index 1ef667f..3dccedd 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -1,38 +1,36 @@ -using Zygote, LinearAlgebra +using LinearAlgebra +import Zygote # use qualified import to avoid conflict with TaylorDiff -@testset "Zygote for mixed derivative" begin +@testset "Zygote-over-TaylorDiff on same variable" begin + # Scalar functions some_number = 0.7 some_numbers = [0.3 0.4 0.1;] - for f in (exp, log, sqrt, sin, asin, sinh, asinh) - @test gradient(x -> derivative(f, x, 2), some_number)[1] ≈ + for f in (exp, log, sqrt, sin, asin, sinh, asinh, x -> x^3) + @test Zygote.gradient(derivative, f, some_number, 2)[2] ≈ derivative(f, some_number, 3) - derivative_result = vec(derivative.(f, some_numbers, 3)) - @test Zygote.jacobian(x -> derivative.(f, x, 2), some_numbers)[1] ≈ - diagm(derivative_result) + @test Zygote.jacobian(broadcast, derivative, f, some_numbers, 2)[3] ≈ + diagm(vec(derivative.(f, some_numbers, 3))) end - some_matrix = [0.7 0.1; 0.4 0.2] - f = x -> sum(tanh.(x), dims = 1) - dfdx1(m, x) = derivative(m, x, [1.0, 0.0], 1) - dfdx2(m, x) = derivative(m, x, [0.0, 1.0], 1) - res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x) - grads = Zygote.gradient(some_matrix) do x - sum(res(f, x)) - end - expected_grads = x -> -2 * sinh(x) / cosh(x)^3 - @test grads[1] ≈ [1 0; 0 2] * expected_grads.(some_matrix) - - @test gradient(x -> derivative(x -> x * x, x, 1), - 5.0)[1] ≈ 2.0 - + # Vector functions g(x) = x[1] * x[1] + x[2] * x[2] - @test gradient(x -> derivative(g, x, [1.0, 0.0], 1), - [1.0, 2.0])[1] ≈ [2.0, 0.0] + @test Zygote.gradient(derivative, g, [1.0, 2.0], [1.0, 0.0], 1)[2] ≈ [2.0, 0.0] + + # Matrix functions + some_matrix = [0.7 0.1; 0.4 0.2] + f(x) = sum(exp.(x), dims = 1) + dfdx1(x) = derivative(f, x, [1.0, 0.0], 1) + dfdx2(x) = derivative(f, x, [0.0, 1.0], 1) + res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x)) + grads = Zygote.gradient(res, some_matrix) + @test grads[1] ≈ [1 0; 0 2] * exp.(some_matrix) end -@testset "Zygote for parameter optimization" begin - gradient(p -> derivative(x -> sum(exp.(x + p)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7]) - gradient(p -> derivative(x -> sum(exp.(p + x)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7]) +@testset "Zygote-over-TaylorDiff on different variable" begin + Zygote.gradient( + p -> derivative(x -> sum(exp.(x + p)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7]) + Zygote.gradient( + p -> derivative(x -> sum(exp.(p + x)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7]) linear_model(x, p, b) = exp.(b + p * x + b)[1] some_x, some_v, some_p, some_b = [0.58, 0.36], [0.23, 0.11], [0.49 0.96], [0.88] loss_taylor(p) = derivative(x -> linear_model(x, p, some_b), some_x, some_v, 1) @@ -41,5 +39,5 @@ end let f = x -> linear_model(x, p, some_b) (f(some_x + ε * some_v) - f(some_x - ε * some_v)) / 2ε end - @test gradient(loss_taylor, some_p)[1] ≈ gradient(loss_finite, some_p)[1] + @test Zygote.gradient(loss_taylor, some_p)[1] ≈ Zygote.gradient(loss_finite, some_p)[1] end