diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 59306c46..34e27841 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -2,11 +2,13 @@ module DataInterpolationsChainRulesCoreExt if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, - LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox + LagrangeInterpolation, AkimaInterpolation, + BSplineInterpolation, BSplineApprox using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, - LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox + LagrangeInterpolation, AkimaInterpolation, + BSplineInterpolation, BSplineApprox using ..ChainRulesCore end @@ -15,7 +17,7 @@ function ChainRulesCore.rrule(::typeof(_interpolate), LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, - BSplineApprox, + BSplineApprox }, t::Number) deriv = derivative(A, t) diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index f2643216..230d0af5 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -2,8 +2,9 @@ module DataInterpolationsOptimExt using DataInterpolations import DataInterpolations: munge_data, - Curvefit, CurvefitCache, _interpolate, get_show, derivative, ExtrapolationError, - integral, IntegralNotFoundError + Curvefit, CurvefitCache, _interpolate, get_show, derivative, + ExtrapolationError, + integral, IntegralNotFoundError isdefined(Base, :get_extension) ? (using Optim, ForwardDiff) : (using ..Optim, ..ForwardDiff) diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index bb34f574..28e8c47b 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -2,7 +2,8 @@ module DataInterpolationsRegularizationToolsExt using DataInterpolations import DataInterpolations: munge_data, - _interpolate, RegularizationSmooth, get_show, derivative, integral + _interpolate, RegularizationSmooth, get_show, derivative, + integral using LinearAlgebra isdefined(Base, :get_extension) ? (import RegularizationTools as RT) : @@ -243,7 +244,8 @@ end # wls::Symbol, d::Int=2; λ::Real=1.0, alg::Symbol=:gcv_svd) """ Solve for the smoothed dependent variables and create spline interpolator """ -function _reg_smooth_solve(u::AbstractVector, t̂::AbstractVector, d::Int, M::AbstractMatrix, +function _reg_smooth_solve( + u::AbstractVector, t̂::AbstractVector, d::Int, M::AbstractMatrix, Wls½::AbstractMatrix, Wr½::AbstractMatrix, λ::Real, alg::Symbol, extrapolate::Bool) λ = float(λ) # `float` expected by RT D = _derivative_matrix(t̂, d) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index 9f23e29e..7e998519 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -51,8 +51,8 @@ function Base.showerror(io::IO, e::IntegralNotFoundError) end export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, - AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, - BSplineInterpolation, BSplineApprox + AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, + BSplineInterpolation, BSplineApprox # added for RegularizationSmooth, JJS 11/27/21 ### Regularization data smoothing and interpolation @@ -106,7 +106,7 @@ struct CurvefitCache{ algType, pminType, FT, - T, + T } <: AbstractInterpolation{FT, T} u::uType t::tType @@ -139,7 +139,6 @@ end export Curvefit - # Deprecated April 2020 export ZeroSpline diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index f9955e9b..c28dba90 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -173,8 +173,9 @@ function QuadraticSpline(u::uType, t; extrapolate = false) where {uType <: Abstr d_tmp = ones(eltype(t), s) du = zeros(eltype(t), s - 1) tA = Tridiagonal(dl, d_tmp, du) - d_ = map(i -> i == 1 ? zeros(eltype(t), size(u[1])) : - 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), + d_ = map( + i -> i == 1 ? zeros(eltype(t), size(u[1])) : + 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) @@ -213,8 +214,9 @@ function CubicSpline(u::uType, typed_zero = zero(6(u[begin + 2] - u[begin + 1]) / h[begin + 2] - 6(u[begin + 1] - u[begin]) / h[begin + 1]) - d = map(i -> i == 1 || i == n + 1 ? typed_zero : - 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], + d = map( + i -> i == 1 || i == n + 1 ? typed_zero : + 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d CubicSpline{true}(u, t, h[1:(n + 1)], z, extrapolate) @@ -228,8 +230,9 @@ function CubicSpline(u::uType, t; extrapolate = false) where {uType <: AbstractV d_tmp = 2 .* (h[1:(n + 1)] .+ h[2:(n + 2)]) du = h[2:(n + 1)] tA = Tridiagonal(dl, d_tmp, du) - d_ = map(i -> i == 1 || i == n + 1 ? zeros(eltype(t), size(u[1])) : - 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], + d_ = map( + i -> i == 1 || i == n + 1 ? zeros(eltype(t), size(u[1])) : + 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 3c157baf..dbbeca84 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -41,8 +41,9 @@ function munge_data(u::AbstractVector, t::AbstractVector) Tu = Base.nonmissingtype(eltype(u)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == length(u) - non_missing_indices = collect(i for i in 1:length(t) - if !ismissing(u[i]) && !ismissing(t[i])) + non_missing_indices = collect(i + for i in 1:length(t) + if !ismissing(u[i]) && !ismissing(t[i])) newu = Tu.([u[i] for i in non_missing_indices]) newt = Tt.([t[i] for i in non_missing_indices]) @@ -53,8 +54,9 @@ function munge_data(U::StridedMatrix, t::AbstractVector) TU = Base.nonmissingtype(eltype(U)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == size(U, 2) - non_missing_indices = collect(i for i in 1:length(t) - if !any(ismissing, U[:, i]) && !ismissing(t[i])) + non_missing_indices = collect(i + for i in 1:length(t) + if !any(ismissing, U[:, i]) && !ismissing(t[i])) newUs = [TU.(U[:, i]) for i in non_missing_indices] newt = Tt.([t[i] for i in non_missing_indices]) diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9fbf153c..ebfe1947 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -360,7 +360,7 @@ end @testset "Matrix case" for u in [ [1.0 2.0 0.0 1.0; 1.0 2.0 0.0 1.0], - ["B" "C" "A" "B"; "B" "C" "A" "B"], + ["B" "C" "A" "B"; "B" "C" "A" "B"] ] A = ConstantInterpolation(u, t, dir = :right; extrapolate = true) @test A(0.5) == u[:, 1] @@ -609,7 +609,7 @@ end 0.9836755196317837, 0.8833959853995836, 0.3810348276782708, - 0.048062978598861855, + 0.048062978598861855 ] us = A.(ts) @test vs ≈ us diff --git a/test/regularization.jl b/test/regularization.jl index 6e083fa0..565becfd 100644 --- a/test/regularization.jl +++ b/test/regularization.jl @@ -56,7 +56,7 @@ tolerance = 1e-3 0.1392761732695201, -0.3312498167413961, -0.6673268474631847, - -0.9370342562716745, + -0.9370342562716745 ] @test isapprox(A.û, ans, rtol = tolerance) @test isapprox(A.(tₒ), ans, rtol = tolerance) @@ -81,7 +81,7 @@ end 0.05281575111822609, -0.5333542714497277, -0.8406745098604134, - -1.0983391396173634, + -1.0983391396173634 ] @test isapprox(A.û, ans, rtol = tolerance) @test isapprox(A.(tₒ), ans, rtol = tolerance) @@ -102,7 +102,7 @@ end 0.09102085384961625, -0.5640882848240228, -0.810519277110118, - -1.1159124134900906, + -1.1159124134900906 ] @test isapprox(A.û, ans, rtol = tolerance) @test isapprox(A.(tₒ), ans, rtol = tolerance) @@ -126,7 +126,7 @@ end 0.04756024028728636, -0.5301034620974782, -0.8408107101140526, - -1.1058428573417736, + -1.1058428573417736 ] @test isapprox(A.û, ans, rtol = tolerance) @test isapprox(A.(tₒ), ans, rtol = tolerance) @@ -173,7 +173,7 @@ end -0.676806367664006, -0.8587832527770329, -1.0443430843364814, - -1.2309001260104093, + -1.2309001260104093 ] @test isapprox(A.û, ans, rtol = tolerance) @test isapprox(A.(t̂), ans, rtol = tolerance) diff --git a/test/show.jl b/test/show.jl index 47bee3e7..45eef321 100644 --- a/test/show.jl +++ b/test/show.jl @@ -16,7 +16,7 @@ x = [1.0, 2.0, 3.0, 4.0, 5.0] LinearInterpolation(x, t), AkimaInterpolation(x, t), QuadraticSpline(x, t), - CubicSpline(x, t), + CubicSpline(x, t) ] test_show_line.(methods) end