Skip to content

Commit

Permalink
Merge pull request #247 from sathvikbhagavan/sb/sym_second_order
Browse files Browse the repository at this point in the history
refactor!: remove indexing dispatches and add dispatch for higher order derivatives with Symbolics
  • Loading branch information
ChrisRackauckas authored May 6, 2024
2 parents b1f9851 + 40a6d96 commit 62749c6
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 97 deletions.
2 changes: 1 addition & 1 deletion ext/DataInterpolationsOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function Curvefit(u,
mfit = optimize(od, lb, ub, p0, Fminbox(alg))
end
pmin = Optim.minimizer(mfit)
CurvefitCache{true}(u, t, model, p0, ub, lb, alg, pmin, extrapolate)
CurvefitCache(u, t, model, p0, ub, lb, alg, pmin, extrapolate)
end

# Curvefit
Expand Down
14 changes: 7 additions & 7 deletions ext/DataInterpolationsRegularizationToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Abstrac
Wls½ = LA.diagm(sqrt.(wls))
Wr½ = LA.diagm(sqrt.(wr))
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
RegularizationSmooth{true}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp, extrapolate)
RegularizationSmooth(u, û, t, t̂, wls, wr, d, λ, alg, Aitp, extrapolate)
end
"""
Direct smoothing, no `t̂` or weights
Expand All @@ -94,7 +94,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2;
Wls½ = Array{Float64}(LA.I, N, N)
Wr½ = Array{Float64}(LA.I, N - d, N - d)
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
RegularizationSmooth{true}(u,
RegularizationSmooth(u,
û,
t,
t̂,
Expand All @@ -121,7 +121,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Abstrac
Wls½ = Array{Float64}(LA.I, N, N)
Wr½ = Array{Float64}(LA.I, N̂ - d, N̂ - d)
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
RegularizationSmooth{true}(u,
RegularizationSmooth(u,
û,
t,
t̂,
Expand Down Expand Up @@ -149,7 +149,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Abstrac
Wls½ = LA.diagm(sqrt.(wls))
Wr½ = Array{Float64}(LA.I, N̂ - d, N̂ - d)
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
RegularizationSmooth{true}(u,
RegularizationSmooth(u,
û,
t,
t̂,
Expand Down Expand Up @@ -179,7 +179,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing
Wls½ = LA.diagm(sqrt.(wls))
Wr½ = Array{Float64}(LA.I, N - d, N - d)
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
RegularizationSmooth{true}(u,
RegularizationSmooth(u,
û,
t,
t̂,
Expand Down Expand Up @@ -209,7 +209,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing
Wls½ = LA.diagm(sqrt.(wls))
Wr½ = LA.diagm(sqrt.(wr))
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
RegularizationSmooth{true}(u,
RegularizationSmooth(u,
û,
t,
t̂,
Expand Down Expand Up @@ -240,7 +240,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing
Wls½ = LA.diagm(sqrt.(wls))
Wr½ = LA.diagm(sqrt.(wr))
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
RegularizationSmooth{true}(u,
RegularizationSmooth(u,
û,
t,
t̂,
Expand Down
4 changes: 4 additions & 0 deletions ext/DataInterpolationsSymbolicsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ function derivative(interp::AbstractInterpolation, t::Num, order = 1)
end
SymbolicUtils.promote_symtype(::typeof(derivative), _...) = Real

function Symbolics.derivative(::typeof(derivative), args::NTuple{3, Any}, ::Val{2})
Symbolics.unwrap(derivative(args[1], Symbolics.wrap(args[2]), args[3] + 1))
end

function Symbolics.derivative(interp::AbstractInterpolation, args::NTuple{1, Any}, ::Val{1})
Symbolics.unwrap(derivative(interp, Symbolics.wrap(args[1])))
end
Expand Down
33 changes: 9 additions & 24 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,7 @@ module DataInterpolations

### Interface Functionality

abstract type AbstractInterpolation{FT, T} <: AbstractVector{T} end

Base.size(A::AbstractInterpolation) = size(A.u)
Base.size(A::AbstractInterpolation{true}) = length(A.u) .+ size(A.t)
Base.getindex(A::AbstractInterpolation, i) = A.u[i]
function Base.getindex(A::AbstractInterpolation{true}, i)
i <= length(A.u) ? A.u[i] : A.t[i - length(A.u)]
end
Base.setindex!(A::AbstractInterpolation, x, i) = A.u[i] = x
function Base.setindex!(A::AbstractInterpolation{true}, x, i)
i <= length(A.u) ? (A.u[i] = x) : (A.t[i - length(A.u)] = x)
end
abstract type AbstractInterpolation{T} end

using LinearAlgebra, RecipesBase
using PrettyTables
Expand Down Expand Up @@ -67,7 +56,7 @@ export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,

# added for RegularizationSmooth, JJS 11/27/21
### Regularization data smoothing and interpolation
struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT, T}
struct RegularizationSmooth{uType, tType, T, T2} <: AbstractInterpolation{T}
u::uType
::uType
t::tType
Expand All @@ -77,9 +66,9 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT
d::Int # derivative degree used to calculate the roughness
λ::T2 # regularization parameter
alg::Symbol # how to determine λ: `:fixed`, `:gcv_svd`, `:gcv_tr`, `L_curve`
Aitp::AbstractInterpolation{FT, T}
Aitp::AbstractInterpolation{T}
extrapolate::Bool
function RegularizationSmooth{FT}(u,
function RegularizationSmooth(u,
û,
t,
t̂,
Expand All @@ -89,8 +78,8 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT
λ,
alg,
Aitp,
extrapolate) where {FT}
new{typeof(u), typeof(t), FT, eltype(u), typeof(λ)}(u,
extrapolate)
new{typeof(u), typeof(t), eltype(u), typeof(λ)}(u,
û,
t,
t̂,
Expand All @@ -116,9 +105,8 @@ struct CurvefitCache{
lbType,
algType,
pminType,
FT,
T
} <: AbstractInterpolation{FT, T}
} <: AbstractInterpolation{T}
u::uType
t::tType
m::mType # model type
Expand All @@ -128,10 +116,10 @@ struct CurvefitCache{
alg::algType # alg to optimize cost function
pmin::pminType # optimized params
extrapolate::Bool
function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin, extrapolate) where {FT}
function CurvefitCache(u, t, m, p0, ub, lb, alg, pmin, extrapolate)
new{typeof(u), typeof(t), typeof(m),
typeof(p0), typeof(ub), typeof(lb),
typeof(alg), typeof(pmin), FT, eltype(u)}(u,
typeof(alg), typeof(pmin), eltype(u)}(u,
t,
m,
p0,
Expand All @@ -150,7 +138,4 @@ end

export Curvefit

# Deprecated April 2020
export ZeroSpline

end # module
2 changes: 1 addition & 1 deletion src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number)
if A.t[idx2] == t2
idx2 -= 1
end
total = zero(eltype(A))
total = zero(eltype(A.u))
for idx in idx1:idx2
lt1 = idx == idx1 ? t1 : A.t[idx]
lt2 = idx == idx2 ? t2 : A.t[idx + 1]
Expand Down
Loading

0 comments on commit 62749c6

Please sign in to comment.