diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 37ca265c..aabc4074 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -104,51 +104,49 @@ function quadratic_spline_params(t::AbstractVector, sc::AbstractVector) end # helper function for data manipulation -function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}) - return u, t -end - function munge_data(u::AbstractVector, t::AbstractVector) - Tu = Base.nonmissingtype(eltype(u)) - Tt = Base.nonmissingtype(eltype(t)) - @assert length(t) == length(u) - non_missing_indices = collect( - i for i in 1:length(t) - if !ismissing(u[i]) && !ismissing(t[i]) - ) + Tu = nonmissingtype(eltype(u)) + Tt = nonmissingtype(eltype(t)) + if Tu === eltype(u) && Tt === eltype(t) + return u, t + end - u = Tu.([u[i] for i in non_missing_indices]) - t = Tt.([t[i] for i in non_missing_indices]) + @assert length(t) == length(u) + non_missing_mask = map((ui, ti) -> !ismissing(ui) && !ismissing(ti), u, t) + u = convert(AbstractVector{Tu}, u[non_missing_mask]) + t = convert(AbstractVector{Tt}, t[non_missing_mask]) return u, t end -function munge_data(U::StridedMatrix, t::AbstractVector) - TU = Base.nonmissingtype(eltype(U)) - Tt = Base.nonmissingtype(eltype(t)) - @assert length(t) == size(U, 2) - non_missing_indices = collect( - i for i in 1:length(t) - if !any(ismissing, U[:, i]) && !ismissing(t[i]) - ) +function munge_data(U::AbstractMatrix, t::AbstractVector) + TU = nonmissingtype(eltype(U)) + Tt = nonmissingtype(eltype(t)) + if TU === eltype(U) && Tt === eltype(t) + return U, t + end - U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) - t = Tt.([t[i] for i in non_missing_indices]) + @assert length(t) == size(U, 2) + non_missing_mask = map( + (uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachcol(U), t) + U = convert(AbstractMatrix{TU}, U[:, non_missing_mask]) + t = convert(AbstractVector{Tt}, t[non_missing_mask]) return U, t end function munge_data(U::AbstractArray{T, N}, t) where {T, N} - TU = Base.nonmissingtype(eltype(U)) - Tt = Base.nonmissingtype(eltype(t)) - @assert length(t) == size(U, ndims(U)) - ax = axes(U)[1:(end - 1)] - non_missing_indices = collect( - i for i in 1:length(t) - if !any(ismissing, U[ax..., i]) && !ismissing(t[i]) - ) - U = cat([TU.(U[ax..., i]) for i in non_missing_indices]...; dims = ndims(U)) - t = Tt.([t[i] for i in non_missing_indices]) + TU = nonmissingtype(eltype(U)) + Tt = nonmissingtype(eltype(t)) + if TU === eltype(U) && Tt === eltype(t) + return U, t + end + + @assert length(t) == size(U, N) + non_missing_mask = map( + (uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachslice(U; dims = N), t) + U = convert(AbstractArray{TU, N}, copy(selectdim(U, N, non_missing_mask))) + t = convert(AbstractVector{Tt}, t[non_missing_mask]) return U, t end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9081ff28..c14d67de 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -920,3 +920,22 @@ f_cubic_spline = c -> square(CubicSpline, c) @test ForwardDiff.derivative(f_quadratic_spline, 4.0) ≈ 8.0 @test ForwardDiff.derivative(f_cubic_spline, 2.0) ≈ 4.0 @test ForwardDiff.derivative(f_cubic_spline, 4.0) ≈ 8.0 + +@testset "munge_data" begin + t0 = [0.1, 0.2, 0.3] + u0 = ["A", "B", "C"] + iszero_allocations(u, t) = iszero(@allocated(DataInterpolations.munge_data(u, t))) + + for T in (String, Union{String, Missing}), dims in 1:3 + _u0 = convert(Array{T}, reshape(u0, ntuple(i -> i == dims ? 3 : 1, dims))) + + u, t = @inferred(DataInterpolations.munge_data(_u0, t0)) + @test u isa Array{String, dims} + @test t isa Vector{Float64} + if T === String + @test iszero_allocations(_u0, t0) + @test u === _u0 + @test t === t + end + end +end