From b15bc9bdfb5b6468259fcfe98270e3e811861f8f Mon Sep 17 00:00:00 2001
From: Sathvik Bhagavan <sathvikbhagavan@gmail.com>
Date: Mon, 6 Nov 2023 22:38:26 +0530
Subject: [PATCH 1/7] test: add derivative and extrapolation tests for
 RegularizationSmooth and Curvefit

---
 test/derivative_tests.jl    | 78 ++++++++++++++++++++++++++++---------
 test/interpolation_tests.jl | 47 ++++++++++++----------
 test/regularization.jl      |  7 ++++
 3 files changed, 93 insertions(+), 39 deletions(-)

diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl
index 7da2a247..8828b8a0 100644
--- a/test/derivative_tests.jl
+++ b/test/derivative_tests.jl
@@ -2,9 +2,13 @@ using DataInterpolations, Test
 using FiniteDifferences
 using DataInterpolations: derivative
 using Symbolics
+using StableRNGs
+using RegularizationTools
+using Optim
+using ForwardDiff
 
-function test_derivatives(method, u, t, args...; name::String)
-    func = method(u, t, args...; extrapolate = true)
+function test_derivatives(method, u, t; args = [], kwargs = [], name::String)
+    func = method(u, t, args...; kwargs..., extrapolate = true)
     trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1))
     trange_exclude = filter(x -> !in(x, t), trange)
     @testset "$name" begin
@@ -16,7 +20,7 @@ function test_derivatives(method, u, t, args...; name::String)
         end
 
         # Interpolation time points
-        for _t in t[2:end-1]
+        for _t in t[2:(end - 1)]
             fdiff = if func isa BSplineInterpolation || func isa BSplineApprox
                 forward_fdm(5, 1; geom = true)(func, _t)
             else
@@ -47,7 +51,7 @@ function test_derivatives(method, u, t, args...; name::String)
 end
 
 @testset "Linear Interpolation" begin
-    u = vcat(collect(1:5), 2*collect(6:10))
+    u = vcat(collect(1:5), 2 * collect(6:10))
     t = 1.0collect(1:10)
     test_derivatives(LinearInterpolation, u, t; name = "Linear Interpolation (Vector)")
     u = vcat(2.0collect(1:10)', 3.0collect(1:10)')
@@ -63,8 +67,8 @@ end
         name = "Quadratic Interpolation (Vector)")
     test_derivatives(QuadraticInterpolation,
         u,
-        t,
-        :Backward;
+        t;
+        args = [:Backward],
         name = "Quadratic Interpolation (Vector), backward")
     u = [1.0 4.0 9.0 16.0; 1.0 4.0 9.0 16.0]
     test_derivatives(QuadraticInterpolation,
@@ -139,28 +143,64 @@ end
     u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22]
     test_derivatives(BSplineInterpolation,
         u,
-        t,
-        2,
-        :Uniform,
-        :Uniform;
+        t;
+        args = [2,
+            :Uniform,
+            :Uniform],
         name = "BSpline Interpolation (Uniform, Uniform)")
     test_derivatives(BSplineInterpolation,
         u,
-        t,
-        2,
-        :ArcLen,
-        :Average;
+        t;
+        args = [2,
+            :ArcLen,
+            :Average],
         name = "BSpline Interpolation (Arclen, Average)")
     test_derivatives(BSplineApprox,
         u,
-        t,
-        3,
-        4,
-        :Uniform,
-        :Uniform;
+        t;
+        args = [
+            3,
+            4,
+            :Uniform,
+            :Uniform],
         name = "BSpline Approx (Uniform, Uniform)")
 end
 
+@testset "RegularizationSmooth" begin
+    npts = 50
+    xmin = 0.0
+    xspan = 3 / 2 * π
+    x = collect(range(xmin, xmin + xspan, length = npts))
+    rng = StableRNG(655)
+    x = x + xspan / npts * (rand(rng, npts) .- 0.5)
+    # select a subset randomly
+    idx = unique(rand(rng, collect(eachindex(x)), 20))
+    t = x[unique(idx)]
+    npts = length(t)
+    ut = sin.(t)
+    stdev = 1e-1 * maximum(ut)
+    u = ut + stdev * randn(rng, npts)
+    # data must be ordered if t̂ is not provided
+    idx = sortperm(t)
+    tₒ = t[idx]
+    uₒ = u[idx]
+    A = RegularizationSmooth(uₒ, tₒ; alg = :fixed)
+    test_derivatives(RegularizationSmooth,
+        uₒ,
+        tₒ;
+        kwargs = [:alg => :fixed],
+        name = "RegularizationSmooth")
+end
+
+@testset "Curvefit" begin
+    rng = StableRNG(12345)
+    model(x, p) = @. p[1] / (1 + exp(x - p[2]))
+    t = range(-10, stop = 10, length = 40)
+    u = model(t, [1.0, 2.0]) + 0.01 * randn(rng, length(t))
+    p0 = [0.5, 0.5]
+    test_derivatives(Curvefit, u, t; args = [model, p0, LBFGS()], name = "Curvefit")
+end
+
 @testset "Symbolic derivatives" begin
     u = [0.0, 1.5, 0.0]
     t = [0.0, 0.5, 1.0]
diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl
index 5866daad..72b5657e 100644
--- a/test/interpolation_tests.jl
+++ b/test/interpolation_tests.jl
@@ -594,26 +594,33 @@ end
     end
 end
 
-# Curvefit Interpolation
-rng = StableRNG(12345)
-model(x, p) = @. p[1] / (1 + exp(x - p[2]))
-t = range(-10, stop = 10, length = 40)
-u = model(t, [1.0, 2.0]) + 0.01 * randn(rng, length(t))
-p0 = [0.5, 0.5]
-
-A = Curvefit(u, t, model, p0, LBFGS())
-
-ts = [-7.0, -2.0, 0.0, 2.5, 5.0]
-vs = [
-    1.0013468217936277,
-    0.9836755196317837,
-    0.8833959853995836,
-    0.3810348276782708,
-    0.048062978598861855,
-]
-us = A.(ts)
-
-@test vs ≈ us
+@testset "Curvefit" begin
+    # Curvefit Interpolation
+    rng = StableRNG(12345)
+    model(x, p) = @. p[1] / (1 + exp(x - p[2]))
+    t = range(-10, stop = 10, length = 40)
+    u = model(t, [1.0, 2.0]) + 0.01 * randn(rng, length(t))
+    p0 = [0.5, 0.5]
+
+    A = Curvefit(u, t, model, p0, LBFGS())
+
+    ts = [-7.0, -2.0, 0.0, 2.5, 5.0]
+    vs = [
+        1.0013468217936277,
+        0.9836755196317837,
+        0.8833959853995836,
+        0.3810348276782708,
+        0.048062978598861855,
+    ]
+    us = A.(ts)
+    @test vs ≈ us
+
+    # Test extrapolation
+    A = Curvefit(u, t, model, p0, LBFGS(); extrapolate = true)
+    @test A(15.0) == model(15.0, A.pmin)
+    A = Curvefit(u, t, model, p0, LBFGS())
+    @test_throws DataInterpolations.ExtrapolationError A(15.0)
+end
 
 # missing values handling tests
 u = [1.0, 4.0, 9.0, 16.0, 25.0, missing, missing]
diff --git a/test/regularization.jl b/test/regularization.jl
index f07b4ad6..39d8e92c 100644
--- a/test/regularization.jl
+++ b/test/regularization.jl
@@ -177,3 +177,10 @@ end
     @test isapprox(A.û, ans, rtol = tolerance)
     @test isapprox(A.(t̂), ans, rtol = tolerance)
 end
+
+@testset "Extrapolation" begin
+    A = RegularizationSmooth(uₒ, tₒ; alg = :fixed, extrapolate = true)
+    @test A(10.0) == A.Aitp(10.0)
+    A = RegularizationSmooth(uₒ, tₒ; alg = :fixed)
+    @test_throws DataInterpolations.ExtrapolationError A(10.0)
+end

From dfba7a8a463c554d7f80dc6f32887f23e2a4b828 Mon Sep 17 00:00:00 2001
From: Sathvik Bhagavan <sathvikbhagavan@gmail.com>
Date: Mon, 6 Nov 2023 22:41:26 +0530
Subject: [PATCH 2/7] feat: add derivative for RegularizationSmooth and add
 extrapolate in the constructor

---
 ...ataInterpolationsRegularizationToolsExt.jl | 166 +++++++++++++-----
 1 file changed, 118 insertions(+), 48 deletions(-)

diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl
index c7abda64..dae1f606 100644
--- a/ext/DataInterpolationsRegularizationToolsExt.jl
+++ b/ext/DataInterpolationsRegularizationToolsExt.jl
@@ -1,7 +1,8 @@
 module DataInterpolationsRegularizationToolsExt
 
 using DataInterpolations
-using DataInterpolations: munge_data, _interpolate, RegularizationSmooth, get_show
+import DataInterpolations: munge_data,
+    _interpolate, RegularizationSmooth, get_show, derivative
 using LinearAlgebra
 
 isdefined(Base, :get_extension) ? (import RegularizationTools as RT) :
@@ -30,17 +31,17 @@ const LA = LinearAlgebra
 
 """
 # Arguments
-- `u::Vector`:  dependent data
-- `t::Vector`:  independent data
+- `u::Vector`:  dependent data.
+- `t::Vector`:  independent data.
 
 # Optional Arguments
 - `t̂::Vector`: t-values to use for the smooth curve (useful when data has missing values or
                is "scattered"); if not provided, then `t̂ = t`; must be monotonically
-               increasing
+               increasing.
 - `wls::{Vector,Symbol}`: weights to use with the least-squares fitting term; if set to
                           `:midpoint`, then midpoint-rule integration weights are used for
-                          _both_ `wls` and `wr`
-- `wr::Vector`: weights to use with the roughness term
+                          _both_ `wls` and `wr`.
+- `wr::Vector`: weights to use with the roughness term.
 - `d::Int = 2`: derivative used to calculate roughness; e.g., when `d = 2`, the 2nd
                 derivative (i.e. the curvature) of the data is used to calculate roughness.
 
@@ -52,119 +53,171 @@ const LA = LinearAlgebra
 - `alg::Symbol = :gcv_svd`: algorithm for determining an optimal value for λ; the provided λ
                             value is used directly if `alg = :fixed`; otherwise `alg =
                             [:gcv_svd, :gcv_tr, :L_curve]` is passed to the
-                            RegularizationTools solver
+                            RegularizationTools solver.
+-  `extrapolate::Bool` = false: flag to allow extrapolating outside the range of the time points provided.
 
 ## Example Constructors
 Smoothing using all arguments
 ```julia
-A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ=[1.0], alg=[:gcv_svd])
+A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ=1.0, alg=:gcv_svd)
 ```
 """
 function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
