Skip to content

Commit

Permalink
Merge pull request #315 from SouthEndMusic/cache_parameters_opt_in
Browse files Browse the repository at this point in the history
Refactor parameter caching + add Zygote tests
  • Loading branch information
ChrisRackauckas authored Jul 28, 2024
2 parents 9f15432 + e0cce0b commit 4366656
Show file tree
Hide file tree
Showing 19 changed files with 602 additions and 342 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

Expand All @@ -16,6 +15,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
Expand All @@ -33,14 +33,14 @@ LinearAlgebra = "1.10"
Optim = "1.6"
PrettyTables = "2"
QuadGK = "2.9.1"
ReadOnlyArrays = "0.2.0"
RecipesBase = "1.3"
Reexport = "1"
RegularizationTools = "0.6"
SafeTestsets = "0.1"
StableRNGs = "1"
Symbolics = "5.29"
Test = "1"
Zygote = "0.6"
julia = "1.10"

[extras]
Expand All @@ -55,6 +55,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"]
test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"]
17 changes: 1 addition & 16 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,7 @@ A2(300.0)

The values computed beyond the range of the time points provided during interpolation will not be reliable, as these methods only perform well within the range and the first/last piece polynomial fit is extrapolated on either side which might not reflect the true nature of the data.

The keyword `safetycopy = false` can be passed to make sure no copies of `u` and `t` are made when initializing the interpolation object.

```@example interface
A3 = QuadraticInterpolation(u, t; safetycopy = false)
# Check for same memory
u === A3.u.parent
```

Note that this does not prevent allocation in every interpolation constructor call, because parameter values are cached for all interpolation types except [`ConstantInterpolation`](@ref).

Because of the caching of parameters which depend on `u` and `t`, this data should not be mutated. Therefore `u` and `t` are wrapped in a `ReadOnlyArray` from [ReadOnlyArrays.jl](https://github.com/JuliaArrays/ReadOnlyArrays.jl).

```@repl interface
A3.t[2] = 3.14
```
The keyword `cache_parameters = true` can be passed to precalculate parameters at initialization, making evalations cheaper to compute. This is not compatible with modifying `u` and `t`. The default `cache_parameters = false` does however not prevent allocation in every interpolation constructor call.

## Derivatives

Expand Down
78 changes: 74 additions & 4 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,97 @@
module DataInterpolationsChainRulesCoreExt

if isdefined(Base, :get_extension)
using DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox
BSplineInterpolation, BSplineApprox, get_idx, get_parameters,
_quad_interp_indices
using ChainRulesCore
else
using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox
BSplineInterpolation, BSplineApprox, get_parameters,
_quad_interp_indices
using ..ChainRulesCore
end

function ChainRulesCore.rrule(
::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters)
A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters)
function LinearInterpolation_pullback(ΔA)
df = NoTangent()
du = ΔA.u
dt = NoTangent()
dI = NoTangent()
dp = NoTangent()
dextrapolate = NoTangent()
dcache_parameters = NoTangent()
df, du, dt, dI, dp, dextrapolate, dcache_parameters
end

A, LinearInterpolation_pullback
end

function ChainRulesCore.rrule(
::Type{QuadraticInterpolation}, u, t, I, p, mode, extrapolate, cache_parameters)
A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters)
function LinearInterpolation_pullback(ΔA)
df = NoTangent()
du = ΔA.u
dt = NoTangent()
dI = NoTangent()
dp = NoTangent()
dmode = NoTangent()
dextrapolate = NoTangent()
dcache_parameters = NoTangent()
df, du, dt, dI, dp, dmode, dextrapolate, dcache_parameters
end

A, LinearInterpolation_pullback
end

function u_tangent(A::LinearInterpolation, t, Δ)
out = zero(A.u)
idx = get_idx(A.t, t, A.idx_prev[])
t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx])
out[idx] = Δ * (one(eltype(out)) - t_factor)
out[idx + 1] = Δ * t_factor
out
end

function u_tangent(A::QuadraticInterpolation, t, Δ)
out = zero(A.u)
i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[])
t₀ = A.t[i₀]
t₁ = A.t[i₁]
t₂ = A.t[i₂]
Δt₀ = t₁ - t₀
Δt₁ = t₂ - t₁
Δt₂ = t₂ - t₀
out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂)
out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁)
out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁)
out
end

function u_tangent(A, t, Δ)
NoTangent()
end

function ChainRulesCore.rrule(::typeof(_interpolate),
A::Union{
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
AkimaInterpolation,
BSplineInterpolation,
BSplineApprox
},
t::Number)
deriv = derivative(A, t)
interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ)
function interpolate_pullback(Δ)
(NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), deriv * Δ)
end
return _interpolate(A, t), interpolate_pullback
end

