From be0617936c8f3c5d4373cf379c555afa0bb932c8 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 7 Feb 2025 02:38:22 +0100 Subject: [PATCH 1/2] Fix type inference and performance problems of `munge_data` --- Project.toml | 4 ++- src/interpolation_utils.jl | 64 ++++++++++++++++++------------------- test/interpolation_tests.jl | 20 ++++++++++++ 3 files changed, 54 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index d509b76d..83f1cbc7 100644 --- a/Project.toml +++ b/Project.toml @@ -32,6 +32,7 @@ EnumX = "1.0.4" FindFirstFunctions = "1.3" FiniteDifferences = "0.12.31" ForwardDiff = "0.10.36" +JET = "0.9.17" LinearAlgebra = "1.10" Optim = "1.6" PrettyTables = "2" @@ -53,6 +54,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Optim = "429524aa-4258-5aef-a3af-852621145aeb" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" @@ -64,4 +66,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BenchmarkTools", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"] +test = ["Aqua", "BenchmarkTools", "JET", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"] diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 37ca265c..aabc4074 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -104,51 +104,49 @@ function quadratic_spline_params(t::AbstractVector, sc::AbstractVector) end # helper function for data manipulation -function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}) - return u, t -end - function munge_data(u::AbstractVector, t::AbstractVector) - Tu = Base.nonmissingtype(eltype(u)) - Tt = Base.nonmissingtype(eltype(t)) - @assert length(t) == length(u) - non_missing_indices = collect( - i for i in 1:length(t) - if !ismissing(u[i]) && !ismissing(t[i]) - ) + Tu = nonmissingtype(eltype(u)) + Tt = nonmissingtype(eltype(t)) + if Tu === eltype(u) && Tt === eltype(t) + return u, t + end - u = Tu.([u[i] for i in non_missing_indices]) - t = Tt.([t[i] for i in non_missing_indices]) + @assert length(t) == length(u) + non_missing_mask = map((ui, ti) -> !ismissing(ui) && !ismissing(ti), u, t) + u = convert(AbstractVector{Tu}, u[non_missing_mask]) + t = convert(AbstractVector{Tt}, t[non_missing_mask]) return u, t end -function munge_data(U::StridedMatrix, t::AbstractVector) - TU = Base.nonmissingtype(eltype(U)) - Tt = Base.nonmissingtype(eltype(t)) - @assert length(t) == size(U, 2) - non_missing_indices = collect( - i for i in 1:length(t) - if !any(ismissing, U[:, i]) && !ismissing(t[i]) - ) +function munge_data(U::AbstractMatrix, t::AbstractVector) + TU = nonmissingtype(eltype(U)) + Tt = nonmissingtype(eltype(t)) + if TU === eltype(U) && Tt === eltype(t) + return U, t + end - U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) - t = Tt.([t[i] for i in non_missing_indices]) + @assert length(t) == size(U, 2) + non_missing_mask = map( + (uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachcol(U), t) + U = convert(AbstractMatrix{TU}, U[:, non_missing_mask]) + t = convert(AbstractVector{Tt}, t[non_missing_mask]) return U, t end function munge_data(U::AbstractArray{T, N}, t) where {T, N} - TU = Base.nonmissingtype(eltype(U)) - Tt = Base.nonmissingtype(eltype(t)) - @assert length(t) == size(U, ndims(U)) - ax = axes(U)[1:(end - 1)] - non_missing_indices = collect( - i for i in 1:length(t) - if !any(ismissing, U[ax..., i]) && !ismissing(t[i]) - ) - U = cat([TU.(U[ax..., i]) for i in non_missing_indices]...; dims = ndims(U)) - t = Tt.([t[i] for i in non_missing_indices]) + TU = nonmissingtype(eltype(U)) + Tt = nonmissingtype(eltype(t)) + if TU === eltype(U) && Tt === eltype(t) + return U, t + end + + @assert length(t) == size(U, N) + non_missing_mask = map( + (uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachslice(U; dims = N), t) + U = convert(AbstractArray{TU, N}, copy(selectdim(U, N, non_missing_mask))) + t = convert(AbstractVector{Tt}, t[non_missing_mask]) return U, t end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9081ff28..04a1dbd4 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -4,6 +4,7 @@ using StableRNGs using Optim, ForwardDiff using BenchmarkTools using Unitful +using JET function test_interpolation_type(T) @test T <: DataInterpolations.AbstractInterpolation @@ -920,3 +921,22 @@ f_cubic_spline = c -> square(CubicSpline, c) @test ForwardDiff.derivative(f_quadratic_spline, 4.0) ≈ 8.0 @test ForwardDiff.derivative(f_cubic_spline, 2.0) ≈ 4.0 @test ForwardDiff.derivative(f_cubic_spline, 4.0) ≈ 8.0 + +@testset "munge_data" begin + t0 = [0.1, 0.2, 0.3] + u0 = ["A", "B", "C"] + + for T in (String, Union{String, Missing}), dims in 1:3 + _u0 = convert(Array{T}, reshape(u0, ntuple(i -> i == dims ? 3 : 1, dims))) + + u, t = @inferred(DataInterpolations.munge_data(_u0, t0)) + @test u isa Array{String, dims} + @test t isa Vector{Float64} + if T === String + @test u === _u0 + @test t === t + end + + @test_call DataInterpolations.munge_data(_u0, t0) + end +end From f2a975ff25a1e9b6bf8adc5c41238b12b6d67b17 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 7 Feb 2025 09:07:06 +0100 Subject: [PATCH 2/2] Remove JET tests --- Project.toml | 4 +--- test/interpolation_tests.jl | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 83f1cbc7..d509b76d 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,6 @@ EnumX = "1.0.4" FindFirstFunctions = "1.3" FiniteDifferences = "0.12.31" ForwardDiff = "0.10.36" -JET = "0.9.17" LinearAlgebra = "1.10" Optim = "1.6" PrettyTables = "2" @@ -54,7 +53,6 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Optim = "429524aa-4258-5aef-a3af-852621145aeb" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" @@ -66,4 +64,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BenchmarkTools", "JET", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"] +test = ["Aqua", "BenchmarkTools", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"] diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 04a1dbd4..c14d67de 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -4,7 +4,6 @@ using StableRNGs using Optim, ForwardDiff using BenchmarkTools using Unitful -using JET function test_interpolation_type(T) @test T <: DataInterpolations.AbstractInterpolation @@ -925,6 +924,7 @@ f_cubic_spline = c -> square(CubicSpline, c) @testset "munge_data" begin t0 = [0.1, 0.2, 0.3] u0 = ["A", "B", "C"] + iszero_allocations(u, t) = iszero(@allocated(DataInterpolations.munge_data(u, t))) for T in (String, Union{String, Missing}), dims in 1:3 _u0 = convert(Array{T}, reshape(u0, ntuple(i -> i == dims ? 3 : 1, dims))) @@ -933,10 +933,9 @@ f_cubic_spline = c -> square(CubicSpline, c) @test u isa Array{String, dims} @test t isa Vector{Float64} if T === String + @test iszero_allocations(_u0, t0) @test u === _u0 @test t === t end - - @test_call DataInterpolations.munge_data(_u0, t0) end end