Skip to content

Commit

Permalink
refactor: add extrapolate in RegularizationSmooth, move CurvefitCache…
Browse files Browse the repository at this point in the history
… constructor, add ForwardDiff to requires for Curvefit
  • Loading branch information
sathvikbhagavan committed Nov 6, 2023
1 parent 4468cda commit e8df1bc
Showing 1 changed file with 61 additions and 8 deletions.
69 changes: 61 additions & 8 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,18 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT
λ::T2 # regularization parameter
alg::Symbol # how to determine λ: `:fixed`, `:gcv_svd`, `:gcv_tr`, `L_curve`
Aitp::AbstractInterpolation{FT, T}
function RegularizationSmooth{FT}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp) where {FT}
extrapolate::Bool
function RegularizationSmooth{FT}(u,
û,
t,
t̂,
wls,
wr,
d,
λ,
alg,
Aitp,
extrapolate) where {FT}
new{typeof(u), typeof(t), FT, eltype(u), typeof(λ)}(u,
û,
t,
Expand All @@ -71,12 +82,57 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT
d,
λ,
alg,
Aitp)
Aitp,
extrapolate)
end
end

export RegularizationSmooth

# CurveFit
struct CurvefitCache{
uType,
tType,
mType,
p0Type,
ubType,
lbType,
algType,
pminType,
FT,
T,
} <: AbstractInterpolation{FT, T}
u::uType
t::tType
m::mType # model type
p0::p0Type # intial params
ub::ubType # upper bound of params
lb::lbType # lower bound of params
alg::algType # alg to optimize cost function
pmin::pminType # optimized params
extrapolate::Bool
function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(m),
typeof(p0), typeof(ub), typeof(lb),
typeof(alg), typeof(pmin), FT, eltype(u)}(u,
t,
m,
p0,
ub,
lb,
alg,
pmin,
extrapolate)
end
end

# Define an empty function, so that it can be extended via `DataInterpolationsOptimExt`
function Curvefit()
error("CurveFit requires loading Optim and ForwardDiff, e.g. `using Optim, ForwardDiff`")
end

export Curvefit

@static if !isdefined(Base, :get_extension)
using Requires
end
Expand All @@ -87,7 +143,9 @@ end
include("../ext/DataInterpolationsChainRulesCoreExt.jl")
end
Requires.@require Optim="429524aa-4258-5aef-a3af-852621145aeb" begin
include("../ext/DataInterpolationsOptimExt.jl")
Requires.@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
include("../ext/DataInterpolationsOptimExt.jl")
end
end
Requires.@require RegularizationTools="29dad682-9a27-4bc3-9c72-016788665182" begin
include("../ext/DataInterpolationsRegularizationToolsExt.jl")
Expand All @@ -98,11 +156,6 @@ end
end
end

# Define an empty function, so that it can be extended via `DataInterpolationsOptimExt`
Curvefit() = error("CurveFit requires loading Optim, e.g. `using Optim`")

export Curvefit

# Deprecated April 2020
export ZeroSpline

Expand Down

0 comments on commit e8df1bc

Please sign in to comment.