Expand Down
5 changes: 2 additions & 3 deletions ext/DataInterpolationsOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ function Curvefit(u,
box = false,
lb = nothing,
ub = nothing;
extrapolate = false,
safetycopy = false)
u, t = munge_data(u, t, safetycopy)
extrapolate = false)
u, t = munge_data(u, t)
errfun(t, u, p) = sum(abs2.(u .- model(t, p)))
if box == false
mfit = optimize(p -> errfun(t, u, p), p0, alg)
Expand Down
28 changes: 14 additions & 14 deletions ext/DataInterpolationsRegularizationToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ = 1.0, alg = :gcv_svd)
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
wls::AbstractVector, wr::AbstractVector, d::Int = 2;
λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
M = _mapping_matrix(t̂, t)
Wls½ = LA.diagm(sqrt.(wls))
Wr½ = LA.diagm(sqrt.(wr))
Expand All @@ -86,8 +86,8 @@ A = RegularizationSmooth(u, t, d; λ = 1.0, alg = :gcv_svd, extrapolate = false)
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2;
λ::Real = 1.0,
alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
= t
N = length(t)
M = Array{Float64}(LA.I, N, N)
Expand Down Expand Up @@ -115,8 +115,8 @@ A = RegularizationSmooth(u, t, t̂, d; λ = 1.0, alg = :gcv_svd, extrapolate = f
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd,
extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
extrapolate::Bool = false)
u, t = munge_data(u, t)
N, N̂ = length(t), length(t̂)
M = _mapping_matrix(t̂, t)
Wls½ = Array{Float64}(LA.I, N, N)
Expand All @@ -143,8 +143,8 @@ A = RegularizationSmooth(u, t, t̂, wls, d; λ = 1.0, alg = :gcv_svd, extrapolat
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
wls::AbstractVector, d::Int = 2; λ::Real = 1.0,
alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
N, N̂ = length(t), length(t̂)
M = _mapping_matrix(t̂, t)
Wls½ = LA.diagm(sqrt.(wls))
Expand Down Expand Up @@ -172,8 +172,8 @@ A = RegularizationSmooth(
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
wls::AbstractVector, d::Int = 2; λ::Real = 1.0,
alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
= t
N = length(t)
M = Array{Float64}(LA.I, N, N)
Expand Down Expand Up @@ -202,8 +202,8 @@ A = RegularizationSmooth(
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
wls::AbstractVector, wr::AbstractVector, d::Int = 2;
λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
= t
N = length(t)
M = Array{Float64}(LA.I, N, N)
Expand Down Expand Up @@ -232,8 +232,8 @@ A = RegularizationSmooth(
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd,
extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
extrapolate::Bool = false)
u, t = munge_data(u, t)
= t
N = length(t)
M = Array{Float64}(LA.I, N, N)
Expand Down
19 changes: 6 additions & 13 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ abstract type AbstractInterpolation{T} end
using LinearAlgebra, RecipesBase
using PrettyTables
using ForwardDiff
using ReadOnlyArrays
import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated,
bracketstrictlymontonic

Expand Down Expand Up @@ -90,12 +89,6 @@ function Base.showerror(io::IO, e::IntegralNotInvertibleError)
print(io, INTEGRAL_NOT_INVERTIBLE_ERROR)
end

const MUST_COPY_ERROR = "A copy must be made of u, t to filter missing data"
struct MustCopyError <: Exception end
function Base.showerror(io::IO, e::MustCopyError)
print(io, MUST_COPY_ERROR)
end

export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,
AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline,
BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation,
Expand Down Expand Up @@ -128,12 +121,12 @@ struct RegularizationSmooth{uType, tType, T, T2, ITP <: AbstractInterpolation{T}
Aitp,
extrapolate)
new{typeof(u), typeof(t), eltype(u), typeof(λ), typeof(Aitp)}(
readonly_wrap(u),
readonly_wrap(û),
readonly_wrap(t),
readonly_wrap(t̂),
readonly_wrap(oftype(u.parent, wls)),
readonly_wrap(oftype(u.parent, wr)),
u,
,
t,
,
wls,
wr,
d,
λ,
alg,
Expand Down
23 changes: 14 additions & 9 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ end

function _derivative(A::LinearInterpolation, t::Number, iguess)
idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -1, side = :first)
A.p.slope[idx], idx
slope = get_parameters(A, idx)
slope, idx
end

function _derivative(A::QuadraticInterpolation, t::Number, iguess)
i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess)
du₀ = A.p.l₀[i₀] * (2t - A.t[i₁] - A.t[i₂])
du₁ = A.p.l₁[i₀] * (2t - A.t[i₀] - A.t[i₂])
du₂ = A.p.l₂[i₀] * (2t - A.t[i₀] - A.t[i₁])
l₀, l₁, l₂ = get_parameters(A, i₀)
du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂])
du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂])
du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁])
return @views @. du₀ + du₁ + du₂, i₀
end

Expand Down Expand Up @@ -129,7 +131,7 @@ end
# QuadraticSpline Interpolation
function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
σ = A.p.σ[idx - 1]
σ = get_parameters(A, idx - 1)
A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx
end

Expand All @@ -139,8 +141,9 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
Δt₁ = t - A.t[idx]
Δt₂ = A.t[idx + 1] - t
dI = (-A.z[idx] * Δt₂^2 + A.z[idx + 1] * Δt₁^2) / (2A.h[idx + 1])
dC = A.p.c₁[idx]
dD = -A.p.c₂[idx]
c₁, c₂ = get_parameters(A, idx)
dC = c₁
dD = -c₂
dI + dC + dD, idx
end

Expand Down Expand Up @@ -193,7 +196,8 @@ function _derivative(
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
out = A.du[idx]
out += Δt₀ * (Δt₀ * A.p.c₂[idx] + 2(A.p.c₁[idx] + Δt₁ * A.p.c₂[idx]))
c₁, c₂ = get_parameters(A, idx)
out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂))
out, idx
end

Expand All @@ -204,7 +208,8 @@ function _derivative(
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
out = A.du[idx] + A.ddu[idx] * Δt₀
c₁, c₂, c₃ = get_parameters(A, idx)
out += Δt₀^2 *
(3A.p.c₁[idx] + (3Δt₁ + Δt₀) * A.p.c₂[idx] + (3Δt₁^2 + Δt₀ * 2Δt₁) * A.p.c₃[idx])
(3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃)
out, idx
end
Loading

0 comments on commit 4366656

Please sign in to comment.