Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving Zygote autodiff perfmance #291

Closed
wants to merge 13 commits into from
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
DataInterpolationsChainRulesCoreExt = ["ChainRulesCore", "SparseArrays"]
DataInterpolationsOptimExt = "Optim"
DataInterpolationsRegularizationToolsExt = "RegularizationTools"
DataInterpolationsSymbolicsExt = "Symbolics"
Expand All @@ -38,6 +39,7 @@ RecipesBase = "1.3"
Reexport = "1"
RegularizationTools = "0.6"
SafeTestsets = "0.1"
SparseArrays = "1.10"
StableRNGs = "1"
Symbolics = "5.29"
Test = "1"
Expand Down
199 changes: 184 additions & 15 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,202 @@
module DataInterpolationsChainRulesCoreExt

if isdefined(Base, :get_extension)
using DataInterpolations: _interpolate, derivative, AbstractInterpolation,
using DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx,
cumulative_integral, LinearParameterCache,
QuadraticSplineParameterCache,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox
BSplineInterpolation, BSplineApprox, LinearInterpolation,
QuadraticSpline
using ChainRulesCore
using LinearAlgebra
using SparseArrays
using ReadOnlyArrays
else
using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation,
using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx,
cumulative_integral, LinearParameterCache,
QuadraticSplineParameterCache,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox
BSplineInterpolation, BSplineApprox, LinearInterpolation,
QuadraticSpline
using ..ChainRulesCore
using ..LinearAlgebra
using ..SparseArrays
using ..ReadOnlyArrays
end

function ChainRulesCore.rrule(::typeof(_interpolate),
A::Union{
LagrangeInterpolation,
AkimaInterpolation,
BSplineInterpolation,
BSplineApprox
},
t::Number)
deriv = derivative(A, t)
interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ)
return _interpolate(A, t), interpolate_pullback
## Linear interpolation

function ChainRulesCore.rrule(
::Type{LinearParameterCache}, u::AbstractArray, t::AbstractVector)
p = LinearParameterCache(u, t)
du = zeros(eltype(p.slope), length(u))

function LinearParameterCache_pullback(Δp)
df = NoTangent()
du[2:end] += Δp.slope
du[1:(end - 1)] -= Δp.slope
dt = NoTangent()
return (df, du, dt)
end

p, LinearParameterCache_pullback
end

function ChainRulesCore.rrule(
::Type{LinearInterpolation}, u, t, I, p, extrapolate, safetycopy)
A = LinearInterpolation(u, t, I, p, extrapolate, safetycopy)

function LinearInterpolation_pullback(ΔA)
df = NoTangent()
du = ΔA.u
dt = NoTangent()
dI = NoTangent()
dp = ΔA.p
dextrapolate = NoTangent()
dsafetycopy = NoTangent()
return df, du, dt, dI, dp, dextrapolate, dsafetycopy
end

A, LinearInterpolation_pullback
end

function allocate_direct_field_tangents(A::LinearInterpolation)
idx = A.idx_prev[]
u = SparseVector(length(A.u), [idx], zeros(1))
(; u)
end

function allocate_parameter_tangents(A::LinearInterpolation)
idx = A.idx_prev[]
slope = SparseVector(length(A.p.slope), [idx], zeros(1))
return (; slope)
end

function _tangent_direct_fields!(
direct_field_tangents::NamedTuple, A::LinearInterpolation, Δt, Δ)
(; u) = direct_field_tangents
idx = A.idx_prev[]
u[idx] = Δ
end

function _tangent_p!(parameter_tangents::NamedTuple, A::LinearInterpolation, Δt, Δ)
(; slope) = parameter_tangents
idx = A.idx_prev[]
slope[idx] = Δt * Δ
end

## Quadratic Spline

function ChainRulesCore.rrule(::Type{QuadraticSplineParameterCache}, u, t)
p = QuadraticSplineParameterCache(u, t)
n = length(u)

Δt = diff(t)
diagonal_main = 2 ./ Δt
pushfirst!(diagonal_main, zero(eltype(diagonal_main)))
diagonal_down = -diagonal_main[2:end]
diagonal_up = zero(diagonal_down)
∂d_∂u = Tridiagonal(diagonal_down, diagonal_main, diagonal_up)

∂σ_∂z = spzeros(n, n - 1)
for i in 1:(n - 1)
∂σ_∂z[i, i] = -0.5 / Δt[i]
∂σ_∂z[i + 1, i] = 0.5 / Δt[i]
end

