From 9737c508af96c94b49ab8a5ec9f011190a82daf0 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sun, 7 Jul 2024 13:39:36 +0200 Subject: [PATCH 01/13] half a LinearInterpolation POC --- ext/DataInterpolationsChainRulesCoreExt.jl | 50 ++++++++++++++++------ src/parameter_caches.jl | 9 +++- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 34e27841..ebd00743 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -3,26 +3,50 @@ module DataInterpolationsChainRulesCoreExt if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, LinearInterpolation, + linear_interpolation_parameters using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, LinearInterpolation, + linear_interpolation_parameters using ..ChainRulesCore end -function ChainRulesCore.rrule(::typeof(_interpolate), - A::Union{ - LagrangeInterpolation, - AkimaInterpolation, - BSplineInterpolation, - BSplineApprox - }, - t::Number) - deriv = derivative(A, t) - interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) - return _interpolate(A, t), interpolate_pullback +function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), u, t, idx) + slope = linear_interpolation_parameters(u, t, idx) + + function linear_interpolation_parameters_pullback(Δslope) + du = @thunk(Δslope*...) # TODO: How to handle sparsity? + dt = @thunk(Δslope*...) # TODO: How to handle sparsity? + df = NoTangent() + didx = NoTangent() + + return (df, du, dt, didx) + end + + return slope, linear_interpolation_parameters_pullback +end + +function _tangent_u(A::LinearInterpolation, t) + ... # TODO: How to handle sparsity? +end + +function _tangent_t(A::LinearInterpolation, t) + ... # TODO: How to handle sparsity? +end + +function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t, iguess) where {AType} + u = _interpolate(A, t, iguess)[1] + function _interpolate_pullback(Δ) + df = NoTangent() + dA = Tangent{AType}(; u = _tangent_u(A, t), t = _tangent_t(A, t)) + dt = @thunk(derivative(A, t)*Δ) + diguess = NoTangent() + return df, dA, dt, diguess + end + u, _interpolate_pullback end function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation, diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..b550b621 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -3,11 +3,16 @@ struct LinearParameterCache{pType} end function LinearParameterCache(u, t) - slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) + slope_prototype = linear_interpolation_parameters(u, t, 1) + idxs = 1:(length(t) - 1) + slope = [zero(slope_prototype) for i in idxs] + for idx in idxs + slope[idx] = linear_interpolation_parameters(u, t, idx) + end return LinearParameterCache(slope) end -function linear_interpolation_parameters(u, t, idx) +function linear_interpolation_parameters(u, t, idx::Integer) Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] Δt = t[idx + 1] - t[idx] slope = Δu / Δt From 16e985ee96a5c91261168d3b7e5c8937c274b43d Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sun, 7 Jul 2024 15:09:56 +0200 Subject: [PATCH 02/13] Add computations in POC --- ext/DataInterpolationsChainRulesCoreExt.jl | 39 ++++++++++++++++------ src/parameter_caches.jl | 2 +- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index ebd00743..022dba84 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -14,38 +14,55 @@ else using ..ChainRulesCore end -function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), u, t, idx) +function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), u::AbstractVector, t::AbstractVector, idx::Integer) slope = linear_interpolation_parameters(u, t, idx) + Δt = t[idx + 1] - t[idx] + Δu = u[idx + 1] - t[idx] + # TODO: use sparse arrays + du = zero(u) + dt = zero(t) - function linear_interpolation_parameters_pullback(Δslope) - du = @thunk(Δslope*...) # TODO: How to handle sparsity? - dt = @thunk(Δslope*...) # TODO: How to handle sparsity? + function linear_interpolation_parameters_pullback(Δ) df = NoTangent() + du .= zero(eltype(u)) + du[idx] = - Δ / Δt + du[idx + 1] = Δslope / Δt + dt .= zero(eltype(t)) + dt[idx] = Δ * Δu / Δt^2 + dt[idx + 1] = - Δ * Δu / Δt^2 didx = NoTangent() - return (df, du, dt, didx) end return slope, linear_interpolation_parameters_pullback end -function _tangent_u(A::LinearInterpolation, t) - ... # TODO: How to handle sparsity? +function _tangent_u!(Δu::AbstractVector, A::LinearInterpolation) + Δu .= zero(eltype(A.u)) + Δu end -function _tangent_t(A::LinearInterpolation, t) - ... # TODO: How to handle sparsity? +function _tangent_t!(Δt::AbstractVector, A::LinearInterpolation) + idx = A.idx_prev[] + Δt .= zero(eltype(Δt)) + Δt[idx] = one(eltype(Δt)) + Δt end function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t, iguess) where {AType} - u = _interpolate(A, t, iguess)[1] + u = _interpolate(A, t) + # TODO: use sparse arrays + Δu = zero(A.u) + Δt = zero(A.t) + function _interpolate_pullback(Δ) df = NoTangent() - dA = Tangent{AType}(; u = _tangent_u(A, t), t = _tangent_t(A, t)) + dA = Tangent{AType}(; u = _tangent_u!(Δu, A), t = _tangent_t!(Δt, A)) dt = @thunk(derivative(A, t)*Δ) diguess = NoTangent() return df, dA, dt, diguess end + u, _interpolate_pullback end diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index b550b621..e61d9325 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -12,7 +12,7 @@ function LinearParameterCache(u, t) return LinearParameterCache(slope) end -function linear_interpolation_parameters(u, t, idx::Integer) +function linear_interpolation_parameters(u::AbstractVector, t::AbstractVector, idx::Integer)::Number Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] Δt = t[idx + 1] - t[idx] slope = Δu / Δt From 25cdddc56d162ca75771d9188ccaf114b6f17557 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sun, 7 Jul 2024 20:14:06 +0200 Subject: [PATCH 03/13] Progress? --- ext/DataInterpolationsChainRulesCoreExt.jl | 17 +++++++++++++---- src/interpolation_utils.jl | 11 +++++------ src/parameter_caches.jl | 8 +++----- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 022dba84..96144f53 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -14,7 +14,8 @@ else using ..ChainRulesCore end -function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), u::AbstractVector, t::AbstractVector, idx::Integer) +function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), + u::AbstractVector, t::AbstractVector, idx::Integer) slope = linear_interpolation_parameters(u, t, idx) Δt = t[idx + 1] - t[idx] Δu = u[idx + 1] - t[idx] @@ -25,11 +26,11 @@ function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), u::Abst function linear_interpolation_parameters_pullback(Δ) df = NoTangent() du .= zero(eltype(u)) - du[idx] = - Δ / Δt - du[idx + 1] = Δslope / Δt + du[idx] = -Δ / Δt + du[idx + 1] = Δ / Δt dt .= zero(eltype(t)) dt[idx] = Δ * Δu / Δt^2 - dt[idx + 1] = - Δ * Δu / Δt^2 + dt[idx + 1] = -Δ * Δu / Δt^2 didx = NoTangent() return (df, du, dt, didx) end @@ -66,6 +67,14 @@ function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t, iguess) where u, _interpolate_pullback end +function ChainRulesCore.rrule(::typeof(_interpolate), + A::AbstractInterpolation, + t::Number) + deriv = derivative(A, t) + interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) + return _interpolate(A, t), interpolate_pullback +end + function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation, t::Number) return _interpolate(A, t), derivative(A, t) * Δt diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 466248b1..77775d3d 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -122,11 +122,10 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end function cumulative_integral(A) - if isempty(methods(_integral, (typeof(A), Any, Any))) - return nothing - end - integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) - for idx in 1:(length(A.t) - 1)] - pushfirst!(integral_values, zero(first(integral_values))) + integral_prototype = _integral(A, 1, A.t[2]) + + integral_values = [zero(integral_prototype), + (_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) + for idx in 1:(length(A.t) - 1))...] return cumsum(integral_values) end diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index e61d9325..dd37c741 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -5,14 +5,12 @@ end function LinearParameterCache(u, t) slope_prototype = linear_interpolation_parameters(u, t, 1) idxs = 1:(length(t) - 1) - slope = [zero(slope_prototype) for i in idxs] - for idx in idxs - slope[idx] = linear_interpolation_parameters(u, t, idx) - end + slope = [linear_interpolation_parameters(u, t, idx) for idx in idxs] return LinearParameterCache(slope) end -function linear_interpolation_parameters(u::AbstractVector, t::AbstractVector, idx::Integer)::Number +function linear_interpolation_parameters( + u::AbstractVector, t::AbstractVector, idx::Integer)::Number Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] Δt = t[idx + 1] - t[idx] slope = Δu / Δt From e67abcdf4fc4aefd65018b8b87e212050c61a5f0 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Mon, 8 Jul 2024 10:21:25 +0200 Subject: [PATCH 04/13] I now get gradients w.r.t. the u of the datapoints, but they are probably not correct yet --- ext/DataInterpolationsChainRulesCoreExt.jl | 33 +++++++++++----------- src/interpolation_utils.jl | 2 +- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 96144f53..023bfdc1 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -38,43 +38,44 @@ function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), return slope, linear_interpolation_parameters_pullback end -function _tangent_u!(Δu::AbstractVector, A::LinearInterpolation) +function _tangent_u!(Δu::AbstractVector, A::LinearInterpolation, Δ) + idx = A.idx_prev[] Δu .= zero(eltype(A.u)) + Δu[idx] = Δ Δu end -function _tangent_t!(Δt::AbstractVector, A::LinearInterpolation) - idx = A.idx_prev[] +function _tangent_t!(Δt::AbstractVector, A::LinearInterpolation, Δ) Δt .= zero(eltype(Δt)) - Δt[idx] = one(eltype(Δt)) Δt end -function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t, iguess) where {AType} +function _tangent_p!(Δslope, A::LinearInterpolation, t, Δ)::Nothing + idx = A.idx_prev[] + Δslope .= zero(eltype(A.p.slope)) + Δslope[idx] = t * Δ + return nothing +end + +function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t) where {AType} u = _interpolate(A, t) # TODO: use sparse arrays Δu = zero(A.u) Δt = zero(A.t) + Δslope = zero(A.p.slope) function _interpolate_pullback(Δ) df = NoTangent() - dA = Tangent{AType}(; u = _tangent_u!(Δu, A), t = _tangent_t!(Δt, A)) + _tangent_p!(Δslope, A, t, Δ) + dA = Tangent{AType}(; u = _tangent_u!(Δu, A, Δ), t = _tangent_t!(Δt, A, Δ), + p = Tangent{typeof(A.p)}(; slope = Δslope)) dt = @thunk(derivative(A, t)*Δ) - diguess = NoTangent() - return df, dA, dt, diguess + return df, dA, dt end u, _interpolate_pullback end -function ChainRulesCore.rrule(::typeof(_interpolate), - A::AbstractInterpolation, - t::Number) - deriv = derivative(A, t) - interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) - return _interpolate(A, t), interpolate_pullback -end - function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation, t::Number) return _interpolate(A, t), derivative(A, t) * Δt diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 77775d3d..53fc7efa 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -107,7 +107,7 @@ function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) end # Don't nest ReadOnlyArrays -readonly_wrap(a::AbstractArray) = ReadOnlyArray(a) +readonly_wrap(a::AbstractArray) = a readonly_wrap(a::ReadOnlyArray) = a function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) From b1b5004b6b19d2a0bc50777db53a634641b31f6d Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Mon, 8 Jul 2024 19:37:56 +0200 Subject: [PATCH 05/13] gradients w.r.t. u PoC! --- ext/DataInterpolationsChainRulesCoreExt.jl | 41 +++++++++++----------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 023bfdc1..97f031c0 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -4,33 +4,31 @@ if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, LinearInterpolation, - linear_interpolation_parameters + linear_interpolation_parameters, get_idx using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, LinearInterpolation, - linear_interpolation_parameters + linear_interpolation_parameters, get_idx using ..ChainRulesCore end +## Linear interpolation + function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), u::AbstractVector, t::AbstractVector, idx::Integer) slope = linear_interpolation_parameters(u, t, idx) - Δt = t[idx + 1] - t[idx] - Δu = u[idx + 1] - t[idx] # TODO: use sparse arrays du = zero(u) - dt = zero(t) function linear_interpolation_parameters_pullback(Δ) df = NoTangent() du .= zero(eltype(u)) + Δt = t[idx + 1] - t[idx] du[idx] = -Δ / Δt du[idx + 1] = Δ / Δt - dt .= zero(eltype(t)) - dt[idx] = Δ * Δu / Δt^2 - dt[idx + 1] = -Δ * Δu / Δt^2 + dt = NoTangent() didx = NoTangent() return (df, du, dt, didx) end @@ -45,30 +43,31 @@ function _tangent_u!(Δu::AbstractVector, A::LinearInterpolation, Δ) Δu end -function _tangent_t!(Δt::AbstractVector, A::LinearInterpolation, Δ) - Δt .= zero(eltype(Δt)) - Δt -end - -function _tangent_p!(Δslope, A::LinearInterpolation, t, Δ)::Nothing +function _tangent_p!(parameter_tangents::NamedTuple, A::LinearInterpolation, t, Δ)::Nothing + (; slope) = parameter_tangents idx = A.idx_prev[] - Δslope .= zero(eltype(A.p.slope)) - Δslope[idx] = t * Δ + slope[idx] = (t - A.t[idx]) * Δ return nothing end +function allocate_parameter_tangents(A::LinearInterpolation) + return (; slope = zero(A.p.slope)) +end + +## generic + function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t) where {AType} u = _interpolate(A, t) # TODO: use sparse arrays Δu = zero(A.u) - Δt = zero(A.t) - Δslope = zero(A.p.slope) + parameter_tangents = allocate_parameter_tangents(A) function _interpolate_pullback(Δ) + A.idx_prev[] = get_idx(A.t, t, A.idx_prev[]) df = NoTangent() - _tangent_p!(Δslope, A, t, Δ) - dA = Tangent{AType}(; u = _tangent_u!(Δu, A, Δ), t = _tangent_t!(Δt, A, Δ), - p = Tangent{typeof(A.p)}(; slope = Δslope)) + _tangent_p!(parameter_tangents, A, t, Δ) + dA = Tangent{AType}(; u = _tangent_u!(Δu, A, Δ), t = NoTangent(), + p = Tangent{typeof(A.p)}(; parameter_tangents...)) dt = @thunk(derivative(A, t)*Δ) return df, dA, dt end From c7c88102ce7abd968a48a0db50d25b25c823cafc Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Mon, 8 Jul 2024 21:29:10 +0200 Subject: [PATCH 06/13] Small fixes --- ext/DataInterpolationsChainRulesCoreExt.jl | 5 ++++- src/interpolation_utils.jl | 6 +++++- src/parameter_caches.jl | 5 ++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 97f031c0..e723df69 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -38,7 +38,6 @@ end function _tangent_u!(Δu::AbstractVector, A::LinearInterpolation, Δ) idx = A.idx_prev[] - Δu .= zero(eltype(A.u)) Δu[idx] = Δ Δu end @@ -54,6 +53,10 @@ function allocate_parameter_tangents(A::LinearInterpolation) return (; slope = zero(A.p.slope)) end +## Quadratic Spline + + + ## generic function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t) where {AType} diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 53fc7efa..7ab7b3f4 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -122,10 +122,14 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end function cumulative_integral(A) + if !(A.u isa AbstractVector{<:Number}) + return nothing + end + integral_prototype = _integral(A, 1, A.t[2]) integral_values = [zero(integral_prototype), (_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) for idx in 1:(length(A.t) - 1))...] return cumsum(integral_values) -end +end \ No newline at end of file diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index dd37c741..ea662174 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -2,15 +2,14 @@ struct LinearParameterCache{pType} slope::pType end -function LinearParameterCache(u, t) - slope_prototype = linear_interpolation_parameters(u, t, 1) +function LinearParameterCache(u, t)::LinearParameterCache idxs = 1:(length(t) - 1) slope = [linear_interpolation_parameters(u, t, idx) for idx in idxs] return LinearParameterCache(slope) end function linear_interpolation_parameters( - u::AbstractVector, t::AbstractVector, idx::Integer)::Number + u::AbstractArray, t::AbstractVector, idx::Integer) Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] Δt = t[idx + 1] - t[idx] slope = Δu / Δt From 406c76ec79b8b0c8dee8107fa9d89bc445756cc8 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 9 Jul 2024 16:38:37 +0200 Subject: [PATCH 07/13] QuadraticSpline gradient w.r.t. u --- ext/DataInterpolationsChainRulesCoreExt.jl | 99 ++++++++++++++++------ src/interpolation_utils.jl | 2 +- src/online.jl | 9 +- src/parameter_caches.jl | 24 +++--- 4 files changed, 93 insertions(+), 41 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index e723df69..c18dc26e 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -1,75 +1,126 @@ module DataInterpolationsChainRulesCoreExt if isdefined(Base, :get_extension) - using DataInterpolations: _interpolate, derivative, AbstractInterpolation, + using DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx, + interpolation_parameters, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, LinearInterpolation, - linear_interpolation_parameters, get_idx + QuadraticSpline using ChainRulesCore else - using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, + using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx, + interpolation_parameters, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, LinearInterpolation, - linear_interpolation_parameters, get_idx + QuadraticSpline using ..ChainRulesCore end ## Linear interpolation -function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), +function ChainRulesCore.rrule(::typeof(interpolation_parameters), + ::Val{:LinearInterpolation}, u::AbstractVector, t::AbstractVector, idx::Integer) - slope = linear_interpolation_parameters(u, t, idx) + slope = interpolation_parameters(Val(:LinearInterpolation), u, t, idx) # TODO: use sparse arrays du = zero(u) + Δt = t[idx + 1] - t[idx] - function linear_interpolation_parameters_pullback(Δ) + function interpolation_parameters_pullback(Δ) df = NoTangent() - du .= zero(eltype(u)) - Δt = t[idx + 1] - t[idx] + dmethod = NoTangent() du[idx] = -Δ / Δt du[idx + 1] = Δ / Δt dt = NoTangent() didx = NoTangent() - return (df, du, dt, didx) + return (df, dmethod, du, dt, didx) end - return slope, linear_interpolation_parameters_pullback + return slope, interpolation_parameters_pullback end -function _tangent_u!(Δu::AbstractVector, A::LinearInterpolation, Δ) +function allocate_direct_field_tangents(A::LinearInterpolation) + (; u = zero(A.u)) +end + +function allocate_parameter_tangents(A::LinearInterpolation) + return (; slope = zero(A.p.slope)) +end + +function _tangent_direct_fields!( + direct_field_tangents::NamedTuple, A::LinearInterpolation, Δt, Δ) + (; u) = direct_field_tangents idx = A.idx_prev[] - Δu[idx] = Δ - Δu + u[idx] = Δ end -function _tangent_p!(parameter_tangents::NamedTuple, A::LinearInterpolation, t, Δ)::Nothing +function _tangent_p!(parameter_tangents::NamedTuple, A::LinearInterpolation, Δt, Δ) (; slope) = parameter_tangents idx = A.idx_prev[] - slope[idx] = (t - A.t[idx]) * Δ - return nothing + slope[idx] = Δt * Δ end -function allocate_parameter_tangents(A::LinearInterpolation) - return (; slope = zero(A.p.slope)) +## Quadratic Spline + +function ChainRulesCore.rrule(::typeof(interpolation_parameters), + ::Val{:QuadraticSpline}, + z::AbstractVector, t::AbstractVector, idx::Integer) + σ = interpolation_parameters(Val(:QuadraticSpline), z, t, idx) + # TODO: use sparse arrays + dz = zero(z) + Δt = t[idx + 1] - t[idx] + + function interpolation_parameters_pullback(Δ) + df = NoTangent() + dmethod = NoTangent() + dz[idx] = -1 // 2 * Δ / Δt + dz[idx + 1] = 1 // 2 * Δ / Δt + dt = NoTangent() + didx = NoTangent() + return (df, dmethod, dz, dt, didx) + end + + return σ, interpolation_parameters_pullback end -## Quadratic Spline +function allocate_direct_field_tangents(A::QuadraticSpline) + (; u = zero(A.u), z = zero(A.z)) +end +function allocate_parameter_tangents(A::QuadraticSpline) + return (; σ = zero(A.p.σ)) +end + +function _tangent_direct_fields!( + direct_field_tangents::NamedTuple, A::QuadraticSpline, Δt, Δ) + (; u, z) = direct_field_tangents + idx = A.idx_prev[] + u[idx] = Δ + z[idx] = Δt * Δ +end +function _tangent_p!(parameter_tangents::NamedTuple, A::QuadraticSpline, Δt, Δ) + (; σ) = parameter_tangents + idx = A.idx_prev[] + σ[idx] = Δ * Δt^2 +end ## generic function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t) where {AType} u = _interpolate(A, t) # TODO: use sparse arrays - Δu = zero(A.u) + direct_field_tangents = allocate_direct_field_tangents(A) parameter_tangents = allocate_parameter_tangents(A) function _interpolate_pullback(Δ) - A.idx_prev[] = get_idx(A.t, t, A.idx_prev[]) + idx = get_idx(A.t, t, A.idx_prev[]) + A.idx_prev[] = idx + Δt = t - A.t[idx] df = NoTangent() - _tangent_p!(parameter_tangents, A, t, Δ) - dA = Tangent{AType}(; u = _tangent_u!(Δu, A, Δ), t = NoTangent(), + _tangent_direct_fields!(direct_field_tangents, A, Δt, Δ) + _tangent_p!(parameter_tangents, A, Δt, Δ) + dA = Tangent{AType}(; direct_field_tangents..., p = Tangent{typeof(A.p)}(; parameter_tangents...)) dt = @thunk(derivative(A, t)*Δ) return df, dA, dt diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 7ab7b3f4..915271d7 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -132,4 +132,4 @@ function cumulative_integral(A) (_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) for idx in 1:(length(A.t) - 1))...] return cumsum(integral_values) -end \ No newline at end of file +end diff --git a/src/online.jl b/src/online.jl index dc500611..a1ca710e 100644 --- a/src/online.jl +++ b/src/online.jl @@ -3,7 +3,7 @@ import Base: append!, push! function push!(A::LinearInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} push!(A.u.parent, u) push!(A.t.parent, t) - slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) + slope = interpolation_parameters(Val(:LinearInterpolation), A.u, A.t, length(A.t) - 1) push!(A.p.slope, slope) A end @@ -11,7 +11,8 @@ end function push!(A::QuadraticInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} push!(A.u.parent, u) push!(A.t.parent, t) - l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) + l₀, l₁, l₂ = interpolation_parameters( + Val(:QuadraticInterpolation), A.u, A.t, length(A.t) - 2) push!(A.p.l₀, l₀) push!(A.p.l₁, l₁) push!(A.p.l₂, l₂) @@ -31,7 +32,7 @@ function append!( u, t = munge_data(u, t, true) append!(A.u.parent, u) append!(A.t.parent, t) - slope = linear_interpolation_parameters.( + slope = interpolation_parameters.(Val(:LinearInterpolation), Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) append!(A.p.slope, slope) A @@ -53,7 +54,7 @@ function append!( u, t = munge_data(u, t, true) append!(A.u.parent, u) append!(A.t.parent, t) - parameters = quadratic_interpolation_parameters.( + parameters = interpolation_parameters.(Val(:QuadraticInterpolation), Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) append!(A.p.l₀, l₀) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index ea662174..ab878c40 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -4,11 +4,11 @@ end function LinearParameterCache(u, t)::LinearParameterCache idxs = 1:(length(t) - 1) - slope = [linear_interpolation_parameters(u, t, idx) for idx in idxs] + slope = [interpolation_parameters(Val(:LinearInterpolation), u, t, idx) for idx in idxs] return LinearParameterCache(slope) end -function linear_interpolation_parameters( +function interpolation_parameters(::Val{:LinearInterpolation}, u::AbstractArray, t::AbstractVector, idx::Integer) Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] Δt = t[idx + 1] - t[idx] @@ -24,13 +24,13 @@ struct QuadraticParameterCache{pType} end function QuadraticParameterCache(u, t) - parameters = quadratic_interpolation_parameters.( + parameters = interpolation_parameters.(Val(:QuadraticInterpolation), Ref(u), Ref(t), 1:(length(t) - 2)) l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) return QuadraticParameterCache(l₀, l₁, l₂) end -function quadratic_interpolation_parameters(u, t, idx) +function interpolation_parameters(::Val{:QuadraticInterpolation}, u, t, idx) if u isa AbstractMatrix u₀ = u[:, idx] u₁ = u[:, idx + 1] @@ -57,11 +57,11 @@ struct QuadraticSplineParameterCache{pType} end function QuadraticSplineParameterCache(z, t) - σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) + σ = interpolation_parameters.(Val(:QuadraticSpline), Ref(z), Ref(t), 1:(length(t) - 1)) return QuadraticSplineParameterCache(σ) end -function quadratic_spline_parameters(z, t, idx) +function interpolation_parameters(::Val{:QuadraticSpline}, z, t, idx) σ = 1 // 2 * (z[idx + 1] - z[idx]) / (t[idx + 1] - t[idx]) return σ end @@ -72,13 +72,13 @@ struct CubicSplineParameterCache{pType} end function CubicSplineParameterCache(u, h, z) - parameters = cubic_spline_parameters.( + parameters = interpolation_parameters.(Val(:CubicSpline), Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) return CubicSplineParameterCache(c₁, c₂) end -function cubic_spline_parameters(u, h, z, idx) +function interpolation_parameters(::Val{:CubicSpline}, u, h, z, idx) c₁ = (u[idx + 1] / h[idx + 1] - z[idx + 1] * h[idx + 1] / 6) c₂ = (u[idx] / h[idx + 1] - z[idx] * h[idx + 1] / 6) return c₁, c₂ @@ -90,13 +90,13 @@ struct CubicHermiteParameterCache{pType} end function CubicHermiteParameterCache(du, u, t) - parameters = cubic_hermite_spline_parameters.( + parameters = interpolation_parameters.(Val(:CubicHermiteSpline), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) return CubicHermiteParameterCache(c₁, c₂) end -function cubic_hermite_spline_parameters(du, u, t, idx) +function interpolation_parameters(::Val{:CubicHermiteSpline}, du, u, t, idx) Δt = t[idx + 1] - t[idx] u₀ = u[idx] u₁ = u[idx + 1] @@ -114,13 +114,13 @@ struct QuinticHermiteParameterCache{pType} end function QuinticHermiteParameterCache(ddu, du, u, t) - parameters = quintic_hermite_spline_parameters.( + parameters = interpolation_parameters.(Val(:QuinticHermiteSpline), Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) c₁, c₂, c₃ = collect.(eachrow(hcat(collect.(parameters)...))) return QuinticHermiteParameterCache(c₁, c₂, c₃) end -function quintic_hermite_spline_parameters(ddu, du, u, t, idx) +function interpolation_parameters(::Val{:QuinticHermiteSpline}, ddu, du, u, t, idx) Δt = t[idx + 1] - t[idx] u₀ = u[idx] u₁ = u[idx + 1] From aa23c78f336e91e3ab63709c601c2c61c8c9547d Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 9 Jul 2024 19:33:28 +0200 Subject: [PATCH 08/13] Use sparse arrays --- Project.toml | 1 + ext/DataInterpolationsChainRulesCoreExt.jl | 30 ++++++++++++++-------- src/interpolation_utils.jl | 8 +++--- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index b06d2adb..dc2b8b2b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index c18dc26e..e949570a 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -16,14 +16,15 @@ else using ..ChainRulesCore end +using SparseArrays + ## Linear interpolation function ChainRulesCore.rrule(::typeof(interpolation_parameters), ::Val{:LinearInterpolation}, u::AbstractVector, t::AbstractVector, idx::Integer) slope = interpolation_parameters(Val(:LinearInterpolation), u, t, idx) - # TODO: use sparse arrays - du = zero(u) + du = SparseVector(length(u), [idx, idx + 1], zeros(2)) Δt = t[idx + 1] - t[idx] function interpolation_parameters_pullback(Δ) @@ -40,11 +41,15 @@ function ChainRulesCore.rrule(::typeof(interpolation_parameters), end function allocate_direct_field_tangents(A::LinearInterpolation) - (; u = zero(A.u)) + idx = A.idx_prev[] + u = SparseVector(length(A.u), [idx], zeros(1)) + (; u) end function allocate_parameter_tangents(A::LinearInterpolation) - return (; slope = zero(A.p.slope)) + idx = A.idx_prev[] + slope = SparseVector(length(A.p.slope), [idx], zeros(1)) + return (; slope) end function _tangent_direct_fields!( @@ -65,9 +70,8 @@ end function ChainRulesCore.rrule(::typeof(interpolation_parameters), ::Val{:QuadraticSpline}, z::AbstractVector, t::AbstractVector, idx::Integer) - σ = interpolation_parameters(Val(:QuadraticSpline), z, t, idx) - # TODO: use sparse arrays - dz = zero(z) + σ = interpolation_parameters(Val(:QuadraticSpline), z, t, idx)s + dz = SparseVector(length(z), [idx, idx + 1], zeros(2)) Δt = t[idx + 1] - t[idx] function interpolation_parameters_pullback(Δ) @@ -84,11 +88,16 @@ function ChainRulesCore.rrule(::typeof(interpolation_parameters), end function allocate_direct_field_tangents(A::QuadraticSpline) - (; u = zero(A.u), z = zero(A.z)) + idx = A.idx_prev[] + u = SparseVector(length(A.u), [idx], zeros(1)) + z = SparseVector(length(A.z), [idx], zeros(1)) + (; u, z) end function allocate_parameter_tangents(A::QuadraticSpline) - return (; σ = zero(A.p.σ)) + idx = A.idx_prev[] + σ = SparseVector(length(A.p.σ), [idx], zeros(1)) + return (; σ) end function _tangent_direct_fields!( @@ -109,12 +118,11 @@ end function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t) where {AType} u = _interpolate(A, t) - # TODO: use sparse arrays + idx = get_idx(A.t, t, A.idx_prev[]) direct_field_tangents = allocate_direct_field_tangents(A) parameter_tangents = allocate_parameter_tangents(A) function _interpolate_pullback(Δ) - idx = get_idx(A.t, t, A.idx_prev[]) A.idx_prev[] = idx Δt = t - A.t[idx] df = NoTangent() diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 915271d7..4fbecfe3 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -121,11 +121,7 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end end -function cumulative_integral(A) - if !(A.u isa AbstractVector{<:Number}) - return nothing - end - +function cumulative_integral(A, ::AbstractVector{<:Number}) integral_prototype = _integral(A, 1, A.t[2]) integral_values = [zero(integral_prototype), @@ -133,3 +129,5 @@ function cumulative_integral(A) for idx in 1:(length(A.t) - 1))...] return cumsum(integral_values) end + +cumulative_integral(A) = nothing From 436ac01d01ccf0bb4d7f5f02d090e67f6256c79d Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 9 Jul 2024 19:42:14 +0200 Subject: [PATCH 09/13] Small fixes --- ext/DataInterpolationsChainRulesCoreExt.jl | 2 +- src/interpolation_caches.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index e949570a..7c10ba94 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -70,7 +70,7 @@ end function ChainRulesCore.rrule(::typeof(interpolation_parameters), ::Val{:QuadraticSpline}, z::AbstractVector, t::AbstractVector, idx::Integer) - σ = interpolation_parameters(Val(:QuadraticSpline), z, t, idx)s + σ = interpolation_parameters(Val(:QuadraticSpline), z, t, idx) dz = SparseVector(length(z), [idx, idx + 1], zeros(2)) Δt = t[idx + 1] - t[idx] diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 83e04fe5..f0008dba 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -299,7 +299,7 @@ function QuadraticSpline( # zero for element type of d, which we don't know yet typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin])) - d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) + d = [2 // 1 * (u[i] - u[max(1, i - 1)]) / (t[i] - t[1 + abs(i - 2)]) for i in eachindex(t)] z = tA \ d p = QuadraticSplineParameterCache(z, t) A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) From edde211cd3ef58d5994736fa80a644b28170c411 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 13 Jul 2024 14:34:33 +0200 Subject: [PATCH 10/13] Pass tests --- Project.toml | 1 + ext/DataInterpolationsChainRulesCoreExt.jl | 101 +++++++++++++-------- src/DataInterpolations.jl | 1 + src/derivatives.jl | 3 +- src/integrals.jl | 5 +- src/interpolation_caches.jl | 74 ++++----------- src/interpolation_methods.jl | 2 +- src/interpolation_utils.jl | 4 +- src/parameter_caches.jl | 41 +++++++-- 9 files changed, 128 insertions(+), 104 deletions(-) diff --git a/Project.toml b/Project.toml index dc2b8b2b..d2b21cef 100644 --- a/Project.toml +++ b/Project.toml @@ -39,6 +39,7 @@ RecipesBase = "1.3" Reexport = "1" RegularizationTools = "0.6" SafeTestsets = "0.1" +SparseArrays = "1.10" StableRNGs = "1" Symbolics = "5.29" Test = "1" diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 7c10ba94..a290be25 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -2,42 +2,53 @@ module DataInterpolationsChainRulesCoreExt if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx, - interpolation_parameters, + interpolation_parameters, LinearParameterCache, + QuadraticSplineParameterCache, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, LinearInterpolation, QuadraticSpline using ChainRulesCore + using LinearAlgebra + using SparseArrays else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx, - interpolation_parameters, + interpolation_parameters, LinearParameterCache, + QuadraticSplineParameterCache, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, LinearInterpolation, QuadraticSpline using ..ChainRulesCore + using ..LinearAlgebra + using ..SparseArrays end -using SparseArrays - ## Linear interpolation -function ChainRulesCore.rrule(::typeof(interpolation_parameters), - ::Val{:LinearInterpolation}, - u::AbstractVector, t::AbstractVector, idx::Integer) - slope = interpolation_parameters(Val(:LinearInterpolation), u, t, idx) - du = SparseVector(length(u), [idx, idx + 1], zeros(2)) - Δt = t[idx + 1] - t[idx] +function ChainRulesCore.rrule( + ::Type{LinearParameterCache}, u::AbstractArray, t::AbstractVector) + p = LinearParameterCache(u, t) + du = zeros(eltype(p.slope), length(u)) - function interpolation_parameters_pullback(Δ) + function LinearParameterCache_pullback(Δp) df = NoTangent() - dmethod = NoTangent() - du[idx] = -Δ / Δt - du[idx + 1] = Δ / Δt + du[2:end] += Δp.slope + du[1:(end - 1)] -= Δp.slope dt = NoTangent() - didx = NoTangent() - return (df, dmethod, du, dt, didx) + return (df, du, dt) end - return slope, interpolation_parameters_pullback + p, LinearParameterCache_pullback +end + +function ChainRulesCore.rrule( + ::Type{LinearInterpolation}, u, t, I, p, extrapolate, safetycopy) + A = LinearInterpolation(u, t, I, p, extrapolate, safetycopy) + + function LinearInterpolation_pullback(ΔA) + return ΔA.u, NoTangent(), NoTangent(), ΔA.p, NoTangent(), NoTangent(), NoTangent() + end + + A, LinearInterpolation_pullback end function allocate_direct_field_tangents(A::LinearInterpolation) @@ -67,50 +78,68 @@ end ## Quadratic Spline -function ChainRulesCore.rrule(::typeof(interpolation_parameters), - ::Val{:QuadraticSpline}, - z::AbstractVector, t::AbstractVector, idx::Integer) - σ = interpolation_parameters(Val(:QuadraticSpline), z, t, idx) - dz = SparseVector(length(z), [idx, idx + 1], zeros(2)) - Δt = t[idx + 1] - t[idx] +function ChainRulesCore.rrule(::Type{QuadraticSplineParameterCache}, u, t) + p = QuadraticSplineParameterCache(u, t) + n = length(u) + + ∂z_∂d = inv(p.tA) - function interpolation_parameters_pullback(Δ) + Δt = diff(t) + diagonal_main = [zero(eltype(Δt)), 2 ./ Δt...] + diagonal_down = -diagonal_main[2:end] + diagonal_up = zero(diagonal_down) + ∂d_∂u = Tridiagonal(diagonal_down, diagonal_main, diagonal_up) + + ∂σ_∂z = spzeros(n, n - 1) + for i in 1:(n - 1) + ∂σ_∂z[i, i] = -0.5 / Δt[i] + ∂σ_∂z[i + 1, i] = 0.5 / Δt[i] + end + + function QuadraticSplineParameterCache_pullback(Δp) df = NoTangent() - dmethod = NoTangent() - dz[idx] = -1 // 2 * Δ / Δt - dz[idx + 1] = 1 // 2 * Δ / Δt + du = (Δp.z + ∂σ_∂z * Δp.σ)' * ∂z_∂d * ∂d_∂u dt = NoTangent() - didx = NoTangent() - return (df, dmethod, dz, dt, didx) + return (df, du, dt) + end + + p, QuadraticSplineParameterCache_pullback +end + +function ChainRulesCore.rrule(::Type{QuadraticSpline}, u, t, I, p, extrapolate, safetycopy) + A = QuadraticSpline(u, t, I, p, extrapolate, safetycopy) + + function LinearInterpolation_pullback(ΔA) + return ΔA.u, NoTangent(), NoTangent(), ΔA.p, NoTangent(), NoTangent(), NoTangent() end - return σ, interpolation_parameters_pullback + A, LinearInterpolation_pullback end function allocate_direct_field_tangents(A::QuadraticSpline) idx = A.idx_prev[] u = SparseVector(length(A.u), [idx], zeros(1)) - z = SparseVector(length(A.z), [idx], zeros(1)) - (; u, z) + (; u) end function allocate_parameter_tangents(A::QuadraticSpline) idx = A.idx_prev[] + z = SparseVector(length(A.p.z), [idx], zeros(1)) σ = SparseVector(length(A.p.σ), [idx], zeros(1)) - return (; σ) + return (; z, σ) end function _tangent_direct_fields!( direct_field_tangents::NamedTuple, A::QuadraticSpline, Δt, Δ) - (; u, z) = direct_field_tangents + (; u) = direct_field_tangents idx = A.idx_prev[] u[idx] = Δ - z[idx] = Δt * Δ end function _tangent_p!(parameter_tangents::NamedTuple, A::QuadraticSpline, Δt, Δ) - (; σ) = parameter_tangents + (; z, σ) = parameter_tangents idx = A.idx_prev[] + z[idx] = Δ * Δt σ[idx] = Δ * Δt^2 end diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index c86a6579..d5c31598 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -8,6 +8,7 @@ using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff using ReadOnlyArrays +using SparseArrays # Only used in DataInterpolationsChainRulesCoreExt.jl, but otherwise Aqua complains import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, bracketstrictlymontonic diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..159f37eb 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -130,7 +130,8 @@ end function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first) σ = A.p.σ[idx - 1] - A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx + z = A.p.z[idx - 1] + z + 2σ * (t - A.t[idx - 1]), idx end # CubicSpline Interpolation diff --git a/src/integrals.jl b/src/integrals.jl index 3040189f..d85ab27f 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -6,7 +6,8 @@ end function integral(A::AbstractInterpolation, t1::Number, t2::Number) ((t1 < A.t[1] || t1 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) ((t2 < A.t[1] || t2 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) - !hasfield(typeof(A), :I) && throw(IntegralNotFoundError()) + has_I = hasfield(typeof(A), :I) + (!has_I || (has_I && isnothing(A.I))) && throw(IntegralNotFoundError()) # the index less than or equal to t1 idx1 = get_idx(A.t, t1, 0) # the index less than t2 @@ -61,7 +62,7 @@ end function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt + return A.p.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt end function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index f0008dba..5e59b842 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -32,7 +32,7 @@ function LinearInterpolation(u, t; extrapolate = false, safetycopy = true) u, t = munge_data(u, t, safetycopy) p = LinearParameterCache(u, t) A = LinearInterpolation(u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) LinearInterpolation(u, t, I, p, extrapolate, safetycopy) end @@ -74,7 +74,7 @@ function QuadraticInterpolation(u, t, mode; extrapolate = false, safetycopy = tr u, t = munge_data(u, t, safetycopy) p = QuadraticParameterCache(u, t) A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) end @@ -198,7 +198,7 @@ function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) end @@ -238,7 +238,7 @@ end function ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) u, t = munge_data(u, t, safetycopy) A = ConstantInterpolation(u, t, nothing, dir, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) end @@ -263,22 +263,16 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: u::uType t::tType I::IType - p::QuadraticSplineParameterCache{pType} - tA::tAType - d::dType - z::zType + p::QuadraticSplineParameterCache{tAType, dType, zType, pType} extrapolate::Bool idx_prev::Base.RefValue{Int} safetycopy::Bool - function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) - new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA), - typeof(d), typeof(z), eltype(u)}(u, + function QuadraticSpline(u, t, I, p, extrapolate, safetycopy) + new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(p.tA), + typeof(p.d), typeof(p.z), eltype(u)}(u, t, I, p, - tA, - d, - z, extrapolate, Ref(1), safetycopy @@ -287,45 +281,13 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: end function QuadraticSpline( - u::uType, t; extrapolate = false, - safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) - s = length(t) - dl = ones(eltype(t), s - 1) - d_tmp = ones(eltype(t), s) - du = zeros(eltype(t), s - 1) - tA = Tridiagonal(dl, d_tmp, du) - - # zero for element type of d, which we don't know yet - typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin])) - - d = [2 // 1 * (u[i] - u[max(1, i - 1)]) / (t[i] - t[1 + abs(i - 2)]) for i in eachindex(t)] - z = tA \ d - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) -end - -function QuadraticSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} + u, t; extrapolate = false, + safetycopy = true) u, t = munge_data(u, t, safetycopy) - s = length(t) - dl = ones(eltype(t), s - 1) - d_tmp = ones(eltype(t), s) - du = zeros(eltype(t), s - 1) - tA = Tridiagonal(dl, d_tmp, du) - d_ = map( - i -> i == 1 ? zeros(eltype(t), size(u[1])) : - 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), - 1:s) - d = transpose(reshape(reduce(hcat, d_), :, s)) - z_ = reshape(transpose(tA \ d), size(u[1])..., :) - z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + p = QuadraticSplineParameterCache(u, t) + A = QuadraticSpline(u, t, nothing, p, extrapolate, safetycopy) + I = cumulative_integral(A, A.u) + QuadraticSpline(u, t, I, p, extrapolate, safetycopy) end """ @@ -391,7 +353,7 @@ function CubicSpline(u::uType, z = tA \ d p = CubicSplineParameterCache(u, h, z) A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) end @@ -413,7 +375,7 @@ function CubicSpline( z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] p = CubicSplineParameterCache(u, h, z) A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) end @@ -735,7 +697,7 @@ function CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) u, t = munge_data(u, t, safetycopy) p = CubicHermiteParameterCache(du, u, t) A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) end @@ -779,6 +741,6 @@ function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = t u, t = munge_data(u, t, safetycopy) p = QuinticHermiteParameterCache(ddu, du, u, t) A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 5d09ceff..dd2893a7 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -149,7 +149,7 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx + return A.p.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx end # CubicSpline Interpolation diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 4fbecfe3..ee6db3c9 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -107,7 +107,7 @@ function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) end # Don't nest ReadOnlyArrays -readonly_wrap(a::AbstractArray) = a +readonly_wrap(a::AbstractArray) = ReadOnlyArray(a) readonly_wrap(a::ReadOnlyArray) = a function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) @@ -130,4 +130,4 @@ function cumulative_integral(A, ::AbstractVector{<:Number}) return cumsum(integral_values) end -cumulative_integral(A) = nothing +cumulative_integral(A, ::AbstractArray) = nothing diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index ab878c40..3648120a 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -52,13 +52,42 @@ function interpolation_parameters(::Val{:QuadraticInterpolation}, u, t, idx) return l₀, l₁, l₂ end -struct QuadraticSplineParameterCache{pType} - σ::pType -end - -function QuadraticSplineParameterCache(z, t) +struct QuadraticSplineParameterCache{tAType, dType, zType, σType} + tA::tAType + d::dType + z::zType + σ::σType +end + +function QuadraticSplineParameterCache(u::AbstractVector{<:Number}, t) + s = length(t) + dl = ones(eltype(t), s - 1) + d_tmp = ones(eltype(t), s) + du = zeros(eltype(t), s - 1) + tA = Tridiagonal(dl, d_tmp, du) + + d = [2 // 1 * (u[i] - u[max(1, i - 1)]) / (t[i] - t[1 + abs(i - 2)]) + for i in eachindex(t)] + z = tA \ d + σ = interpolation_parameters.(Val(:QuadraticSpline), Ref(z), Ref(t), 1:(length(t) - 1)) + return QuadraticSplineParameterCache(tA, d, z, σ) +end + +function QuadraticSplineParameterCache(u::AbstractVector{<:AbstractArray{<:Number}}, t) + s = length(t) + dl = ones(eltype(t), s - 1) + d_tmp = ones(eltype(t), s) + du = zeros(eltype(t), s - 1) + tA = Tridiagonal(dl, d_tmp, du) + d_ = map( + i -> i == 1 ? zeros(eltype(t), size(u[1])) : + 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), + 1:s) + d = transpose(reshape(reduce(hcat, d_), :, s)) + z_ = reshape(transpose(tA \ d), size(u[1])..., :) + z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] σ = interpolation_parameters.(Val(:QuadraticSpline), Ref(z), Ref(t), 1:(length(t) - 1)) - return QuadraticSplineParameterCache(σ) + return QuadraticSplineParameterCache(tA, d, z, σ) end function interpolation_parameters(::Val{:QuadraticSpline}, z, t, idx) From 30c99c794c095d8c2e00649769902039c22c351b Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 13 Jul 2024 17:12:11 +0200 Subject: [PATCH 11/13] Add ReadOnlyArray rrule --- ext/DataInterpolationsChainRulesCoreExt.jl | 47 ++++++++++++++++------ 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index a290be25..3e872bb0 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -2,7 +2,7 @@ module DataInterpolationsChainRulesCoreExt if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx, - interpolation_parameters, LinearParameterCache, + cumulative_integral, LinearParameterCache, QuadraticSplineParameterCache, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, LinearInterpolation, @@ -10,9 +10,10 @@ if isdefined(Base, :get_extension) using ChainRulesCore using LinearAlgebra using SparseArrays + using ReadOnlyArrays else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx, - interpolation_parameters, LinearParameterCache, + cumulative_integral, LinearParameterCache, QuadraticSplineParameterCache, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, LinearInterpolation, @@ -20,6 +21,7 @@ else using ..ChainRulesCore using ..LinearAlgebra using ..SparseArrays + using ..ReadOnlyArrays end ## Linear interpolation @@ -45,7 +47,14 @@ function ChainRulesCore.rrule( A = LinearInterpolation(u, t, I, p, extrapolate, safetycopy) function LinearInterpolation_pullback(ΔA) - return ΔA.u, NoTangent(), NoTangent(), ΔA.p, NoTangent(), NoTangent(), NoTangent() + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = ΔA.p + dextrapolate = NoTangent() + dsafetycopy = NoTangent() + return df, du, dt, dI, dp, dextrapolate, dsafetycopy end A, LinearInterpolation_pullback @@ -82,8 +91,6 @@ function ChainRulesCore.rrule(::Type{QuadraticSplineParameterCache}, u, t) p = QuadraticSplineParameterCache(u, t) n = length(u) - ∂z_∂d = inv(p.tA) - Δt = diff(t) diagonal_main = [zero(eltype(Δt)), 2 ./ Δt...] diagonal_down = -diagonal_main[2:end] @@ -98,7 +105,9 @@ function ChainRulesCore.rrule(::Type{QuadraticSplineParameterCache}, u, t) function QuadraticSplineParameterCache_pullback(Δp) df = NoTangent() - du = (Δp.z + ∂σ_∂z * Δp.σ)' * ∂z_∂d * ∂d_∂u + temp1 = Δp.z + ∂σ_∂z * Δp.σ + temp2 = p.tA' \ temp1 + du = ∂d_∂u' * temp2 dt = NoTangent() return (df, du, dt) end @@ -109,11 +118,18 @@ end function ChainRulesCore.rrule(::Type{QuadraticSpline}, u, t, I, p, extrapolate, safetycopy) A = QuadraticSpline(u, t, I, p, extrapolate, safetycopy) - function LinearInterpolation_pullback(ΔA) - return ΔA.u, NoTangent(), NoTangent(), ΔA.p, NoTangent(), NoTangent(), NoTangent() + function QuadraticSpline_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = ΔA.p + dextrapolate = NoTangent() + dsafetycopy = NoTangent() + return df, du, dt, dI, dp, dextrapolate, dsafetycopy end - A, LinearInterpolation_pullback + A, QuadraticSpline_pullback end function allocate_direct_field_tangents(A::QuadraticSpline) @@ -145,8 +161,8 @@ end ## generic -function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t) where {AType} - u = _interpolate(A, t) +function ChainRulesCore.rrule(A::AType, t::Number) where {AType <: AbstractInterpolation} + u = A(t) idx = get_idx(A.t, t, A.idx_prev[]) direct_field_tangents = allocate_direct_field_tangents(A) parameter_tangents = allocate_parameter_tangents(A) @@ -154,13 +170,12 @@ function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t) where {AType} function _interpolate_pullback(Δ) A.idx_prev[] = idx Δt = t - A.t[idx] - df = NoTangent() _tangent_direct_fields!(direct_field_tangents, A, Δt, Δ) _tangent_p!(parameter_tangents, A, Δt, Δ) dA = Tangent{AType}(; direct_field_tangents..., p = Tangent{typeof(A.p)}(; parameter_tangents...)) dt = @thunk(derivative(A, t)*Δ) - return df, dA, dt + return dA, dt end u, _interpolate_pullback @@ -171,4 +186,10 @@ function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractIn return _interpolate(A, t), derivative(A, t) * Δt end +function ChainRulesCore.rrule(::Type{ReadOnlyArray}, parent) + read_only_array = ReadOnlyArray(parent) + ReadOnlyArray_pullback(Δ) = NoTangent(), Δ + read_only_array, ReadOnlyArray_pullback +end + end # module From 353cdff08af1e7ab77d8080fbf892149c931df3b Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 13 Jul 2024 17:34:50 +0200 Subject: [PATCH 12/13] Ignure cumulative_integral --- ext/DataInterpolationsChainRulesCoreExt.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 3e872bb0..30d3fe09 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -192,4 +192,10 @@ function ChainRulesCore.rrule(::Type{ReadOnlyArray}, parent) read_only_array, ReadOnlyArray_pullback end +function ChainRulesCore.rrule(::typeof(cumulative_integral), A, u) + I = cumulative_integral(A, u) + cumulative_integral_pullback(Δ) = NoTangent(), NoTangent() + I, cumulative_integral_pullback +end + end # module From 72790c8c1c0eb5faaa25b8c2491fdac087a14a3d Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 13 Jul 2024 19:08:59 +0200 Subject: [PATCH 13/13] Move SparseArrays to weakdeps --- Project.toml | 4 ++-- ext/DataInterpolationsChainRulesCoreExt.jl | 3 ++- src/DataInterpolations.jl | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index d2b21cef..b3bdcfa1 100644 --- a/Project.toml +++ b/Project.toml @@ -10,16 +10,16 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Optim = "429524aa-4258-5aef-a3af-852621145aeb" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] -DataInterpolationsChainRulesCoreExt = "ChainRulesCore" +DataInterpolationsChainRulesCoreExt = ["ChainRulesCore", "SparseArrays"] DataInterpolationsOptimExt = "Optim" DataInterpolationsRegularizationToolsExt = "RegularizationTools" DataInterpolationsSymbolicsExt = "Symbolics" diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 30d3fe09..37613347 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -92,7 +92,8 @@ function ChainRulesCore.rrule(::Type{QuadraticSplineParameterCache}, u, t) n = length(u) Δt = diff(t) - diagonal_main = [zero(eltype(Δt)), 2 ./ Δt...] + diagonal_main = 2 ./ Δt + pushfirst!(diagonal_main, zero(eltype(diagonal_main))) diagonal_down = -diagonal_main[2:end] diagonal_up = zero(diagonal_down) ∂d_∂u = Tridiagonal(diagonal_down, diagonal_main, diagonal_up) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index d5c31598..c86a6579 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -8,7 +8,6 @@ using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff using ReadOnlyArrays -using SparseArrays # Only used in DataInterpolationsChainRulesCoreExt.jl, but otherwise Aqua complains import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, bracketstrictlymontonic