Skip to content

Commit

Permalink
Fix type instability
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace committed Feb 6, 2024
1 parent 567b070 commit d5d0279
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,20 @@ function flatten(::Type{T}, x::SparseMatrixCSC) where {T<:Real}
end

function flatten(::Type{T}, x::Tuple) where {T<:Real}
x_vecs_and_backs = map(val -> flatten(T, val), x)
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
lengths = map(length, x_vecs)
sz = _cumsum(lengths)
vec1, back1 = flatten(T, first(x))
vec2, back2 = flatten(T, Base.tail(x))
l1 = length(vec1)
l2 = length(vec2)
function unflatten_to_Tuple(v::Vector{T})
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[(s - l + 1):s])
end
return (back1(v[1:l1]), back2(v[l1+1:l1+l2]))
end
return reduce(vcat, x_vecs), unflatten_to_Tuple
return vcat(vec1, vec2), unflatten_to_Tuple
end

function flatten(::Type{T}, x::Tuple{}) where {T<:Real}
v = T[]
unflatten_to_empty_Tuple(::Vector{T}) = x
return v, unflatten_to_empty_Tuple
end

function flatten(::Type{T}, x::NamedTuple) where {T<:Real}
Expand Down

0 comments on commit d5d0279

Please sign in to comment.