-    wls::AbstractVector, wr::AbstractVector, d::Int = 2;
-    λ::Real = 1.0, alg::Symbol = :gcv_svd)
+        wls::AbstractVector, wr::AbstractVector, d::Int = 2;
+        λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false)
     u, t = munge_data(u, t)
     M = _mapping_matrix(t̂, t)
     Wls½ = LA.diagm(sqrt.(wls))
     Wr½ = LA.diagm(sqrt.(wr))
-    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg)
-    RegularizationSmooth{true}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp)
+    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
+    RegularizationSmooth{true}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp, extrapolate)
 end
 """
 Direct smoothing, no `t̂` or weights
 ```julia
-A = RegularizationSmooth(u, t, d; λ=[1.0], alg=[:gcv_svd])
+A = RegularizationSmooth(u, t, d; λ=1.0, alg=:gcv_svd, extrapolate=false)
 ```
 """
 function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2;
-    λ::Real = 1.0,
-    alg::Symbol = :gcv_svd)
+        λ::Real = 1.0,
+        alg::Symbol = :gcv_svd, extrapolate::Bool = false)
     u, t = munge_data(u, t)
     t̂ = t
     N = length(t)
     M = Array{Float64}(LA.I, N, N)
     Wls½ = Array{Float64}(LA.I, N, N)
     Wr½ = Array{Float64}(LA.I, N - d, N - d)
-    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg)
-    RegularizationSmooth{true}(u, û, t, t̂, LA.diag(Wls½), LA.diag(Wr½), d, λ, alg, Aitp)
+    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
+    RegularizationSmooth{true}(u,
+        û,
+        t,
+        t̂,
+        LA.diag(Wls½),
+        LA.diag(Wr½),
+        d,
+        λ,
+        alg,
+        Aitp,
+        extrapolate)
 end
 """
 `t̂` provided, no weights
 ```julia
-A = RegularizationSmooth(u, t, t̂, d; λ=[1.0], alg=[:gcv_svd])
+A = RegularizationSmooth(u, t, t̂, d; λ=1.0, alg=:gcv_svd, extrapolate=false)
 ```
 """
 function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
-    d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd)
+        d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false)
     u, t = munge_data(u, t)
     N, N̂ = length(t), length(t̂)
     M = _mapping_matrix(t̂, t)
     Wls½ = Array{Float64}(LA.I, N, N)
     Wr½ = Array{Float64}(LA.I, N̂ - d, N̂ - d)
-    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg)
-    RegularizationSmooth{true}(u, û, t, t̂, LA.diag(Wls½), LA.diag(Wr½), d, λ, alg, Aitp)
+    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
+    RegularizationSmooth{true}(u,
+        û,
+        t,
+        t̂,
+        LA.diag(Wls½),
+        LA.diag(Wr½),
+        d,
+        λ,
+        alg,
+        Aitp,
+        extrapolate)
 end
 """
 `t̂` and `wls` provided
 ```julia
-A = RegularizationSmooth(u, t, t̂, wls, d; λ=[1.0], alg=[:gcv_svd])
+A = RegularizationSmooth(u, t, t̂, wls, d; λ=1.0, alg=:gcv_svd, extrapolate=false)
 ```
 """
 function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
-    wls::AbstractVector, d::Int = 2; λ::Real = 1.0,
-    alg::Symbol = :gcv_svd)
+        wls::AbstractVector, d::Int = 2; λ::Real = 1.0,
+        alg::Symbol = :gcv_svd, extrapolate::Bool = false)
     u, t = munge_data(u, t)
     N, N̂ = length(t), length(t̂)
     M = _mapping_matrix(t̂, t)
     Wls½ = LA.diagm(sqrt.(wls))
     Wr½ = Array{Float64}(LA.I, N̂ - d, N̂ - d)
-    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg)
-    RegularizationSmooth{true}(u, û, t, t̂, wls, LA.diag(Wr½), d, λ, alg, Aitp)
+    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
+    RegularizationSmooth{true}(u,
+        û,
+        t,
+        t̂,
+        wls,
+        LA.diag(Wr½),
+        d,
+        λ,
+        alg,
+        Aitp,
+        extrapolate)
 end
 """
 `wls` provided, no `t̂`
 ```julia
-A = RegularizationSmooth(u, t, nothing, wls,d; λ=[1.0], alg=[:gcv_svd])
+A = RegularizationSmooth(u, t, nothing, wls,d; λ=1.0, alg=:gcv_svd, extrapolate=false)
 ```
 """
 function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
