From e8df1bc3a667ed941858d44cc6ced50418663939 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Mon, 6 Nov 2023 22:45:41 +0530 Subject: [PATCH] refactor: add extrapolate in RegularizationSmooth, move CurvefitCache constructor, add ForwardDiff to requires for Curvefit --- src/DataInterpolations.jl | 69 ++++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index 96f9ebab..1d59e0a1 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -61,7 +61,18 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT λ::T2 # regularization parameter alg::Symbol # how to determine λ: `:fixed`, `:gcv_svd`, `:gcv_tr`, `L_curve` Aitp::AbstractInterpolation{FT, T} - function RegularizationSmooth{FT}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp) where {FT} + extrapolate::Bool + function RegularizationSmooth{FT}(u, + û, + t, + t̂, + wls, + wr, + d, + λ, + alg, + Aitp, + extrapolate) where {FT} new{typeof(u), typeof(t), FT, eltype(u), typeof(λ)}(u, û, t, @@ -71,12 +82,57 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT d, λ, alg, - Aitp) + Aitp, + extrapolate) end end export RegularizationSmooth +# CurveFit +struct CurvefitCache{ + uType, + tType, + mType, + p0Type, + ubType, + lbType, + algType, + pminType, + FT, + T, +} <: AbstractInterpolation{FT, T} + u::uType + t::tType + m::mType # model type + p0::p0Type # intial params + ub::ubType # upper bound of params + lb::lbType # lower bound of params + alg::algType # alg to optimize cost function + pmin::pminType # optimized params + extrapolate::Bool + function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin, extrapolate) where {FT} + new{typeof(u), typeof(t), typeof(m), + typeof(p0), typeof(ub), typeof(lb), + typeof(alg), typeof(pmin), FT, eltype(u)}(u, + t, + m, + p0, + ub, + lb, + alg, + pmin, + extrapolate) + end +end + +# Define an empty function, so that it can be extended via `DataInterpolationsOptimExt` +function Curvefit() + error("CurveFit requires loading Optim and ForwardDiff, e.g. `using Optim, ForwardDiff`") +end + +export Curvefit + @static if !isdefined(Base, :get_extension) using Requires end @@ -87,7 +143,9 @@ end include("../ext/DataInterpolationsChainRulesCoreExt.jl") end Requires.@require Optim="429524aa-4258-5aef-a3af-852621145aeb" begin - include("../ext/DataInterpolationsOptimExt.jl") + Requires.@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/DataInterpolationsOptimExt.jl") + end end Requires.@require RegularizationTools="29dad682-9a27-4bc3-9c72-016788665182" begin include("../ext/DataInterpolationsRegularizationToolsExt.jl") @@ -98,11 +156,6 @@ end end end -# Define an empty function, so that it can be extended via `DataInterpolationsOptimExt` -Curvefit() = error("CurveFit requires loading Optim, e.g. `using Optim`") - -export Curvefit - # Deprecated April 2020 export ZeroSpline