Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add show methods #195

Merged
merged 4 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "4.4.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -24,6 +25,7 @@ DataInterpolationsSymbolicsExt = "Symbolics"
[compat]
ChainRulesCore = "0.9.44, 0.10, 1"
Optim = "0.19, 0.20, 0.21, 0.22, 1.0"
PrettyTables = "2"
QuadGK = "2.9.1"
RecipesBase = "0.8, 1.0"
RecursiveArrayTools = "2"
Expand Down
9 changes: 7 additions & 2 deletions ext/DataInterpolationsOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ module DataInterpolationsOptimExt

if isdefined(Base, :get_extension)
using DataInterpolations: AbstractInterpolation, munge_data
import DataInterpolations: Curvefit, _interpolate
import DataInterpolations: Curvefit, _interpolate, get_show
using Reexport
else
using ..DataInterpolations: AbstractInterpolation, munge_data
import ..DataInterpolations: Curvefit, _interpolate
import ..DataInterpolations: Curvefit, _interpolate, get_show
using Reexport
end

Expand Down Expand Up @@ -75,4 +75,9 @@ function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}},
_interpolate(A, t), i
end

function get_show(interp::CurvefitCache)
return "Curvefit" *
" with $(length(interp.t)) points, using $(nameof(typeof(interp.alg)))\n"
end

end # module
7 changes: 6 additions & 1 deletion ext/DataInterpolationsRegularizationToolsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DataInterpolationsRegularizationToolsExt

using DataInterpolations
using DataInterpolations: munge_data, _interpolate, RegularizationSmooth
using DataInterpolations: munge_data, _interpolate, RegularizationSmooth, get_show
using LinearAlgebra

