diff --git a/Project.toml b/Project.toml index 6a3fb7cb..2d1ad771 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "4.4.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -24,6 +25,7 @@ DataInterpolationsSymbolicsExt = "Symbolics" [compat] ChainRulesCore = "0.9.44, 0.10, 1" Optim = "0.19, 0.20, 0.21, 0.22, 1.0" +PrettyTables = "2" QuadGK = "2.9.1" RecipesBase = "0.8, 1.0" RecursiveArrayTools = "2" diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index 93f32d70..5db14436 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -2,11 +2,11 @@ module DataInterpolationsOptimExt if isdefined(Base, :get_extension) using DataInterpolations: AbstractInterpolation, munge_data - import DataInterpolations: Curvefit, _interpolate + import DataInterpolations: Curvefit, _interpolate, get_show using Reexport else using ..DataInterpolations: AbstractInterpolation, munge_data - import ..DataInterpolations: Curvefit, _interpolate + import ..DataInterpolations: Curvefit, _interpolate, get_show using Reexport end @@ -75,4 +75,9 @@ function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}}, _interpolate(A, t), i end +function get_show(interp::CurvefitCache) + return "Curvefit" * + " with $(length(interp.t)) points, using $(nameof(typeof(interp.alg)))\n" +end + end # module diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index fe5e7d75..c7abda64 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -1,7 +1,7 @@ module DataInterpolationsRegularizationToolsExt using DataInterpolations -using DataInterpolations: munge_data, _interpolate, RegularizationSmooth +using DataInterpolations: munge_data, _interpolate, RegularizationSmooth, get_show using LinearAlgebra isdefined(Base, :get_extension) ? (import RegularizationTools as RT) : @@ -269,4 +269,9 @@ function DataInterpolations._interpolate(A::RegularizationSmooth{ DataInterpolations._interpolate(A.Aitp, t) end +function DataInterpolations.get_show(interp::RegularizationSmooth) + return "RegularizationSmooth" * + " with $(length(interp.t)) points, with regularization coefficient $(interp.λ)\n" +end + end # module diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index b5f36a35..96f9ebab 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -16,6 +16,7 @@ function Base.setindex!(A::AbstractInterpolation{true}, x, i) end using LinearAlgebra, RecursiveArrayTools, RecipesBase +using PrettyTables include("interpolation_caches.jl") include("interpolation_utils.jl") @@ -24,6 +25,7 @@ include("plot_rec.jl") include("derivatives.jl") include("integrals.jl") include("online.jl") +include("show.jl") (interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t) (interp::AbstractInterpolation)(t::Number, i::Integer) = _interpolate(interp, t, i) diff --git a/src/show.jl b/src/show.jl new file mode 100644 index 00000000..7ce1e006 --- /dev/null +++ b/src/show.jl @@ -0,0 +1,63 @@ +###################### Generic Dispatches ###################### + +function Base.show(io::IO, mime::MIME"text/plain", interp::AbstractInterpolation) + print(io, get_show(interp)) + header = ["time", get_names(interp.u)...] + data = hcat(interp.t, get_data(interp.u)) + pretty_table(io, data; header = header, vcrop_mode = :middle) +end + +function get_show(interp::AbstractInterpolation) + return string(nameof(typeof(interp))) * " with $(length(interp.t)) points\n" +end + +function get_data(u::AbstractVector) + return u +end + +function get_data(u::AbstractVector{<:AbstractVector}) + return reduce(hcat, u)' +end + +function get_data(u::AbstractMatrix) + return u' +end + +function get_names(u::AbstractVector) + return ["u"] +end + +function get_names(u::AbstractVector{<:AbstractVector}) + return ["u$i" for i in eachindex(first(u))] +end + +function get_names(u::AbstractMatrix) + return ["u$i" for i in axes(u, 1)] +end + +###################### Specific Dispatches ###################### + +function get_show(interp::QuadraticInterpolation) + return string(nameof(typeof(interp))) * + " with $(length(interp.t)) points, $(interp.mode) mode\n" +end + +function get_show(interp::LagrangeInterpolation) + return string(nameof(typeof(interp))) * + " with $(length(interp.t)) points, with order $(interp.n)\n" +end + +function get_show(interp::ConstantInterpolation) + return string(nameof(typeof(interp))) * + " with $(length(interp.t)) points, in $(interp.dir) direction\n" +end + +function get_show(interp::BSplineInterpolation) + return string(nameof(typeof(interp))) * + " with $(length(interp.t)) points, with degree $(interp.d)\n" +end + +function get_show(interp::BSplineApprox) + return string(nameof(typeof(interp))) * + " with $(length(interp.t)) points, with degree $(interp.d), number of control points $(interp.h)\n" +end diff --git a/test/runtests.jl b/test/runtests.jl index eebeda2a..39447ed6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,4 +19,7 @@ using DataInterpolations, Test @testset "Regularization Smoothing" begin include("regularization.jl") end + @testset "Show methods" begin + include("show.jl") + end end diff --git a/test/show.jl b/test/show.jl new file mode 100644 index 00000000..2dcbfb96 --- /dev/null +++ b/test/show.jl @@ -0,0 +1,83 @@ +using Optim, StableRNGs +using RegularizationTools + +t = [1.0, 2.0, 3.0, 4.0, 5.0] +x = [1.0, 2.0, 3.0, 4.0, 5.0] + +@testset "Generic Cases" begin + function test_show_line(A) + @testset "$(nameof(typeof(A)))" begin + @test startswith(sprint(io -> show(io, MIME"text/plain"(), A)), + "$(nameof(typeof(A))) with $(length(A.t)) points\n") + end + end + methods = [ + LinearInterpolation(x, t), + AkimaInterpolation(x, t), + QuadraticSpline(x, t), + CubicSpline(x, t), + ] + test_show_line.(methods) +end + +@testset "Specific Cases" begin + @testset "QuadraticInterpolation" begin + A = QuadraticInterpolation(x, t) + @test startswith(sprint(io -> show(io, MIME"text/plain"(), A)), + "QuadraticInterpolation with 5 points, Forward mode\n") + end + @testset "LagrangeInterpolation" begin + A = LagrangeInterpolation(x, t) + @test startswith(sprint(io -> show(io, MIME"text/plain"(), A)), + "LagrangeInterpolation with 5 points, with order 4\n") + end + @testset "ConstantInterpolation" begin + A = ConstantInterpolation(x, t) + @test startswith(sprint(io -> show(io, MIME"text/plain"(), A)), + "ConstantInterpolation with 5 points, in left direction\n") + end + @testset "BSplineInterpolation" begin + A = BSplineInterpolation(x, t, 3, :Uniform, :Uniform) + @test startswith(sprint(io -> show(io, MIME"text/plain"(), A)), + "BSplineInterpolation with 5 points, with degree 3\n") + end + @testset "BSplineApprox" begin + A = BSplineApprox(x, t, 2, 4, :Uniform, :Uniform) + @test startswith(sprint(io -> show(io, MIME"text/plain"(), A)), + "BSplineApprox with 5 points, with degree 2, number of control points 4\n") + end +end + +@testset "CurveFit" begin + rng = StableRNG(12345) + model(x, p) = @. p[1] / (1 + exp(x - p[2])) + t = range(-10, stop = 10, length = 40) + u = model(t, [1.0, 2.0]) + 0.01 * randn(rng, length(t)) + p0 = [0.5, 0.5] + A = Curvefit(u, t, model, p0, LBFGS()) + @test startswith(sprint(io -> show(io, MIME"text/plain"(), A)), + "Curvefit with 40 points, using LBFGS\n") +end + +@testset "RegularizationSmooth" begin + npts = 50 + xmin = 0.0 + xspan = 3 / 2 * π + x = collect(range(xmin, xmin + xspan, length = npts)) + rng = StableRNG(655) + x = x + xspan / npts * (rand(rng, npts) .- 0.5) + # select a subset randomly + idx = unique(rand(rng, collect(eachindex(x)), 20)) + t = x[unique(idx)] + npts = length(t) + ut = sin.(t) + stdev = 1e-1 * maximum(ut) + u = ut + stdev * randn(rng, npts) + # data must be ordered if t̂ is not provided + idx = sortperm(t) + tₒ = t[idx] + uₒ = u[idx] + A = RegularizationSmooth(uₒ, tₒ; alg = :fixed) + @test startswith(sprint(io -> show(io, MIME"text/plain"(), A)), + "RegularizationSmooth with 15 points, with regularization coefficient 1.0\n") +end