From 282fb057937ad7f261f1ec83d11aad9b8647d066 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Mon, 23 Sep 2024 09:46:06 +0000 Subject: [PATCH 1/6] fix(LinearInterpolation): fix cache to avoid scalar indexing and generalize for high dim arrays --- src/parameter_caches.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 0701b3a2..bfd9f87e 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -19,8 +19,9 @@ function safe_diff(b, a::T) where {T} end function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where {T} - Δu = if u isa AbstractMatrix - [safe_diff(u[j, idx + 1], u[j, idx]) for j in 1:size(u)[1]] + Δu = if u isa AbstractArray + ax = axes(u) + safe_diff.(u[ax[1:end-1]..., idx+1:idx+1] , u[ax[1:end-1]..., idx:idx]) else safe_diff(u[idx + 1], u[idx]) end From 29c5b286fddd2a91afa1f8f026eded188cc46e31 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Mon, 23 Sep 2024 09:46:47 +0000 Subject: [PATCH 2/6] feat(utils): add `munge_data` dispatch for higher dimensional arrays --- src/interpolation_utils.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 0df5e484..4037d810 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -94,6 +94,21 @@ function munge_data(U::StridedMatrix, t::AbstractVector) return U, t end +function munge_data(U::AbstractArray{T, N}, t) where {T, N} + TU = Base.nonmissingtype(eltype(U)) + Tt = Base.nonmissingtype(eltype(t)) + @assert length(t) == size(U, ndims(U)) + ax = axes(U)[1:end-1] + non_missing_indices = collect( + i for i in 1:length(t) + if !any(ismissing, U[ax..., i]) && !ismissing(t[i]) + ) + U = cat([TU.(U[ax..., i]) for i in non_missing_indices]...; dims = ndims(U)) + t = Tt.([t[i] for i in non_missing_indices]) + + return U, t +end + seems_linear(assume_linear_t::Bool, _) = assume_linear_t seems_linear(assume_linear_t::Number, t) = looks_linear(t; threshold = assume_linear_t) From 2ffef8d34304d8c13d6ba0dc0a1e5fd5db6affd0 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Mon, 23 Sep 2024 09:47:20 +0000 Subject: [PATCH 3/6] fix(LinearInterpolation): generalize interpolation for all sized arrays --- src/interpolation_methods.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 697f7375..26298b8f 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -33,11 +33,12 @@ function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, igues val end -function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) +function _interpolate(A::LinearInterpolation{<:AbstractArray}, t::Number, iguess) idx = get_idx(A, t, iguess) Δt = t - A.t[idx] slope = get_parameters(A, idx) - return A.u[:, idx] + slope * Δt + ax = axes(A.u)[1:end-1] + return A.u[ax..., idx] + slope * Δt end # Quadratic Interpolation From 9530aa6f6f7c8820ec81dbb61529d09644657292 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Mon, 23 Sep 2024 10:40:09 +0000 Subject: [PATCH 4/6] fix(LinearInterpolation): better check for arrays --- src/parameter_caches.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index bfd9f87e..83f574b8 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -18,8 +18,8 @@ function safe_diff(b, a::T) where {T} b == a ? zero(T) : b - a end -function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where {T} - Δu = if u isa AbstractArray +function linear_interpolation_parameters(u::AbstractArray{T, N}, t, idx) where {T, N} + Δu = if N > 1 ax = axes(u) safe_diff.(u[ax[1:end-1]..., idx+1:idx+1] , u[ax[1:end-1]..., idx:idx]) else From ce28c4b6c4ca6b6c764ea64e26cd6f7d88e2b145 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Tue, 24 Sep 2024 09:15:15 +0000 Subject: [PATCH 5/6] test(interpolation_tests): update tests for `Matrix` LinearInterpolation --- test/interpolation_tests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 3ebc1741..38c424de 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -43,11 +43,11 @@ end A = LinearInterpolation(u, t; extrapolate = true) for (_t, _u) in zip(t, eachcol(u)) - @test A(_t) == _u + @test A(_t) == reshape(_u, : , 1) end - @test A(0) == [0.0, 0.0] - @test A(5.5) == [11.0, 16.5] - @test A(11) == [22, 33] + @test A(0) == [0.0; 0.0;;] + @test A(5.5) == [11.0; 16.5;;] + @test A(11) == [22; 33;;] x = 1:10 y = 2:4 From 97529bb656ba64fe543964c4bdd66fa8b6f8d429 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Tue, 24 Sep 2024 10:24:11 +0000 Subject: [PATCH 6/6] chore(DataInterpolations): fix formatting --- src/interpolation_methods.jl | 2 +- src/interpolation_utils.jl | 4 ++-- src/parameter_caches.jl | 3 ++- test/interpolation_tests.jl | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 26298b8f..03adf558 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -37,7 +37,7 @@ function _interpolate(A::LinearInterpolation{<:AbstractArray}, t::Number, iguess idx = get_idx(A, t, iguess) Δt = t - A.t[idx] slope = get_parameters(A, idx) - ax = axes(A.u)[1:end-1] + ax = axes(A.u)[1:(end - 1)] return A.u[ax..., idx] + slope * Δt end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 4037d810..f8fc7147 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -98,7 +98,7 @@ function munge_data(U::AbstractArray{T, N}, t) where {T, N} TU = Base.nonmissingtype(eltype(U)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == size(U, ndims(U)) - ax = axes(U)[1:end-1] + ax = axes(U)[1:(end - 1)] non_missing_indices = collect( i for i in 1:length(t) if !any(ismissing, U[ax..., i]) && !ismissing(t[i]) @@ -107,7 +107,7 @@ function munge_data(U::AbstractArray{T, N}, t) where {T, N} t = Tt.([t[i] for i in non_missing_indices]) return U, t -end +end seems_linear(assume_linear_t::Bool, _) = assume_linear_t seems_linear(assume_linear_t::Number, t) = looks_linear(t; threshold = assume_linear_t) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 83f574b8..824b2e2a 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -21,7 +21,8 @@ end function linear_interpolation_parameters(u::AbstractArray{T, N}, t, idx) where {T, N} Δu = if N > 1 ax = axes(u) - safe_diff.(u[ax[1:end-1]..., idx+1:idx+1] , u[ax[1:end-1]..., idx:idx]) + safe_diff.( + u[ax[1:(end - 1)]..., (idx + 1):(idx + 1)], u[ax[1:(end - 1)]..., idx:idx]) else safe_diff(u[idx + 1], u[idx]) end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 38c424de..69e5c197 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -43,7 +43,7 @@ end A = LinearInterpolation(u, t; extrapolate = true) for (_t, _u) in zip(t, eachcol(u)) - @test A(_t) == reshape(_u, : , 1) + @test A(_t) == reshape(_u, :, 1) end @test A(0) == [0.0; 0.0;;] @test A(5.5) == [11.0; 16.5;;]