Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor parameter caching + add Zygote tests #315

Merged
merged 11 commits into from
Jul 28, 2024
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

Expand All @@ -16,6 +15,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
Expand All @@ -33,14 +33,14 @@ LinearAlgebra = "1.10"
Optim = "1.6"
PrettyTables = "2"
QuadGK = "2.9.1"
ReadOnlyArrays = "0.2.0"
RecipesBase = "1.3"
Reexport = "1"
RegularizationTools = "0.6"
SafeTestsets = "0.1"
StableRNGs = "1"
Symbolics = "5.29"
Test = "1"
Zygote = "0.6"
julia = "1.10"

[extras]
Expand All @@ -55,6 +55,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"]
test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"]
17 changes: 1 addition & 16 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,7 @@ A2(300.0)

The values computed beyond the range of the time points provided during interpolation will not be reliable, as these methods only perform well within the range and the first/last piece polynomial fit is extrapolated on either side which might not reflect the true nature of the data.

The keyword `safetycopy = false` can be passed to make sure no copies of `u` and `t` are made when initializing the interpolation object.

```@example interface
A3 = QuadraticInterpolation(u, t; safetycopy = false)

# Check for same memory
u === A3.u.parent
```

Note that this does not prevent allocation in every interpolation constructor call, because parameter values are cached for all interpolation types except [`ConstantInterpolation`](@ref).

Because of the caching of parameters which depend on `u` and `t`, this data should not be mutated. Therefore `u` and `t` are wrapped in a `ReadOnlyArray` from [ReadOnlyArrays.jl](https://github.com/JuliaArrays/ReadOnlyArrays.jl).

```@repl interface
A3.t[2] = 3.14
```
The keyword `cache_parameters = true` can be passed to precalculate parameters at initialization, making evalations cheaper to compute. This is not compatible with modifying `u` and `t`. The default `cache_parameters = false` does however not prevent allocation in every interpolation constructor call.

## Derivatives

Expand Down
78 changes: 74 additions & 4 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,97 @@
module DataInterpolationsChainRulesCoreExt

if isdefined(Base, :get_extension)
using DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox
BSplineInterpolation, BSplineApprox, get_idx, get_parameters,
_quad_interp_indices
using ChainRulesCore
else
using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox
BSplineInterpolation, BSplineApprox, get_parameters,
_quad_interp_indices
using ..ChainRulesCore
end

function ChainRulesCore.rrule(
::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters)
A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters)
function LinearInterpolation_pullback(ΔA)
df = NoTangent()
du = ΔA.u
dt = NoTangent()
dI = NoTangent()
dp = NoTangent()
dextrapolate = NoTangent()
dcache_parameters = NoTangent()
df, du, dt, dI, dp, dextrapolate, dcache_parameters
end

A, LinearInterpolation_pullback
end

function ChainRulesCore.rrule(
::Type{QuadraticInterpolation}, u, t, I, p, mode, extrapolate, cache_parameters)
A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters)
function LinearInterpolation_pullback(ΔA)
df = NoTangent()
du = ΔA.u
dt = NoTangent()
dI = NoTangent()
dp = NoTangent()
dmode = NoTangent()
dextrapolate = NoTangent()
dcache_parameters = NoTangent()
df, du, dt, dI, dp, dmode, dextrapolate, dcache_parameters
end

A, LinearInterpolation_pullback
end

function u_tangent(A::LinearInterpolation, t, Δ)
out = zero(A.u)
idx = get_idx(A.t, t, A.idx_prev[])
t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx])
out[idx] = Δ * (one(eltype(out)) - t_factor)
out[idx + 1] = Δ * t_factor
out
end

function u_tangent(A::QuadraticInterpolation, t, Δ)
out = zero(A.u)
i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[])
t₀ = A.t[i₀]
t₁ = A.t[i₁]
t₂ = A.t[i₂]
Δt₀ = t₁ - t₀
Δt₁ = t₂ - t₁
Δt₂ = t₂ - t₀
out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂)
out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁)
out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁)
out
end

function u_tangent(A, t, Δ)
NoTangent()
end

function ChainRulesCore.rrule(::typeof(_interpolate),
A::Union{
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
AkimaInterpolation,
BSplineInterpolation,
BSplineApprox
},
t::Number)
deriv = derivative(A, t)
interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ)
function interpolate_pullback(Δ)
(NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), deriv * Δ)
end
return _interpolate(A, t), interpolate_pullback
end