function QuadraticSplineParameterCache_pullback(Δp)
df = NoTangent()
temp1 = Δp.z + ∂σ_∂z * Δp.σ
temp2 = p.tA' \ temp1
du = ∂d_∂u' * temp2
dt = NoTangent()
return (df, du, dt)
end

p, QuadraticSplineParameterCache_pullback
end

function ChainRulesCore.rrule(::Type{QuadraticSpline}, u, t, I, p, extrapolate, safetycopy)
A = QuadraticSpline(u, t, I, p, extrapolate, safetycopy)

function QuadraticSpline_pullback(ΔA)
df = NoTangent()
du = ΔA.u
dt = NoTangent()
dI = NoTangent()
dp = ΔA.p
dextrapolate = NoTangent()
dsafetycopy = NoTangent()
return df, du, dt, dI, dp, dextrapolate, dsafetycopy
end

A, QuadraticSpline_pullback
end

function allocate_direct_field_tangents(A::QuadraticSpline)
idx = A.idx_prev[]
u = SparseVector(length(A.u), [idx], zeros(1))
(; u)
end

function allocate_parameter_tangents(A::QuadraticSpline)
idx = A.idx_prev[]
z = SparseVector(length(A.p.z), [idx], zeros(1))
σ = SparseVector(length(A.p.σ), [idx], zeros(1))
return (; z, σ)
end

function _tangent_direct_fields!(
direct_field_tangents::NamedTuple, A::QuadraticSpline, Δt, Δ)
(; u) = direct_field_tangents
idx = A.idx_prev[]
u[idx] = Δ
end

function _tangent_p!(parameter_tangents::NamedTuple, A::QuadraticSpline, Δt, Δ)
(; z, σ) = parameter_tangents
idx = A.idx_prev[]
z[idx] = Δ * Δt
σ[idx] = Δ * Δt^2
end

## generic

function ChainRulesCore.rrule(A::AType, t::Number) where {AType <: AbstractInterpolation}
u = A(t)
idx = get_idx(A.t, t, A.idx_prev[])
direct_field_tangents = allocate_direct_field_tangents(A)
parameter_tangents = allocate_parameter_tangents(A)

function _interpolate_pullback(Δ)
A.idx_prev[] = idx
Δt = t - A.t[idx]
_tangent_direct_fields!(direct_field_tangents, A, Δt, Δ)
_tangent_p!(parameter_tangents, A, Δt, Δ)
dA = Tangent{AType}(; direct_field_tangents...,
p = Tangent{typeof(A.p)}(; parameter_tangents...))
dt = @thunk(derivative(A, t)*Δ)
return dA, dt
end

u, _interpolate_pullback
end

function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation,
t::Number)
return _interpolate(A, t), derivative(A, t) * Δt
end

function ChainRulesCore.rrule(::Type{ReadOnlyArray}, parent)
read_only_array = ReadOnlyArray(parent)
ReadOnlyArray_pullback(Δ) = NoTangent(), Δ
read_only_array, ReadOnlyArray_pullback
end

function ChainRulesCore.rrule(::typeof(cumulative_integral), A, u)
I = cumulative_integral(A, u)
cumulative_integral_pullback(Δ) = NoTangent(), NoTangent()
I, cumulative_integral_pullback
end

end # module
3 changes: 2 additions & 1 deletion src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ end
function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
σ = A.p.σ[idx - 1]
A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx
z = A.p.z[idx - 1]
z + 2σ * (t - A.t[idx - 1]), idx
end

# CubicSpline Interpolation
Expand Down
5 changes: 3 additions & 2 deletions src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ end
function integral(A::AbstractInterpolation, t1::Number, t2::Number)
((t1 < A.t[1] || t1 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
((t2 < A.t[1] || t2 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
!hasfield(typeof(A), :I) && throw(IntegralNotFoundError())
has_I = hasfield(typeof(A), :I)
(!has_I || (has_I && isnothing(A.I))) && throw(IntegralNotFoundError())
# the index less than or equal to t1
idx1 = get_idx(A.t, t1, 0)
# the index less than t2
Expand Down Expand Up @@ -61,7 +62,7 @@ end
function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Cᵢ = A.u[idx]
Δt = t - A.t[idx]
return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt
return A.p.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt
end

function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Expand Down
Loading
Loading