-    wls::AbstractVector, d::Int = 2; λ::Real = 1.0,
-    alg::Symbol = :gcv_svd)
+        wls::AbstractVector, d::Int = 2; λ::Real = 1.0,
+        alg::Symbol = :gcv_svd, extrapolate::Bool = false)
     u, t = munge_data(u, t)
     t̂ = t
     N = length(t)
     M = Array{Float64}(LA.I, N, N)
     Wls½ = LA.diagm(sqrt.(wls))
     Wr½ = Array{Float64}(LA.I, N - d, N - d)
-    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg)
-    RegularizationSmooth{true}(u, û, t, t̂, wls, LA.diag(Wr½), d, λ, alg, Aitp)
+    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
+    RegularizationSmooth{true}(u,
+        û,
+        t,
+        t̂,
+        wls,
+        LA.diag(Wr½),
+        d,
+        λ,
+        alg,
+        Aitp,
+        extrapolate)
 end
 """
 `wls` and `wr` provided, no `t̂`
 ```julia
-A = RegularizationSmooth(u, t, nothing, wls, wr, d; λ=[1.0], alg=[:gcv_svd])
+A = RegularizationSmooth(u, t, nothing, wls, wr, d; λ=1.0, alg=:gcv_svd, extrapolate=false)
 ```
 """
 function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
-    wls::AbstractVector, wr::AbstractVector, d::Int = 2;
-    λ::Real = 1.0, alg::Symbol = :gcv_svd)
+        wls::AbstractVector, wr::AbstractVector, d::Int = 2;
+        λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false)
     u, t = munge_data(u, t)
     t̂ = t
     N = length(t)
     M = Array{Float64}(LA.I, N, N)
     Wls½ = LA.diagm(sqrt.(wls))
     Wr½ = LA.diagm(sqrt.(wr))
-    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg)
-    RegularizationSmooth{true}(u, û, t, t̂, wls, LA.diag(Wr½), d, λ, alg, Aitp)
+    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
+    RegularizationSmooth{true}(u,
+        û,
+        t,
+        t̂,
+        wls,
+        LA.diag(Wr½),
+        d,
+        λ,
+        alg,
+        Aitp,
+        extrapolate)
 end
 """
 Keyword provided for `wls`, no `t̂`
 ```julia
-A = RegularizationSmooth(u, t, nothing, :midpoint, d; λ=[1.0], alg=[:gcv_svd])
+A = RegularizationSmooth(u, t, nothing, :midpoint, d; λ=1.0, alg=:gcv_svd, extrapolate=false)
 ```
 """
 function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing,
-    wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd)
+        wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd,
+        extrapolate::Bool = false)
     u, t = munge_data(u, t)
     t̂ = t
     N = length(t)
@@ -172,8 +225,18 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing
     wls, wr = _weighting_by_kw(t, d, wls)
     Wls½ = LA.diagm(sqrt.(wls))
     Wr½ = LA.diagm(sqrt.(wr))
-    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg)
-    RegularizationSmooth{true}(u, û, t, t̂, LA.diag(Wls½), LA.diag(Wr½), d, λ, alg, Aitp)
+    û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
+    RegularizationSmooth{true}(u,
+        û,
+        t,
+        t̂,
+        LA.diag(Wls½),
+        LA.diag(Wr½),
+        d,
+        λ,
+        alg,
+        Aitp,
+        extrapolate)
 end
 # """ t̂ provided and keyword for wls  _TBD_ """
 # function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector,
@@ -181,7 +244,7 @@ end
 
 """ Solve for the smoothed dependent variables and create spline interpolator """
 function _reg_smooth_solve(u::AbstractVector, t̂::AbstractVector, d::Int, M::AbstractMatrix,
-    Wls½::AbstractMatrix, Wr½::AbstractMatrix, λ::Real, alg::Symbol)
+        Wls½::AbstractMatrix, Wr½::AbstractMatrix, λ::Real, alg::Symbol, extrapolate::Bool)
     λ = float(λ) # `float` expected by RT
     D = _derivative_matrix(t̂, d)
     Ψ = RT.setupRegularizationProblem(Wls½ * M, Wr½ * D)
