Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add derivatives for RegularizationSmooth and Curvefit #204

Merged
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
DataInterpolationsOptimExt = "Optim"
DataInterpolationsOptimExt = ["ForwardDiff", "Optim"]
DataInterpolationsRegularizationToolsExt = "RegularizationTools"
DataInterpolationsSymbolicsExt = "Symbolics"

[compat]
ChainRulesCore = "0.9.44, 0.10, 1"
ForwardDiff = "0.10"
LinearAlgebra = "1.6"
Optim = "0.19, 0.20, 0.21, 0.22, 1.0"
PrettyTables = "2"
Expand Down
16 changes: 8 additions & 8 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ else
end

function ChainRulesCore.rrule(::typeof(_interpolate),
A::Union{
LagrangeInterpolation,
AkimaInterpolation,
BSplineInterpolation,
BSplineApprox,
},
t::Number)
A::Union{
LagrangeInterpolation,
AkimaInterpolation,
BSplineInterpolation,
BSplineApprox,
},
t::Number)
deriv = derivative(A, t)
interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ)
return _interpolate(A, t), interpolate_pullback
end

function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation,
t::Number)
t::Number)
return _interpolate(A, t), derivative(A, t) * Δt
end

Expand Down
75 changes: 26 additions & 49 deletions ext/DataInterpolationsOptimExt.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,22 @@
module DataInterpolationsOptimExt

if isdefined(Base, :get_extension)
using DataInterpolations: AbstractInterpolation, munge_data
import DataInterpolations: Curvefit, _interpolate, get_show
using Reexport
else
using ..DataInterpolations: AbstractInterpolation, munge_data
import ..DataInterpolations: Curvefit, _interpolate, get_show
using Reexport
end
using DataInterpolations
import DataInterpolations: munge_data,
Curvefit, CurvefitCache, _interpolate, get_show, derivative, ExtrapolationError

isdefined(Base, :get_extension) ? (@reexport using Optim) : (@reexport using ..Optim)
isdefined(Base, :get_extension) ? (using Optim, ForwardDiff) :
(using ..Optim, ..ForwardDiff)

### Curvefit
struct CurvefitCache{
uType,
tType,
mType,
p0Type,
ubType,
lbType,
algType,
pminType,
FT,
T,
} <: AbstractInterpolation{FT, T}
u::uType
t::tType
m::mType # model type
p0::p0Type # intial params
ub::ubType # upper bound of params
lb::lbType # lower bound of params
alg::algType # alg to optimize cost function
pmin::pminType # optimized params
function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin) where {FT}
new{typeof(u), typeof(t), typeof(m),
typeof(p0), typeof(ub), typeof(lb),
typeof(alg), typeof(pmin), FT, eltype(u)}(u,
t,
m,
p0,
ub,
lb,
alg,
pmin)
end
end

function Curvefit(u, t, model, p0, alg, box = false, lb = nothing, ub = nothing)
function Curvefit(u,
t,
model,
p0,
alg,
box = false,
lb = nothing,
ub = nothing;
extrapolate = false)
u, t = munge_data(u, t)
errfun(t, u, p) = sum(abs2.(u .- model(t, p)))
if box == false
Expand All @@ -60,21 +29,29 @@ function Curvefit(u, t, model, p0, alg, box = false, lb = nothing, ub = nothing)
mfit = optimize(od, lb, ub, p0, Fminbox(alg))
end
pmin = Optim.minimizer(mfit)
CurvefitCache{true}(u, t, model, p0, ub, lb, alg, pmin)
CurvefitCache{true}(u, t, model, p0, ub, lb, alg, pmin, extrapolate)
end

# Curvefit
function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}},
t::Union{AbstractVector{<:Number}, Number})
t::Union{AbstractVector{<:Number}, Number})
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) &&
throw(ExtrapolationError())
A.m(t, A.pmin)
end

function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}},
t::Union{AbstractVector{<:Number}, Number},
i)
t::Union{AbstractVector{<:Number}, Number},
i)
_interpolate(A, t), i
end

function derivative(A::CurvefitCache{<:AbstractVector{<:Number}},
t::Union{AbstractVector{<:Number}, Number})
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
ForwardDiff.derivative(x -> A.m(x, A.pmin), t)
end

function get_show(interp::CurvefitCache)
return "Curvefit" *
" with $(length(interp.t)) points, using $(nameof(typeof(interp.alg)))\n"
Expand Down
Loading