Skip to content

Commit

Permalink
Merge pull request #308 from sathvikbhagavan/sb/inference
Browse files Browse the repository at this point in the history
refactor: collect parameters such that they are type stable
  • Loading branch information
ChrisRackauckas authored Jul 23, 2024
2 parents 2f758f8 + 826f0a7 commit d1cc354
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/online.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function append!(
append!(A.t.parent, t)
parameters = quadratic_interpolation_parameters.(
Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2))
l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...)))
l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters))))
append!(A.p.l₀, l₀)
append!(A.p.l₁, l₁)
append!(A.p.l₂, l₂)
Expand Down
8 changes: 4 additions & 4 deletions src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
function QuadraticParameterCache(u, t)
parameters = quadratic_interpolation_parameters.(
Ref(u), Ref(t), 1:(length(t) - 2))
l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...)))
l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters))))
return QuadraticParameterCache(l₀, l₁, l₂)
end

Expand Down Expand Up @@ -81,7 +81,7 @@ end
function CubicSplineParameterCache(u, h, z)
parameters = cubic_spline_parameters.(
Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1))
c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...)))
c₁, c₂ = collect.(eachrow(stack(collect.(parameters))))
return CubicSplineParameterCache(c₁, c₂)
end

Expand All @@ -99,7 +99,7 @@ end
function CubicHermiteParameterCache(du, u, t)
parameters = cubic_hermite_spline_parameters.(
Ref(du), Ref(u), Ref(t), 1:(length(t) - 1))
c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...)))
c₁, c₂ = collect.(eachrow(stack(collect.(parameters))))
return CubicHermiteParameterCache(c₁, c₂)
end

Expand All @@ -123,7 +123,7 @@ end
function QuinticHermiteParameterCache(ddu, du, u, t)
parameters = quintic_hermite_spline_parameters.(
Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1))
c₁, c₂, c₃ = collect.(eachrow(hcat(collect.(parameters)...)))
c₁, c₂, c₃ = collect.(eachrow(stack(collect.(parameters))))
return QuinticHermiteParameterCache(c₁, c₂, c₃)
end

Expand Down
61 changes: 46 additions & 15 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,52 @@
using DataInterpolations
u = 2.0collect(1:10)
t = 1.0collect(1:10)
@inferred LinearInterpolation(u, t)
A = LinearInterpolation(u, t)
using Symbolics

for i in 1:10
@test u[i] == A.u[i]
end
@testset "Interface" begin
u = 2.0collect(1:10)
t = 1.0collect(1:10)
A = LinearInterpolation(u, t)

for i in 1:10
@test u[i] == A.u[i]
end

for i in 1:10
@test t[i] == A.t[i]
for i in 1:10
@test t[i] == A.t[i]
end
end

using Symbolics
u = 2.0collect(1:10)
t = 1.0collect(1:10)
A = LinearInterpolation(u, t)
@testset "Symbolics" begin
u = 2.0collect(1:10)
t = 1.0collect(1:10)
A = LinearInterpolation(u, t)
@variables t x(t)
substitute(A(t), Dict(t => x))
end

@variables t x(t)
substitute(A(t), Dict(t => x))
@testset "Type Inference" begin
u = 2.0collect(1:10)
t = 1.0collect(1:10)
methods = [
ConstantInterpolation, LinearInterpolation,
QuadraticInterpolation, LagrangeInterpolation,
QuadraticSpline, CubicSpline, AkimaInterpolation
]
@testset "$method" for method in methods
@inferred method(u, t)
end
@testset "BSplineInterpolation" begin
@inferred BSplineInterpolation(u, t, 3, :Uniform, :Uniform)
@inferred BSplineInterpolation(u, t, 3, :ArcLen, :Average)
end
@testset "BSplineApprox" begin
@inferred BSplineApprox(u, t, 3, 5, :Uniform, :Uniform)
@inferred BSplineApprox(u, t, 3, 5, :ArcLen, :Average)
end
du = ones(10)
ddu = zeros(10)
@testset "Hermite Splines" begin
@inferred CubicHermiteSpline(du, u, t)
@inferred PCHIPInterpolation(u, t)
@inferred QuinticHermiteSpline(ddu, du, u, t)
end
end
3 changes: 2 additions & 1 deletion test/online_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ u2 = [1.0, 2.0, 1.0]
ts_append = 1.0:0.5:6.0
ts_push = 1.0:0.5:4.0

for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation]
@testset "$method" for method in [
LinearInterpolation, QuadraticInterpolation, ConstantInterpolation]
func1 = method(u1, t1)
append!(func1, u2, t2)
func2 = method(vcat(u1, u2), vcat(t1, t2))
Expand Down

0 comments on commit d1cc354

Please sign in to comment.