Expand Down
5 changes: 2 additions & 3 deletions ext/DataInterpolationsOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ function Curvefit(u,
box = false,
lb = nothing,
ub = nothing;
extrapolate = false,
safetycopy = false)
u, t = munge_data(u, t, safetycopy)
extrapolate = false)
u, t = munge_data(u, t)
errfun(t, u, p) = sum(abs2.(u .- model(t, p)))
if box == false
mfit = optimize(p -> errfun(t, u, p), p0, alg)
Expand Down
28 changes: 14 additions & 14 deletions ext/DataInterpolationsRegularizationToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ = 1.0, alg = :gcv_svd)
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
wls::AbstractVector, wr::AbstractVector, d::Int = 2;
λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
M = _mapping_matrix(t̂, t)
Wls½ = LA.diagm(sqrt.(wls))
Wr½ = LA.diagm(sqrt.(wr))
Expand All @@ -86,8 +86,8 @@ A = RegularizationSmooth(u, t, d; λ = 1.0, alg = :gcv_svd, extrapolate = false)
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2;
λ::Real = 1.0,
alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
t̂ = t
N = length(t)
M = Array{Float64}(LA.I, N, N)
Expand Down Expand Up @@ -115,8 +115,8 @@ A = RegularizationSmooth(u, t, t̂, d; λ = 1.0, alg = :gcv_svd, extrapolate = f
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd,
extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
extrapolate::Bool = false)
u, t = munge_data(u, t)
N, N̂ = length(t), length(t̂)
M = _mapping_matrix(t̂, t)
Wls½ = Array{Float64}(LA.I, N, N)
Expand All @@ -143,8 +143,8 @@ A = RegularizationSmooth(u, t, t̂, wls, d; λ = 1.0, alg = :gcv_svd, extrapolat
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
wls::AbstractVector, d::Int = 2; λ::Real = 1.0,
alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
N, N̂ = length(t), length(t̂)
M = _mapping_matrix(t̂, t)
Wls½ = LA.diagm(sqrt.(wls))
Expand Down Expand Up @@ -172,8 +172,8 @@ A = RegularizationSmooth(
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
wls::AbstractVector, d::Int = 2; λ::Real = 1.0,
alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
t̂ = t
N = length(t)
M = Array{Float64}(LA.I, N, N)
Expand Down Expand Up @@ -202,8 +202,8 @@ A = RegularizationSmooth(
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
wls::AbstractVector, wr::AbstractVector, d::Int = 2;
λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false)
u, t = munge_data(u, t)
t̂ = t
N = length(t)
M = Array{Float64}(LA.I, N, N)
Expand Down Expand Up @@ -232,8 +232,8 @@ A = RegularizationSmooth(
"""
function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd,
extrapolate::Bool = false, safetycopy::Bool = true)
u, t = munge_data(u, t, safetycopy)
extrapolate::Bool = false)
u, t = munge_data(u, t)
t̂ = t
N = length(t)
M = Array{Float64}(LA.I, N, N)
Expand Down
19 changes: 6 additions & 13 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ abstract type AbstractInterpolation{T} end
using LinearAlgebra, RecipesBase
using PrettyTables
using ForwardDiff
using ReadOnlyArrays
import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated,
bracketstrictlymontonic

Expand Down Expand Up @@ -90,12 +89,6 @@ function Base.showerror(io::IO, e::IntegralNotInvertibleError)
print(io, INTEGRAL_NOT_INVERTIBLE_ERROR)
end

const MUST_COPY_ERROR = "A copy must be made of u, t to filter missing data"
struct MustCopyError <: Exception end
function Base.showerror(io::IO, e::MustCopyError)
print(io, MUST_COPY_ERROR)
end

export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,
AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline,
BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation,
Expand Down Expand Up @@ -128,12 +121,12 @@ struct RegularizationSmooth{uType, tType, T, T2, ITP <: AbstractInterpolation{T}
Aitp,
extrapolate)
new{typeof(u), typeof(t), eltype(u), typeof(λ), typeof(Aitp)}(
readonly_wrap(u),
readonly_wrap(û),
readonly_wrap(t),
readonly_wrap(t̂),
readonly_wrap(oftype(u.parent, wls)),
readonly_wrap(oftype(u.parent, wr)),
u,
,
t,
,
wls,
wr,
d,
λ,
alg,
Expand Down
23 changes: 14 additions & 9 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ end

function _derivative(A::LinearInterpolation, t::Number, iguess)
idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -1, side = :first)
A.p.slope[idx], idx
slope = get_parameters(A, idx)
slope, idx
end

function _derivative(A::QuadraticInterpolation, t::Number, iguess)
i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess)
du₀ = A.p.l₀[i₀] * (2t - A.t[i₁] - A.t[i₂])
du₁ = A.p.l₁[i₀] * (2t - A.t[i₀] - A.t[i₂])
du₂ = A.p.l₂[i₀] * (2t - A.t[i₀] - A.t[i₁])
l₀, l₁, l₂ = get_parameters(A, i₀)
du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂])
du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂])
du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁])
return @views @. du₀ + du₁ + du₂, i₀
end

Expand Down Expand Up @@ -129,7 +131,7 @@ end
# QuadraticSpline Interpolation
function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
σ = A.p.σ[idx - 1]
σ = get_parameters(A, idx - 1)
A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx
end

Expand All @@ -139,8 +141,9 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
Δt₁ = t - A.t[idx]
Δt₂ = A.t[idx + 1] - t
dI = (-A.z[idx] * Δt₂^2 + A.z[idx + 1] * Δt₁^2) / (2A.h[idx + 1])
dC = A.p.c₁[idx]
dD = -A.p.c₂[idx]
c₁, c₂ = get_parameters(A, idx)
dC = c₁
dD = -c₂
dI + dC + dD, idx
end

Expand Down Expand Up @@ -193,7 +196,8 @@ function _derivative(
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
out = A.du[idx]
out += Δt₀ * (Δt₀ * A.p.c₂[idx] + 2(A.p.c₁[idx] + Δt₁ * A.p.c₂[idx]))
c₁, c₂ = get_parameters(A, idx)
out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂))
out, idx
end

Expand All @@ -204,7 +208,8 @@ function _derivative(
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
out = A.du[idx] + A.ddu[idx] * Δt₀
c₁, c₂, c₃ = get_parameters(A, idx)
out += Δt₀^2 *
(3A.p.c₁[idx] + (3Δt₁ + Δt₀) * A.p.c₂[idx] + (3Δt₁^2 + Δt₀ * 2Δt₁) * A.p.c₃[idx])
(3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃)
out, idx
end
Loading
Loading