@@ -198,7 +261,7 @@ function _reg_smooth_solve(u::AbstractVector, t̂::AbstractVector, d::Int, M::Ab
         û = result.x
         λ = result.λ
     end
-    Aitp = CubicSpline(û, t̂)
+    Aitp = CubicSpline(û, t̂; extrapolate)
     # It seems logical to use B-Spline of order d+1, but I am unsure if theory supports the
     # extra computational cost, JJS 12/25/21
     #Aitp = BSplineInterpolation(û,t̂,d+1,:ArcLen,:Average)
@@ -262,14 +325,21 @@ function _weighting_by_kw(t::AbstractVector, d::Int, wls::Symbol)
     end
 end
 
-function DataInterpolations._interpolate(A::RegularizationSmooth{
-        <:AbstractVector{<:Number},
-    },
-    t::Number)
-    DataInterpolations._interpolate(A.Aitp, t)
+function _interpolate(A::RegularizationSmooth{
+            <:AbstractVector{<:Number},
+        },
+        t::Number)
+    _interpolate(A.Aitp, t)
 end
 
-function DataInterpolations.get_show(interp::RegularizationSmooth)
+function derivative(A::RegularizationSmooth{
+            <:AbstractVector{<:Number},
+        },
+        t::Number)
+    derivative(A.Aitp, t)
+end
+
+function get_show(interp::RegularizationSmooth)
     return "RegularizationSmooth" *
            " with $(length(interp.t)) points, with regularization coefficient $(interp.λ)\n"
 end

From c9219c466507b429413aec0c307ff05df6a04c68 Mon Sep 17 00:00:00 2001
From: Sathvik Bhagavan <sathvikbhagavan@gmail.com>
Date: Mon, 6 Nov 2023 22:43:14 +0530
Subject: [PATCH 3/7] feat: add derivative of Curvefit using ForwardDiff, add
 extrpolate in the constructor and remove CurvfitCache and move it to
 src/DataInterpolations.jl

---
 ext/DataInterpolationsOptimExt.jl | 75 +++++++++++--------------------
 1 file changed, 26 insertions(+), 49 deletions(-)

diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl
index 5db14436..76ad8eb0 100644
--- a/ext/DataInterpolationsOptimExt.jl
+++ b/ext/DataInterpolationsOptimExt.jl
@@ -1,53 +1,22 @@
 module DataInterpolationsOptimExt
 
-if isdefined(Base, :get_extension)
-    using DataInterpolations: AbstractInterpolation, munge_data
-    import DataInterpolations: Curvefit, _interpolate, get_show
-    using Reexport
-else
-    using ..DataInterpolations: AbstractInterpolation, munge_data
-    import ..DataInterpolations: Curvefit, _interpolate, get_show
-    using Reexport
-end
+using DataInterpolations
+import DataInterpolations: munge_data,
+    Curvefit, CurvefitCache, _interpolate, get_show, derivative, ExtrapolationError
 
-isdefined(Base, :get_extension) ? (@reexport using Optim) : (@reexport using ..Optim)
+isdefined(Base, :get_extension) ? (using Optim, ForwardDiff) :
+(using ..Optim, ..ForwardDiff)
 
 ### Curvefit
-struct CurvefitCache{
-    uType,
-    tType,
-    mType,
-    p0Type,
-    ubType,
-    lbType,
-    algType,
-    pminType,
-    FT,
-    T,
-} <: AbstractInterpolation{FT, T}
-    u::uType
-    t::tType
-    m::mType        # model type
-    p0::p0Type      # intial params
-    ub::ubType      # upper bound of params
-    lb::lbType      # lower bound of params
-    alg::algType    # alg to optimize cost function
-    pmin::pminType  # optimized params
-    function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin) where {FT}
-        new{typeof(u), typeof(t), typeof(m),
-            typeof(p0), typeof(ub), typeof(lb),
-            typeof(alg), typeof(pmin), FT, eltype(u)}(u,
-            t,
-            m,
-            p0,
-            ub,
-            lb,
-            alg,
-            pmin)
-    end
-end
-
-function Curvefit(u, t, model, p0, alg, box = false, lb = nothing, ub = nothing)
+function Curvefit(u,
+        t,
+        model,
+        p0,
+        alg,
+        box = false,
+        lb = nothing,
+        ub = nothing;
+        extrapolate = false)
     u, t = munge_data(u, t)
     errfun(t, u, p) = sum(abs2.(u .- model(t, p)))
     if box == false
@@ -60,21 +29,29 @@ function Curvefit(u, t, model, p0, alg, box = false, lb = nothing, ub = nothing)
         mfit = optimize(od, lb, ub, p0, Fminbox(alg))
     end
     pmin = Optim.minimizer(mfit)
-    CurvefitCache{true}(u, t, model, p0, ub, lb, alg, pmin)
+    CurvefitCache{true}(u, t, model, p0, ub, lb, alg, pmin, extrapolate)
 end
 
 # Curvefit
 function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}},
-    t::Union{AbstractVector{<:Number}, Number})
+        t::Union{AbstractVector{<:Number}, Number})
+    ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) &&
+        throw(ExtrapolationError())
     A.m(t, A.pmin)
 end
 
 function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}},
-    t::Union{AbstractVector{<:Number}, Number},
-    i)
+        t::Union{AbstractVector{<:Number}, Number},
+        i)
     _interpolate(A, t), i
 end
 
+function derivative(A::CurvefitCache{<:AbstractVector{<:Number}},
+        t::Union{AbstractVector{<:Number}, Number})
+    ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
+    ForwardDiff.derivative(x -> A.m(x, A.pmin), t)
+end
+
 function get_show(interp::CurvefitCache)
     return "Curvefit" *
            " with $(length(interp.t)) points, using $(nameof(typeof(interp.alg)))\n"

From 67ee41eb5f0661eda3f86cb76061c9ce37068f1b Mon Sep 17 00:00:00 2001
From: Sathvik Bhagavan <sathvikbhagavan@gmail.com>
Date: Mon, 6 Nov 2023 22:45:41 +0530
Subject: [PATCH 4/7] refactor: add extrapolate in RegularizationSmooth, move
 CurvefitCache constructor, add ForwardDiff to requires for Curvefit