isdefined(Base, :get_extension) ? (import RegularizationTools as RT) :
Expand Down Expand Up @@ -269,4 +269,9 @@ function DataInterpolations._interpolate(A::RegularizationSmooth{
DataInterpolations._interpolate(A.Aitp, t)
end

function DataInterpolations.get_show(interp::RegularizationSmooth)
return "RegularizationSmooth" *
" with $(length(interp.t)) points, with regularization coefficient $(interp.λ)\n"
end

end # module
2 changes: 2 additions & 0 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ function Base.setindex!(A::AbstractInterpolation{true}, x, i)
end

using LinearAlgebra, RecursiveArrayTools, RecipesBase
using PrettyTables

include("interpolation_caches.jl")
include("interpolation_utils.jl")
Expand All @@ -24,6 +25,7 @@ include("plot_rec.jl")
include("derivatives.jl")
include("integrals.jl")
include("online.jl")
include("show.jl")

(interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t)
(interp::AbstractInterpolation)(t::Number, i::Integer) = _interpolate(interp, t, i)
Expand Down
63 changes: 63 additions & 0 deletions src/show.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
###################### Generic Dispatches ######################

function Base.show(io::IO, mime::MIME"text/plain", interp::AbstractInterpolation)
print(io, get_show(interp))
header = ["time", get_names(interp.u)...]
data = hcat(interp.t, get_data(interp.u))
pretty_table(io, data; header = header, vcrop_mode = :middle)
end

function get_show(interp::AbstractInterpolation)
return string(nameof(typeof(interp))) * " with $(length(interp.t)) points\n"
end

function get_data(u::AbstractVector)
return u
end

function get_data(u::AbstractVector{<:AbstractVector})
return reduce(hcat, u)'
end

function get_data(u::AbstractMatrix)
return u'
end

function get_names(u::AbstractVector)
return ["u"]
end

function get_names(u::AbstractVector{<:AbstractVector})
return ["u$i" for i in eachindex(first(u))]
end

function get_names(u::AbstractMatrix)
return ["u$i" for i in axes(u, 1)]
end

###################### Specific Dispatches ######################

function get_show(interp::QuadraticInterpolation)
return string(nameof(typeof(interp))) *
" with $(length(interp.t)) points, $(interp.mode) mode\n"
end

function get_show(interp::LagrangeInterpolation)
return string(nameof(typeof(interp))) *
" with $(length(interp.t)) points, with order $(interp.n)\n"
end

function get_show(interp::ConstantInterpolation)
return string(nameof(typeof(interp))) *
" with $(length(interp.t)) points, in $(interp.dir) direction\n"
end

function get_show(interp::BSplineInterpolation)
return string(nameof(typeof(interp))) *
" with $(length(interp.t)) points, with degree $(interp.d)\n"
end

function get_show(interp::BSplineApprox)
return string(nameof(typeof(interp))) *
" with $(length(interp.t)) points, with degree $(interp.d), number of control points $(interp.h)\n"
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ using DataInterpolations, Test
@testset "Regularization Smoothing" begin
include("regularization.jl")
end
@testset "Show methods" begin
include("show.jl")
end
end
83 changes: 83 additions & 0 deletions test/show.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using Optim, StableRNGs
using RegularizationTools

t = [1.0, 2.0, 3.0, 4.0, 5.0]
x = [1.0, 2.0, 3.0, 4.0, 5.0]

@testset "Generic Cases" begin
function test_show_line(A)
@testset "$(nameof(typeof(A)))" begin
@test startswith(sprint(io -> show(io, MIME"text/plain"(), A)),
"$(nameof(typeof(A))) with $(length(A.t)) points\n")
end
end
methods = [
LinearInterpolation(x, t),
AkimaInterpolation(x, t),
QuadraticSpline(x, t),
CubicSpline(x, t),
]
test_show_line.(methods)
end

@testset "Specific Cases" begin
@testset "QuadraticInterpolation" begin
A = QuadraticInterpolation(x, t)
@test startswith(sprint(io -> show(io, MIME"text/plain"(), A)),
"QuadraticInterpolation with 5 points, Forward mode\n")
end
@testset "LagrangeInterpolation" begin
A = LagrangeInterpolation(x, t)
@test startswith(sprint(io -> show(io, MIME"text/plain"(), A)),
"LagrangeInterpolation with 5 points, with order 4\n")
end
@testset "ConstantInterpolation" begin
A = ConstantInterpolation(x, t)
@test startswith(sprint(io -> show(io, MIME"text/plain"(), A)),
"ConstantInterpolation with 5 points, in left direction\n")
end
@testset "BSplineInterpolation" begin
A = BSplineInterpolation(x, t, 3, :Uniform, :Uniform)
@test startswith(sprint(io -> show(io, MIME"text/plain"(), A)),
"BSplineInterpolation with 5 points, with degree 3\n")
end
@testset "BSplineApprox" begin
A = BSplineApprox(x, t, 2, 4, :Uniform, :Uniform)
@test startswith(sprint(io -> show(io, MIME"text/plain"(), A)),
"BSplineApprox with 5 points, with degree 2, number of control points 4\n")
end
end

@testset "CurveFit" begin
rng = StableRNG(12345)
model(x, p) = @. p[1] / (1 + exp(x - p[2]))
t = range(-10, stop = 10, length = 40)
u = model(t, [1.0, 2.0]) + 0.01 * randn(rng, length(t))
p0 = [0.5, 0.5]
A = Curvefit(u, t, model, p0, LBFGS())
@test startswith(sprint(io -> show(io, MIME"text/plain"(), A)),
"Curvefit with 40 points, using LBFGS\n")
end

@testset "RegularizationSmooth" begin
npts = 50
xmin = 0.0
xspan = 3 / 2 * π
x = collect(range(xmin, xmin + xspan, length = npts))
rng = StableRNG(655)
x = x + xspan / npts * (rand(rng, npts) .- 0.5)
# select a subset randomly
idx = unique(rand(rng, collect(eachindex(x)), 20))
t = x[unique(idx)]
npts = length(t)
ut = sin.(t)
stdev = 1e-1 * maximum(ut)
u = ut + stdev * randn(rng, npts)
# data must be ordered if t̂ is not provided
idx = sortperm(t)
tₒ = t[idx]
uₒ = u[idx]
A = RegularizationSmooth(uₒ, tₒ; alg = :fixed)
@test startswith(sprint(io -> show(io, MIME"text/plain"(), A)),
"RegularizationSmooth with 15 points, with regularization coefficient 1.0\n")
end