From d5d0279597ce97df623c14bf8e93807313d4afa9 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 11:23:06 +0100 Subject: [PATCH 1/6] Fix type instability --- src/flatten.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/flatten.jl b/src/flatten.jl index 1414cb0..68b6c83 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -79,16 +79,20 @@ function flatten(::Type{T}, x::SparseMatrixCSC) where {T<:Real} end function flatten(::Type{T}, x::Tuple) where {T<:Real} - x_vecs_and_backs = map(val -> flatten(T, val), x) - x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) - lengths = map(length, x_vecs) - sz = _cumsum(lengths) + vec1, back1 = flatten(T, first(x)) + vec2, back2 = flatten(T, Base.tail(x)) + l1 = length(vec1) + l2 = length(vec2) function unflatten_to_Tuple(v::Vector{T}) - map(x_backs, lengths, sz) do x_back, l, s - return x_back(v[(s - l + 1):s]) - end + return (back1(v[1:l1]), back2(v[l1+1:l1+l2])) end - return reduce(vcat, x_vecs), unflatten_to_Tuple + return vcat(vec1, vec2), unflatten_to_Tuple +end + +function flatten(::Type{T}, x::Tuple{}) where {T<:Real} + v = T[] + unflatten_to_empty_Tuple(::Vector{T}) = x + return v, unflatten_to_empty_Tuple end function flatten(::Type{T}, x::NamedTuple) where {T<:Real} From 4eb267be19fe6b2837e5f5a60a68e0301da2777f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 11:31:00 +0100 Subject: [PATCH 2/6] Fix implementation --- src/flatten.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flatten.jl b/src/flatten.jl index 68b6c83..3d66fc2 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -84,7 +84,7 @@ function flatten(::Type{T}, x::Tuple) where {T<:Real} l1 = length(vec1) l2 = length(vec2) function unflatten_to_Tuple(v::Vector{T}) - return (back1(v[1:l1]), back2(v[l1+1:l1+l2])) + return (back1(v[1:l1]), back2(v[l1+1:l1+l2])...) end return vcat(vec1, vec2), unflatten_to_Tuple end From b58f524f2869760ea01378e42de7518776697959 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 11:33:28 +0100 Subject: [PATCH 3/6] Add regression test --- test/flatten.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/flatten.jl b/test/flatten.jl index 6b15e42..6490062 100644 --- a/test/flatten.jl +++ b/test/flatten.jl @@ -39,6 +39,13 @@ test_flatten_interface((1.0, 2.0); check_inferred=tuple_infers) test_flatten_interface((1.0, (2.0, 3.0), randn(5)); check_inferred=tuple_infers) + + # Prevent regression of PR #67 + @testset "Type stability of unflatten" begin + θ = (1., ((2., 3.), 4.)) + x, unflatten = flatten(θ) + @test (@inferred unflatten(x)) == θ + end end @testset "NamedTuple" begin From bbc883ae6959ee95e95211ef0790750d350225c6 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 6 Feb 2024 11:56:10 +0100 Subject: [PATCH 4/6] Satisfy reviewdog Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/flatten.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/flatten.jl b/test/flatten.jl index 6490062..975cc5e 100644 --- a/test/flatten.jl +++ b/test/flatten.jl @@ -42,7 +42,7 @@ # Prevent regression of PR #67 @testset "Type stability of unflatten" begin - θ = (1., ((2., 3.), 4.)) + θ = (1.0, ((2.0, 3.0), 4.0)) x, unflatten = flatten(θ) @test (@inferred unflatten(x)) == θ end From 7461979a9372741f6921a995f1af34fa3d021c89 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 6 Feb 2024 12:01:22 +0100 Subject: [PATCH 5/6] Satisfy reviewdog Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/flatten.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flatten.jl b/src/flatten.jl index 3d66fc2..08c8f23 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -84,7 +84,7 @@ function flatten(::Type{T}, x::Tuple) where {T<:Real} l1 = length(vec1) l2 = length(vec2) function unflatten_to_Tuple(v::Vector{T}) - return (back1(v[1:l1]), back2(v[l1+1:l1+l2])...) + return (back1(v[1:l1]), back2(v[(l1 + 1):(l1 + l2)])...) end return vcat(vec1, vec2), unflatten_to_Tuple end From f5bdf67cda150d9dd59e6b849760e794dfc7e9d1 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 18:18:01 +0100 Subject: [PATCH 6/6] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 103aae3..715a381 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParameterHandling" uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5" authors = ["Invenia Technical Computing Corporation"] -version = "0.4.8" +version = "0.4.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"