---
 src/DataInterpolations.jl | 69 ++++++++++++++++++++++++++++++++++-----
 1 file changed, 61 insertions(+), 8 deletions(-)

diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl
index 96f9ebab..1d59e0a1 100644
--- a/src/DataInterpolations.jl
+++ b/src/DataInterpolations.jl
@@ -61,7 +61,18 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT
     λ::T2        # regularization parameter
     alg::Symbol  # how to determine λ: `:fixed`, `:gcv_svd`, `:gcv_tr`, `L_curve`
     Aitp::AbstractInterpolation{FT, T}
-    function RegularizationSmooth{FT}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp) where {FT}
+    extrapolate::Bool
+    function RegularizationSmooth{FT}(u,
+            û,
+            t,
+            t̂,
+            wls,
+            wr,
+            d,
+            λ,
+            alg,
+            Aitp,
+            extrapolate) where {FT}
         new{typeof(u), typeof(t), FT, eltype(u), typeof(λ)}(u,
             û,
             t,
@@ -71,12 +82,57 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT
             d,
             λ,
             alg,
-            Aitp)
+            Aitp,
+            extrapolate)
     end
 end
 
 export RegularizationSmooth
 
+# CurveFit
+struct CurvefitCache{
+    uType,
+    tType,
+    mType,
+    p0Type,
+    ubType,
+    lbType,
+    algType,
+    pminType,
+    FT,
+    T,
+} <: AbstractInterpolation{FT, T}
+    u::uType
+    t::tType
+    m::mType        # model type
+    p0::p0Type      # intial params
+    ub::ubType      # upper bound of params
+    lb::lbType      # lower bound of params
+    alg::algType    # alg to optimize cost function
+    pmin::pminType  # optimized params
+    extrapolate::Bool
+    function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin, extrapolate) where {FT}
+        new{typeof(u), typeof(t), typeof(m),
+            typeof(p0), typeof(ub), typeof(lb),
+            typeof(alg), typeof(pmin), FT, eltype(u)}(u,
+            t,
+            m,
+            p0,
+            ub,
+            lb,
+            alg,
+            pmin,
+            extrapolate)
+    end
+end
+
+# Define an empty function, so that it can be extended via `DataInterpolationsOptimExt`
+function Curvefit()
+    error("CurveFit requires loading Optim and ForwardDiff, e.g. `using Optim, ForwardDiff`")
+end
+
+export Curvefit
+
 @static if !isdefined(Base, :get_extension)
     using Requires
 end
@@ -87,7 +143,9 @@ end
             include("../ext/DataInterpolationsChainRulesCoreExt.jl")
         end
         Requires.@require Optim="429524aa-4258-5aef-a3af-852621145aeb" begin
-            include("../ext/DataInterpolationsOptimExt.jl")
+            Requires.@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
+                include("../ext/DataInterpolationsOptimExt.jl")
+            end
         end
         Requires.@require RegularizationTools="29dad682-9a27-4bc3-9c72-016788665182" begin
             include("../ext/DataInterpolationsRegularizationToolsExt.jl")
@@ -98,11 +156,6 @@ end
     end
 end
 
-# Define an empty function, so that it can be extended via `DataInterpolationsOptimExt`
-Curvefit() = error("CurveFit requires loading Optim, e.g. `using Optim`")
-
-export Curvefit
-
 # Deprecated April 2020
 export ZeroSpline
 

From 66ecc9c7804d40cf042e68f96a6b838a2dd6b1d6 Mon Sep 17 00:00:00 2001
From: Sathvik Bhagavan <sathvikbhagavan@gmail.com>
Date: Mon, 6 Nov 2023 22:46:15 +0530
Subject: [PATCH 5/7] chore: format files

---
 ext/DataInterpolationsChainRulesCoreExt.jl | 16 ++---
 src/derivatives.jl                         |  4 +-
 src/integrals.jl                           |  8 +--
 src/interpolation_caches.jl                | 42 +++++------
 src/interpolation_methods.jl               |  4 +-
 src/interpolation_utils.jl                 |  6 +-
 src/plot_rec.jl                            | 84 +++++++++++-----------
 7 files changed, 82 insertions(+), 82 deletions(-)

diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl
index 9508f16c..59306c46 100644
--- a/ext/DataInterpolationsChainRulesCoreExt.jl
+++ b/ext/DataInterpolationsChainRulesCoreExt.jl
@@ -11,20 +11,20 @@ else
 end
 
 function ChainRulesCore.rrule(::typeof(_interpolate),
-    A::Union{
-        LagrangeInterpolation,
-        AkimaInterpolation,
-        BSplineInterpolation,
-        BSplineApprox,
-    },
-    t::Number)
+        A::Union{
+            LagrangeInterpolation,
+            AkimaInterpolation,
+            BSplineInterpolation,
+            BSplineApprox,
+        },
+        t::Number)
     deriv = derivative(A, t)
     interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ)
     return _interpolate(A, t), interpolate_pullback
 end
 
 function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation,
-    t::Number)
+        t::Number)
     return _interpolate(A, t), derivative(A, t) * Δt
 end
 
