From b15bc9bdfb5b6468259fcfe98270e3e811861f8f Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <sathvikbhagavan@gmail.com> Date: Mon, 6 Nov 2023 22:38:26 +0530 Subject: [PATCH 1/7] test: add derivative and extrapolation tests for RegularizationSmooth and Curvefit --- test/derivative_tests.jl | 78 ++++++++++++++++++++++++++++--------- test/interpolation_tests.jl | 47 ++++++++++++---------- test/regularization.jl | 7 ++++ 3 files changed, 93 insertions(+), 39 deletions(-) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 7da2a247..8828b8a0 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -2,9 +2,13 @@ using DataInterpolations, Test using FiniteDifferences using DataInterpolations: derivative using Symbolics +using StableRNGs +using RegularizationTools +using Optim +using ForwardDiff -function test_derivatives(method, u, t, args...; name::String) - func = method(u, t, args...; extrapolate = true) +function test_derivatives(method, u, t; args = [], kwargs = [], name::String) + func = method(u, t, args...; kwargs..., extrapolate = true) trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) trange_exclude = filter(x -> !in(x, t), trange) @testset "$name" begin @@ -16,7 +20,7 @@ function test_derivatives(method, u, t, args...; name::String) end # Interpolation time points - for _t in t[2:end-1] + for _t in t[2:(end - 1)] fdiff = if func isa BSplineInterpolation || func isa BSplineApprox forward_fdm(5, 1; geom = true)(func, _t) else @@ -47,7 +51,7 @@ function test_derivatives(method, u, t, args...; name::String) end @testset "Linear Interpolation" begin - u = vcat(collect(1:5), 2*collect(6:10)) + u = vcat(collect(1:5), 2 * collect(6:10)) t = 1.0collect(1:10) test_derivatives(LinearInterpolation, u, t; name = "Linear Interpolation (Vector)") u = vcat(2.0collect(1:10)', 3.0collect(1:10)') @@ -63,8 +67,8 @@ end name = "Quadratic Interpolation (Vector)") test_derivatives(QuadraticInterpolation, u, - t, - :Backward; + t; + args = [:Backward], name = "Quadratic Interpolation (Vector), backward") u = [1.0 4.0 9.0 16.0; 1.0 4.0 9.0 16.0] test_derivatives(QuadraticInterpolation, @@ -139,28 +143,64 @@ end u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] test_derivatives(BSplineInterpolation, u, - t, - 2, - :Uniform, - :Uniform; + t; + args = [2, + :Uniform, + :Uniform], name = "BSpline Interpolation (Uniform, Uniform)") test_derivatives(BSplineInterpolation, u, - t, - 2, - :ArcLen, - :Average; + t; + args = [2, + :ArcLen, + :Average], name = "BSpline Interpolation (Arclen, Average)") test_derivatives(BSplineApprox, u, - t, - 3, - 4, - :Uniform, - :Uniform; + t; + args = [ + 3, + 4, + :Uniform, + :Uniform], name = "BSpline Approx (Uniform, Uniform)") end +@testset "RegularizationSmooth" begin + npts = 50 + xmin = 0.0 + xspan = 3 / 2 * π + x = collect(range(xmin, xmin + xspan, length = npts)) + rng = StableRNG(655) + x = x + xspan / npts * (rand(rng, npts) .- 0.5) + # select a subset randomly + idx = unique(rand(rng, collect(eachindex(x)), 20)) + t = x[unique(idx)] + npts = length(t) + ut = sin.(t) + stdev = 1e-1 * maximum(ut) + u = ut + stdev * randn(rng, npts) + # data must be ordered if t̂ is not provided + idx = sortperm(t) + tₒ = t[idx] + uₒ = u[idx] + A = RegularizationSmooth(uₒ, tₒ; alg = :fixed) + test_derivatives(RegularizationSmooth, + uₒ, + tₒ; + kwargs = [:alg => :fixed], + name = "RegularizationSmooth") +end + +@testset "Curvefit" begin + rng = StableRNG(12345) + model(x, p) = @. p[1] / (1 + exp(x - p[2])) + t = range(-10, stop = 10, length = 40) + u = model(t, [1.0, 2.0]) + 0.01 * randn(rng, length(t)) + p0 = [0.5, 0.5] + test_derivatives(Curvefit, u, t; args = [model, p0, LBFGS()], name = "Curvefit") +end + @testset "Symbolic derivatives" begin u = [0.0, 1.5, 0.0] t = [0.0, 0.5, 1.0] diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 5866daad..72b5657e 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -594,26 +594,33 @@ end end end -# Curvefit Interpolation -rng = StableRNG(12345) -model(x, p) = @. p[1] / (1 + exp(x - p[2])) -t = range(-10, stop = 10, length = 40) -u = model(t, [1.0, 2.0]) + 0.01 * randn(rng, length(t)) -p0 = [0.5, 0.5] - -A = Curvefit(u, t, model, p0, LBFGS()) - -ts = [-7.0, -2.0, 0.0, 2.5, 5.0] -vs = [ - 1.0013468217936277, - 0.9836755196317837, - 0.8833959853995836, - 0.3810348276782708, - 0.048062978598861855, -] -us = A.(ts) - -@test vs ≈ us +@testset "Curvefit" begin + # Curvefit Interpolation + rng = StableRNG(12345) + model(x, p) = @. p[1] / (1 + exp(x - p[2])) + t = range(-10, stop = 10, length = 40) + u = model(t, [1.0, 2.0]) + 0.01 * randn(rng, length(t)) + p0 = [0.5, 0.5] + + A = Curvefit(u, t, model, p0, LBFGS()) + + ts = [-7.0, -2.0, 0.0, 2.5, 5.0] + vs = [ + 1.0013468217936277, + 0.9836755196317837, + 0.8833959853995836, + 0.3810348276782708, + 0.048062978598861855, + ] + us = A.(ts) + @test vs ≈ us + + # Test extrapolation + A = Curvefit(u, t, model, p0, LBFGS(); extrapolate = true) + @test A(15.0) == model(15.0, A.pmin) + A = Curvefit(u, t, model, p0, LBFGS()) + @test_throws DataInterpolations.ExtrapolationError A(15.0) +end # missing values handling tests u = [1.0, 4.0, 9.0, 16.0, 25.0, missing, missing] diff --git a/test/regularization.jl b/test/regularization.jl index f07b4ad6..39d8e92c 100644 --- a/test/regularization.jl +++ b/test/regularization.jl @@ -177,3 +177,10 @@ end @test isapprox(A.û, ans, rtol = tolerance) @test isapprox(A.(t̂), ans, rtol = tolerance) end + +@testset "Extrapolation" begin + A = RegularizationSmooth(uₒ, tₒ; alg = :fixed, extrapolate = true) + @test A(10.0) == A.Aitp(10.0) + A = RegularizationSmooth(uₒ, tₒ; alg = :fixed) + @test_throws DataInterpolations.ExtrapolationError A(10.0) +end From dfba7a8a463c554d7f80dc6f32887f23e2a4b828 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <sathvikbhagavan@gmail.com> Date: Mon, 6 Nov 2023 22:41:26 +0530 Subject: [PATCH 2/7] feat: add derivative for RegularizationSmooth and add extrapolate in the constructor --- ...ataInterpolationsRegularizationToolsExt.jl | 166 +++++++++++++----- 1 file changed, 118 insertions(+), 48 deletions(-) diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index c7abda64..dae1f606 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -1,7 +1,8 @@ module DataInterpolationsRegularizationToolsExt using DataInterpolations -using DataInterpolations: munge_data, _interpolate, RegularizationSmooth, get_show +import DataInterpolations: munge_data, + _interpolate, RegularizationSmooth, get_show, derivative using LinearAlgebra isdefined(Base, :get_extension) ? (import RegularizationTools as RT) : @@ -30,17 +31,17 @@ const LA = LinearAlgebra """ # Arguments -- `u::Vector`: dependent data -- `t::Vector`: independent data +- `u::Vector`: dependent data. +- `t::Vector`: independent data. # Optional Arguments - `t̂::Vector`: t-values to use for the smooth curve (useful when data has missing values or is "scattered"); if not provided, then `t̂ = t`; must be monotonically - increasing + increasing. - `wls::{Vector,Symbol}`: weights to use with the least-squares fitting term; if set to `:midpoint`, then midpoint-rule integration weights are used for - _both_ `wls` and `wr` -- `wr::Vector`: weights to use with the roughness term + _both_ `wls` and `wr`. +- `wr::Vector`: weights to use with the roughness term. - `d::Int = 2`: derivative used to calculate roughness; e.g., when `d = 2`, the 2nd derivative (i.e. the curvature) of the data is used to calculate roughness. @@ -52,119 +53,171 @@ const LA = LinearAlgebra - `alg::Symbol = :gcv_svd`: algorithm for determining an optimal value for λ; the provided λ value is used directly if `alg = :fixed`; otherwise `alg = [:gcv_svd, :gcv_tr, :L_curve]` is passed to the - RegularizationTools solver + RegularizationTools solver. +- `extrapolate::Bool` = false: flag to allow extrapolating outside the range of the time points provided. ## Example Constructors Smoothing using all arguments ```julia -A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ=[1.0], alg=[:gcv_svd]) +A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ=1.0, alg=:gcv_svd) ``` """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, - wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd) + wls::AbstractVector, wr::AbstractVector, d::Int = 2; + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) u, t = munge_data(u, t) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) - û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg) - RegularizationSmooth{true}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp) + û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate) + RegularizationSmooth{true}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp, extrapolate) end """ Direct smoothing, no `t̂` or weights ```julia -A = RegularizationSmooth(u, t, d; λ=[1.0], alg=[:gcv_svd]) +A = RegularizationSmooth(u, t, d; λ=1.0, alg=:gcv_svd, extrapolate=false) ``` """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2; - λ::Real = 1.0, - alg::Symbol = :gcv_svd) + λ::Real = 1.0, + alg::Symbol = :gcv_svd, extrapolate::Bool = false) u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) Wls½ = Array{Float64}(LA.I, N, N) Wr½ = Array{Float64}(LA.I, N - d, N - d) - û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg) - RegularizationSmooth{true}(u, û, t, t̂, LA.diag(Wls½), LA.diag(Wr½), d, λ, alg, Aitp) + û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate) + RegularizationSmooth{true}(u, + û, + t, + t̂, + LA.diag(Wls½), + LA.diag(Wr½), + d, + λ, + alg, + Aitp, + extrapolate) end """ `t̂` provided, no weights ```julia -A = RegularizationSmooth(u, t, t̂, d; λ=[1.0], alg=[:gcv_svd]) +A = RegularizationSmooth(u, t, t̂, d; λ=1.0, alg=:gcv_svd, extrapolate=false) ``` """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, - d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd) + d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = Array{Float64}(LA.I, N, N) Wr½ = Array{Float64}(LA.I, N̂ - d, N̂ - d) - û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg) - RegularizationSmooth{true}(u, û, t, t̂, LA.diag(Wls½), LA.diag(Wr½), d, λ, alg, Aitp) + û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate) + RegularizationSmooth{true}(u, + û, + t, + t̂, + LA.diag(Wls½), + LA.diag(Wr½), + d, + λ, + alg, + Aitp, + extrapolate) end """ `t̂` and `wls` provided ```julia -A = RegularizationSmooth(u, t, t̂, wls, d; λ=[1.0], alg=[:gcv_svd]) +A = RegularizationSmooth(u, t, t̂, wls, d; λ=1.0, alg=:gcv_svd, extrapolate=false) ``` """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, - wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd) + wls::AbstractVector, d::Int = 2; λ::Real = 1.0, + alg::Symbol = :gcv_svd, extrapolate::Bool = false) u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = Array{Float64}(LA.I, N̂ - d, N̂ - d) - û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg) - RegularizationSmooth{true}(u, û, t, t̂, wls, LA.diag(Wr½), d, λ, alg, Aitp) + û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate) + RegularizationSmooth{true}(u, + û, + t, + t̂, + wls, + LA.diag(Wr½), + d, + λ, + alg, + Aitp, + extrapolate) end """ `wls` provided, no `t̂` ```julia -A = RegularizationSmooth(u, t, nothing, wls,d; λ=[1.0], alg=[:gcv_svd]) +A = RegularizationSmooth(u, t, nothing, wls,d; λ=1.0, alg=:gcv_svd, extrapolate=false) ``` """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, - wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd) + wls::AbstractVector, d::Int = 2; λ::Real = 1.0, + alg::Symbol = :gcv_svd, extrapolate::Bool = false) u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = Array{Float64}(LA.I, N - d, N - d) - û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg) - RegularizationSmooth{true}(u, û, t, t̂, wls, LA.diag(Wr½), d, λ, alg, Aitp) + û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate) + RegularizationSmooth{true}(u, + û, + t, + t̂, + wls, + LA.diag(Wr½), + d, + λ, + alg, + Aitp, + extrapolate) end """ `wls` and `wr` provided, no `t̂` ```julia -A = RegularizationSmooth(u, t, nothing, wls, wr, d; λ=[1.0], alg=[:gcv_svd]) +A = RegularizationSmooth(u, t, nothing, wls, wr, d; λ=1.0, alg=:gcv_svd, extrapolate=false) ``` """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, - wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd) + wls::AbstractVector, wr::AbstractVector, d::Int = 2; + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) - û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg) - RegularizationSmooth{true}(u, û, t, t̂, wls, LA.diag(Wr½), d, λ, alg, Aitp) + û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate) + RegularizationSmooth{true}(u, + û, + t, + t̂, + wls, + LA.diag(Wr½), + d, + λ, + alg, + Aitp, + extrapolate) end """ Keyword provided for `wls`, no `t̂` ```julia -A = RegularizationSmooth(u, t, nothing, :midpoint, d; λ=[1.0], alg=[:gcv_svd]) +A = RegularizationSmooth(u, t, nothing, :midpoint, d; λ=1.0, alg=:gcv_svd, extrapolate=false) ``` """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, - wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd) + wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, + extrapolate::Bool = false) u, t = munge_data(u, t) t̂ = t N = length(t) @@ -172,8 +225,18 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing wls, wr = _weighting_by_kw(t, d, wls) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) - û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg) - RegularizationSmooth{true}(u, û, t, t̂, LA.diag(Wls½), LA.diag(Wr½), d, λ, alg, Aitp) + û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate) + RegularizationSmooth{true}(u, + û, + t, + t̂, + LA.diag(Wls½), + LA.diag(Wr½), + d, + λ, + alg, + Aitp, + extrapolate) end # """ t̂ provided and keyword for wls _TBD_ """ # function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, @@ -181,7 +244,7 @@ end """ Solve for the smoothed dependent variables and create spline interpolator """ function _reg_smooth_solve(u::AbstractVector, t̂::AbstractVector, d::Int, M::AbstractMatrix, - Wls½::AbstractMatrix, Wr½::AbstractMatrix, λ::Real, alg::Symbol) + Wls½::AbstractMatrix, Wr½::AbstractMatrix, λ::Real, alg::Symbol, extrapolate::Bool) λ = float(λ) # `float` expected by RT D = _derivative_matrix(t̂, d) Ψ = RT.setupRegularizationProblem(Wls½ * M, Wr½ * D) @@ -198,7 +261,7 @@ function _reg_smooth_solve(u::AbstractVector, t̂::AbstractVector, d::Int, M::Ab û = result.x λ = result.λ end - Aitp = CubicSpline(û, t̂) + Aitp = CubicSpline(û, t̂; extrapolate) # It seems logical to use B-Spline of order d+1, but I am unsure if theory supports the # extra computational cost, JJS 12/25/21 #Aitp = BSplineInterpolation(û,t̂,d+1,:ArcLen,:Average) @@ -262,14 +325,21 @@ function _weighting_by_kw(t::AbstractVector, d::Int, wls::Symbol) end end -function DataInterpolations._interpolate(A::RegularizationSmooth{ - <:AbstractVector{<:Number}, - }, - t::Number) - DataInterpolations._interpolate(A.Aitp, t) +function _interpolate(A::RegularizationSmooth{ + <:AbstractVector{<:Number}, + }, + t::Number) + _interpolate(A.Aitp, t) end -function DataInterpolations.get_show(interp::RegularizationSmooth) +function derivative(A::RegularizationSmooth{ + <:AbstractVector{<:Number}, + }, + t::Number) + derivative(A.Aitp, t) +end + +function get_show(interp::RegularizationSmooth) return "RegularizationSmooth" * " with $(length(interp.t)) points, with regularization coefficient $(interp.λ)\n" end From c9219c466507b429413aec0c307ff05df6a04c68 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <sathvikbhagavan@gmail.com> Date: Mon, 6 Nov 2023 22:43:14 +0530 Subject: [PATCH 3/7] feat: add derivative of Curvefit using ForwardDiff, add extrpolate in the constructor and remove CurvfitCache and move it to src/DataInterpolations.jl --- ext/DataInterpolationsOptimExt.jl | 75 +++++++++++-------------------- 1 file changed, 26 insertions(+), 49 deletions(-) diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index 5db14436..76ad8eb0 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -1,53 +1,22 @@ module DataInterpolationsOptimExt -if isdefined(Base, :get_extension) - using DataInterpolations: AbstractInterpolation, munge_data - import DataInterpolations: Curvefit, _interpolate, get_show - using Reexport -else - using ..DataInterpolations: AbstractInterpolation, munge_data - import ..DataInterpolations: Curvefit, _interpolate, get_show - using Reexport -end +using DataInterpolations +import DataInterpolations: munge_data, + Curvefit, CurvefitCache, _interpolate, get_show, derivative, ExtrapolationError -isdefined(Base, :get_extension) ? (@reexport using Optim) : (@reexport using ..Optim) +isdefined(Base, :get_extension) ? (using Optim, ForwardDiff) : +(using ..Optim, ..ForwardDiff) ### Curvefit -struct CurvefitCache{ - uType, - tType, - mType, - p0Type, - ubType, - lbType, - algType, - pminType, - FT, - T, -} <: AbstractInterpolation{FT, T} - u::uType - t::tType - m::mType # model type - p0::p0Type # intial params - ub::ubType # upper bound of params - lb::lbType # lower bound of params - alg::algType # alg to optimize cost function - pmin::pminType # optimized params - function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin) where {FT} - new{typeof(u), typeof(t), typeof(m), - typeof(p0), typeof(ub), typeof(lb), - typeof(alg), typeof(pmin), FT, eltype(u)}(u, - t, - m, - p0, - ub, - lb, - alg, - pmin) - end -end - -function Curvefit(u, t, model, p0, alg, box = false, lb = nothing, ub = nothing) +function Curvefit(u, + t, + model, + p0, + alg, + box = false, + lb = nothing, + ub = nothing; + extrapolate = false) u, t = munge_data(u, t) errfun(t, u, p) = sum(abs2.(u .- model(t, p))) if box == false @@ -60,21 +29,29 @@ function Curvefit(u, t, model, p0, alg, box = false, lb = nothing, ub = nothing) mfit = optimize(od, lb, ub, p0, Fminbox(alg)) end pmin = Optim.minimizer(mfit) - CurvefitCache{true}(u, t, model, p0, ub, lb, alg, pmin) + CurvefitCache{true}(u, t, model, p0, ub, lb, alg, pmin, extrapolate) end # Curvefit function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}}, - t::Union{AbstractVector{<:Number}, Number}) + t::Union{AbstractVector{<:Number}, Number}) + ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && + throw(ExtrapolationError()) A.m(t, A.pmin) end function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}}, - t::Union{AbstractVector{<:Number}, Number}, - i) + t::Union{AbstractVector{<:Number}, Number}, + i) _interpolate(A, t), i end +function derivative(A::CurvefitCache{<:AbstractVector{<:Number}}, + t::Union{AbstractVector{<:Number}, Number}) + ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) + ForwardDiff.derivative(x -> A.m(x, A.pmin), t) +end + function get_show(interp::CurvefitCache) return "Curvefit" * " with $(length(interp.t)) points, using $(nameof(typeof(interp.alg)))\n" From 67ee41eb5f0661eda3f86cb76061c9ce37068f1b Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <sathvikbhagavan@gmail.com> Date: Mon, 6 Nov 2023 22:45:41 +0530 Subject: [PATCH 4/7] refactor: add extrapolate in RegularizationSmooth, move CurvefitCache constructor, add ForwardDiff to requires for Curvefit --- src/DataInterpolations.jl | 69 ++++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index 96f9ebab..1d59e0a1 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -61,7 +61,18 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT λ::T2 # regularization parameter alg::Symbol # how to determine λ: `:fixed`, `:gcv_svd`, `:gcv_tr`, `L_curve` Aitp::AbstractInterpolation{FT, T} - function RegularizationSmooth{FT}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp) where {FT} + extrapolate::Bool + function RegularizationSmooth{FT}(u, + û, + t, + t̂, + wls, + wr, + d, + λ, + alg, + Aitp, + extrapolate) where {FT} new{typeof(u), typeof(t), FT, eltype(u), typeof(λ)}(u, û, t, @@ -71,12 +82,57 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT d, λ, alg, - Aitp) + Aitp, + extrapolate) end end export RegularizationSmooth +# CurveFit +struct CurvefitCache{ + uType, + tType, + mType, + p0Type, + ubType, + lbType, + algType, + pminType, + FT, + T, +} <: AbstractInterpolation{FT, T} + u::uType + t::tType + m::mType # model type + p0::p0Type # intial params + ub::ubType # upper bound of params + lb::lbType # lower bound of params + alg::algType # alg to optimize cost function + pmin::pminType # optimized params + extrapolate::Bool + function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin, extrapolate) where {FT} + new{typeof(u), typeof(t), typeof(m), + typeof(p0), typeof(ub), typeof(lb), + typeof(alg), typeof(pmin), FT, eltype(u)}(u, + t, + m, + p0, + ub, + lb, + alg, + pmin, + extrapolate) + end +end + +# Define an empty function, so that it can be extended via `DataInterpolationsOptimExt` +function Curvefit() + error("CurveFit requires loading Optim and ForwardDiff, e.g. `using Optim, ForwardDiff`") +end + +export Curvefit + @static if !isdefined(Base, :get_extension) using Requires end @@ -87,7 +143,9 @@ end include("../ext/DataInterpolationsChainRulesCoreExt.jl") end Requires.@require Optim="429524aa-4258-5aef-a3af-852621145aeb" begin - include("../ext/DataInterpolationsOptimExt.jl") + Requires.@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/DataInterpolationsOptimExt.jl") + end end Requires.@require RegularizationTools="29dad682-9a27-4bc3-9c72-016788665182" begin include("../ext/DataInterpolationsRegularizationToolsExt.jl") @@ -98,11 +156,6 @@ end end end -# Define an empty function, so that it can be extended via `DataInterpolationsOptimExt` -Curvefit() = error("CurveFit requires loading Optim, e.g. `using Optim`") - -export Curvefit - # Deprecated April 2020 export ZeroSpline From 66ecc9c7804d40cf042e68f96a6b838a2dd6b1d6 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <sathvikbhagavan@gmail.com> Date: Mon, 6 Nov 2023 22:46:15 +0530 Subject: [PATCH 5/7] chore: format files --- ext/DataInterpolationsChainRulesCoreExt.jl | 16 ++--- src/derivatives.jl | 4 +- src/integrals.jl | 8 +-- src/interpolation_caches.jl | 42 +++++------ src/interpolation_methods.jl | 4 +- src/interpolation_utils.jl | 6 +- src/plot_rec.jl | 84 +++++++++++----------- 7 files changed, 82 insertions(+), 82 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 9508f16c..59306c46 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -11,20 +11,20 @@ else end function ChainRulesCore.rrule(::typeof(_interpolate), - A::Union{ - LagrangeInterpolation, - AkimaInterpolation, - BSplineInterpolation, - BSplineApprox, - }, - t::Number) + A::Union{ + LagrangeInterpolation, + AkimaInterpolation, + BSplineInterpolation, + BSplineApprox, + }, + t::Number) deriv = derivative(A, t) interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) return _interpolate(A, t), interpolate_pullback end function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation, - t::Number) + t::Number) return _interpolate(A, t), derivative(A, t) * Δt end diff --git a/src/derivatives.jl b/src/derivatives.jl index 518a9945..fb6a62c2 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -64,7 +64,7 @@ function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number) tmp += k end end - der += A.u[j]*tmp + der += A.u[j] * tmp end der end @@ -98,7 +98,7 @@ function derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number) tmp += k end end - @. der += A.u[:, j]*tmp + @. der += A.u[:, j] * tmp end der end diff --git a/src/integrals.jl b/src/integrals.jl index 1c42a241..d8946656 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -24,8 +24,8 @@ end samples(A::LinearInterpolation{<:AbstractVector}) = (0, 1) function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}}, - idx::Number, - t::Number) + idx::Number, + t::Number) t1 = A.t[idx] t2 = A.t[idx + 1] u1 = A.u[idx] @@ -46,8 +46,8 @@ end samples(A::QuadraticInterpolation{<:AbstractVector}) = (0, 1) function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}}, - idx::Number, - t::Number) + idx::Number, + t::Number) A.mode == :Backward && idx > 1 && (idx -= 1) idx = min(length(A.t) - 2, idx) t1 = A.t[idx] diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 18e5cac6..f9955e9b 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -149,8 +149,8 @@ struct QuadraticSpline{uType, tType, tAType, dType, zType, FT, T} <: end function QuadraticSpline(u::uType, - t; - extrapolate = false) where {uType <: AbstractVector{<:Number}} + t; + extrapolate = false) where {uType <: AbstractVector{<:Number}} u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) @@ -199,8 +199,8 @@ struct CubicSpline{uType, tType, hType, zType, FT, T} <: AbstractInterpolation{F end function CubicSpline(u::uType, - t; - extrapolate = false) where {uType <: AbstractVector{<:Number}} + t; + extrapolate = false) where {uType <: AbstractVector{<:Number}} u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) @@ -250,14 +250,14 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, FT, T} <: knotVecType::Symbol extrapolate::Bool function BSplineInterpolation{FT}(u, - t, - d, - p, - k, - c, - pVecType, - knotVecType, - extrapolate) where {FT} + t, + d, + p, + k, + c, + pVecType, + knotVecType, + extrapolate) where {FT} new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), FT, eltype(u)}(u, t, d, @@ -348,15 +348,15 @@ struct BSplineApprox{uType, tType, pType, kType, cType, FT, T} <: knotVecType::Symbol extrapolate::Bool function BSplineApprox{FT}(u, - t, - d, - h, - p, - k, - c, - pVecType, - knotVecType, - extrapolate) where {FT} + t, + d, + h, + p, + k, + c, + pVecType, + knotVecType, + extrapolate) where {FT} new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), FT, eltype(u)}(u, t, d, diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 8b781a8e..328c9208 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -175,8 +175,8 @@ end # BSpline Curve Interpolation function _interpolate(A::BSplineInterpolation{<:AbstractVector{<:Number}}, - t::Number, - iguess) + t::Number, + iguess) t < A.t[1] && return A.u[1], 1 t > A.t[end] && return A.u[end], lastindex(t) # change t into param [0 1] diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 705749a8..3c157baf 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -82,9 +82,9 @@ this function would be the index returned by the previous call to `searchsorted` See [`sort!`](@ref) for an explanation of the keyword arguments `by`, `lt` and `rev`. """ function bracketstrictlymontonic(v::AbstractVector, - x, - guess::T, - o::Base.Order.Ordering)::NTuple{2, keytype(v)} where {T <: Integer} + x, + guess::T, + o::Base.Order.Ordering)::NTuple{2, keytype(v)} where {T <: Integer} bottom = firstindex(v) top = lastindex(v) if guess < bottom || guess > top diff --git a/src/plot_rec.jl b/src/plot_rec.jl index 307e56b2..ac8e3f45 100644 --- a/src/plot_rec.jl +++ b/src/plot_rec.jl @@ -32,11 +32,11 @@ end ######################################## @recipe function f(::Type{Val{:linear_interp}}, - x, - y, - z; - plotdensity = 10_000, - denseplot = true) + x, + y, + z; + plotdensity = 10_000, + denseplot = true) seriestype := :path label --> "Linear fit" @@ -54,11 +54,11 @@ end ######################################## @recipe function f(::Type{Val{:quadratic_interp}}, - x, - y, - z; - plotdensity = 10_000, - denseplot = true) + x, + y, + z; + plotdensity = 10_000, + denseplot = true) seriestype := :path label --> "Quadratic fit" @@ -76,11 +76,11 @@ end ######################################## @recipe function f(::Type{Val{:quadratic_spline}}, - x, - y, - z; - plotdensity = 10_000, - denseplot = true) + x, + y, + z; + plotdensity = 10_000, + denseplot = true) seriestype := :path label --> "Quadratic Spline" @@ -101,10 +101,10 @@ end ######################################## @recipe function f(::Type{Val{:lagrange_interp}}, - x, y, z; - n = length(x) - 1, - plotdensity = 10_000, - denseplot = true) + x, y, z; + n = length(x) - 1, + plotdensity = 10_000, + denseplot = true) seriestype := :path label --> "Lagrange Fit" @@ -126,11 +126,11 @@ end ######################################## @recipe function f(::Type{Val{:cubic_spline}}, - x, - y, - z; - plotdensity = 10_000, - denseplot = true) + x, + y, + z; + plotdensity = 10_000, + denseplot = true) seriestype := :path label --> "Cubic Spline" @@ -146,12 +146,12 @@ end end @recipe function f(::Type{Val{:bspline_interp}}, - x, y, z; - d = 5, - pVec = :ArcLen, - knotVec = :Average, - plotdensity = length(x) * 6, - denseplot = true) + x, y, z; + d = 5, + pVec = :ArcLen, + knotVec = :Average, + plotdensity = length(x) * 6, + denseplot = true) seriestype := :path label --> "B-Spline" @@ -176,13 +176,13 @@ end ######################################## @recipe function f(::Type{Val{:bspline_approx}}, - x, y, z; - d = 5, - h = length(x) - 1, - pVec = :ArcLen, - knotVec = :Average, - plotdensity = length(x) * 6, - denseplot = true) + x, y, z; + d = 5, + h = length(x) - 1, + pVec = :ArcLen, + knotVec = :Average, + plotdensity = length(x) * 6, + denseplot = true) seriestype := :path label --> "B-Spline" @@ -206,11 +206,11 @@ end ######################################## @recipe function f(::Type{Val{:akima}}, - x, - y, - z; - plotdensity = length(x) * 6, - denseplot = true) + x, + y, + z; + plotdensity = length(x) * 6, + denseplot = true) seriestype := :path label --> "Akima" From d4aa74d7e7078810877c1e3776bb96ccb07d0670 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <sathvikbhagavan@gmail.com> Date: Mon, 6 Nov 2023 22:57:07 +0530 Subject: [PATCH 6/7] build: add ForwardDiff as weak deps --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index fea02a86..82ce0637 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Optim = "429524aa-4258-5aef-a3af-852621145aeb" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -24,6 +25,7 @@ DataInterpolationsSymbolicsExt = "Symbolics" [compat] ChainRulesCore = "0.9.44, 0.10, 1" +ForwardDiff = "0.10" LinearAlgebra = "1.6" Optim = "0.19, 0.20, 0.21, 0.22, 1.0" PrettyTables = "2" From f6f86e04f301197b8b2db903c1d70a7427ea6e06 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <sathvikbhagavan@gmail.com> Date: Mon, 6 Nov 2023 23:20:05 +0530 Subject: [PATCH 7/7] build: add ForwardDiff dep in DataInterpolationsOptimExt extension --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 82ce0637..03946219 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] DataInterpolationsChainRulesCoreExt = "ChainRulesCore" -DataInterpolationsOptimExt = "Optim" +DataInterpolationsOptimExt = ["ForwardDiff", "Optim"] DataInterpolationsRegularizationToolsExt = "RegularizationTools" DataInterpolationsSymbolicsExt = "Symbolics"