diff --git a/src/derivatives.jl b/src/derivatives.jl
index 518a9945..fb6a62c2 100644
--- a/src/derivatives.jl
+++ b/src/derivatives.jl
@@ -64,7 +64,7 @@ function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
                 tmp += k
             end
         end
-        der += A.u[j]*tmp
+        der += A.u[j] * tmp
     end
     der
 end
@@ -98,7 +98,7 @@ function derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
                 tmp += k
             end
         end
-        @. der += A.u[:, j]*tmp
+        @. der += A.u[:, j] * tmp
     end
     der
 end
diff --git a/src/integrals.jl b/src/integrals.jl
index 1c42a241..d8946656 100644
--- a/src/integrals.jl
+++ b/src/integrals.jl
@@ -24,8 +24,8 @@ end
 
 samples(A::LinearInterpolation{<:AbstractVector}) = (0, 1)
 function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}},
-    idx::Number,
-    t::Number)
+        idx::Number,
+        t::Number)
     t1 = A.t[idx]
     t2 = A.t[idx + 1]
     u1 = A.u[idx]
@@ -46,8 +46,8 @@ end
 
 samples(A::QuadraticInterpolation{<:AbstractVector}) = (0, 1)
 function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}},
-    idx::Number,
-    t::Number)
+        idx::Number,
+        t::Number)
     A.mode == :Backward && idx > 1 && (idx -= 1)
     idx = min(length(A.t) - 2, idx)
     t1 = A.t[idx]
diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl
index 18e5cac6..f9955e9b 100644
--- a/src/interpolation_caches.jl
+++ b/src/interpolation_caches.jl
@@ -149,8 +149,8 @@ struct QuadraticSpline{uType, tType, tAType, dType, zType, FT, T} <:
 end
 
 function QuadraticSpline(u::uType,
-    t;
-    extrapolate = false) where {uType <: AbstractVector{<:Number}}
+        t;
+        extrapolate = false) where {uType <: AbstractVector{<:Number}}
     u, t = munge_data(u, t)
     s = length(t)
     dl = ones(eltype(t), s - 1)
@@ -199,8 +199,8 @@ struct CubicSpline{uType, tType, hType, zType, FT, T} <: AbstractInterpolation{F
 end
 
 function CubicSpline(u::uType,
-    t;
-    extrapolate = false) where {uType <: AbstractVector{<:Number}}
+        t;
+        extrapolate = false) where {uType <: AbstractVector{<:Number}}
     u, t = munge_data(u, t)
     n = length(t) - 1
     h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0)
@@ -250,14 +250,14 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, FT, T} <:
     knotVecType::Symbol
     extrapolate::Bool
     function BSplineInterpolation{FT}(u,
-        t,
-        d,
-        p,
-        k,
-        c,
-        pVecType,
-        knotVecType,
-        extrapolate) where {FT}
+            t,
+            d,
+            p,
+            k,
+            c,
+            pVecType,
+            knotVecType,
+            extrapolate) where {FT}
         new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), FT, eltype(u)}(u,
             t,
             d,
@@ -348,15 +348,15 @@ struct BSplineApprox{uType, tType, pType, kType, cType, FT, T} <:
     knotVecType::Symbol
     extrapolate::Bool
     function BSplineApprox{FT}(u,
-        t,
-        d,
-        h,
-        p,
-        k,
-        c,
-        pVecType,
-        knotVecType,
-        extrapolate) where {FT}
+            t,
+            d,
+            h,
+            p,
+            k,
+            c,
+            pVecType,
+            knotVecType,
+            extrapolate) where {FT}
         new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), FT, eltype(u)}(u,
             t,
             d,
diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl
index 8b781a8e..328c9208 100644
--- a/src/interpolation_methods.jl
+++ b/src/interpolation_methods.jl
@@ -175,8 +175,8 @@ end
 
 # BSpline Curve Interpolation
 function _interpolate(A::BSplineInterpolation{<:AbstractVector{<:Number}},
-    t::Number,
-    iguess)
+        t::Number,
+        iguess)
     t < A.t[1] && return A.u[1], 1
     t > A.t[end] && return A.u[end], lastindex(t)
     # change t into param [0 1]
diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl
index 705749a8..3c157baf 100644
--- a/src/interpolation_utils.jl
+++ b/src/interpolation_utils.jl
@@ -82,9 +82,9 @@ this function would be the index returned by the previous call to `searchsorted`
 See [`sort!`](@ref) for an explanation of the keyword arguments `by`, `lt` and `rev`.
 """
 function bracketstrictlymontonic(v::AbstractVector,
-    x,
-    guess::T,
-    o::Base.Order.Ordering)::NTuple{2, keytype(v)} where {T <: Integer}
+        x,
+        guess::T,
+        o::Base.Order.Ordering)::NTuple{2, keytype(v)} where {T <: Integer}
     bottom = firstindex(v)
     top = lastindex(v)
     if guess < bottom || guess > top
diff --git a/src/plot_rec.jl b/src/plot_rec.jl
index 307e56b2..ac8e3f45 100644
--- a/src/plot_rec.jl
+++ b/src/plot_rec.jl
@@ -32,11 +32,11 @@ end
 ########################################
 
 @recipe function f(::Type{Val{:linear_interp}},
-    x,
-    y,
-    z;
-    plotdensity = 10_000,
-    denseplot = true)
+        x,
+        y,
+        z;
+        plotdensity = 10_000,
+        denseplot = true)
     seriestype := :path
 
     label --> "Linear fit"
@@ -54,11 +54,11 @@ end
 ########################################
 
 @recipe function f(::Type{Val{:quadratic_interp}},
-    x,
-    y,
-    z;
-    plotdensity = 10_000,
-    denseplot = true)
+        x,
+        y,
+        z;
+        plotdensity = 10_000,
+        denseplot = true)
     seriestype := :path
 
     label --> "Quadratic fit"
@@ -76,11 +76,11 @@ end
 ########################################
 
 @recipe function f(::Type{Val{:quadratic_spline}},
-    x,
-    y,
-    z;
-    plotdensity = 10_000,
-    denseplot = true)
+        x,
+        y,
+        z;
+        plotdensity = 10_000,
+        denseplot = true)
     seriestype := :path
 
     label --> "Quadratic Spline"
@@ -101,10 +101,10 @@ end
 ########################################
 
 @recipe function f(::Type{Val{:lagrange_interp}},
-    x, y, z;
-    n = length(x) - 1,
-    plotdensity = 10_000,
-    denseplot = true)
+        x, y, z;
+        n = length(x) - 1,
+        plotdensity = 10_000,
+        denseplot = true)
     seriestype := :path
 
     label --> "Lagrange Fit"
@@ -126,11 +126,11 @@ end
 ########################################
 
 @recipe function f(::Type{Val{:cubic_spline}},
-    x,
-    y,
-    z;
-    plotdensity = 10_000,
-    denseplot = true)
+        x,
+        y,
+        z;
+        plotdensity = 10_000,
+        denseplot = true)
     seriestype := :path
 
     label --> "Cubic Spline"
@@ -146,12 +146,12 @@ end
 end
 
 @recipe function f(::Type{Val{:bspline_interp}},
-    x, y, z;
-    d = 5,
-    pVec = :ArcLen,
-    knotVec = :Average,
-    plotdensity = length(x) * 6,
-    denseplot = true)
+        x, y, z;
+        d = 5,
+        pVec = :ArcLen,
+        knotVec = :Average,
+        plotdensity = length(x) * 6,
+        denseplot = true)
     seriestype := :path
 
     label --> "B-Spline"
@@ -176,13 +176,13 @@ end
 ########################################
 
 @recipe function f(::Type{Val{:bspline_approx}},
-    x, y, z;
-    d = 5,
-    h = length(x) - 1,
-    pVec = :ArcLen,
-    knotVec = :Average,
-    plotdensity = length(x) * 6,
-    denseplot = true)
+        x, y, z;
+        d = 5,
+        h = length(x) - 1,
+        pVec = :ArcLen,
+        knotVec = :Average,
+        plotdensity = length(x) * 6,
+        denseplot = true)
     seriestype := :path
 
     label --> "B-Spline"
@@ -206,11 +206,11 @@ end
 ########################################
 
 @recipe function f(::Type{Val{:akima}},
-    x,
-    y,
-    z;
-    plotdensity = length(x) * 6,
-    denseplot = true)
+        x,
+        y,
+        z;
+        plotdensity = length(x) * 6,
+        denseplot = true)
     seriestype := :path
 
     label --> "Akima"

From d4aa74d7e7078810877c1e3776bb96ccb07d0670 Mon Sep 17 00:00:00 2001
From: Sathvik Bhagavan <sathvikbhagavan@gmail.com>
Date: Mon, 6 Nov 2023 22:57:07 +0530
Subject: [PATCH 6/7] build: add ForwardDiff as weak deps

---
 Project.toml | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/Project.toml b/Project.toml
index fea02a86..82ce0637 100644
--- a/Project.toml
+++ b/Project.toml
@@ -12,6 +12,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
 
 [weakdeps]
 ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 Optim = "429524aa-4258-5aef-a3af-852621145aeb"
 RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
 Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -24,6 +25,7 @@ DataInterpolationsSymbolicsExt = "Symbolics"
 
 [compat]
 ChainRulesCore = "0.9.44, 0.10, 1"
+ForwardDiff = "0.10"
 LinearAlgebra = "1.6"
 Optim = "0.19, 0.20, 0.21, 0.22, 1.0"
 PrettyTables = "2"

From f6f86e04f301197b8b2db903c1d70a7427ea6e06 Mon Sep 17 00:00:00 2001
From: Sathvik Bhagavan <sathvikbhagavan@gmail.com>
Date: Mon, 6 Nov 2023 23:20:05 +0530
Subject: [PATCH 7/7] build: add ForwardDiff dep in DataInterpolationsOptimExt
 extension

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 82ce0637..03946219 100644
--- a/Project.toml
+++ b/Project.toml
@@ -19,7 +19,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
 
 [extensions]
 DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
-DataInterpolationsOptimExt = "Optim"
+DataInterpolationsOptimExt = ["ForwardDiff", "Optim"]
 DataInterpolationsRegularizationToolsExt = "RegularizationTools"
 DataInterpolationsSymbolicsExt = "Symbolics"