diff --git a/src/flatten.jl b/src/flatten.jl index 1414cb0..68b6c83 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -79,16 +79,20 @@ function flatten(::Type{T}, x::SparseMatrixCSC) where {T<:Real} end function flatten(::Type{T}, x::Tuple) where {T<:Real} - x_vecs_and_backs = map(val -> flatten(T, val), x) - x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) - lengths = map(length, x_vecs) - sz = _cumsum(lengths) + vec1, back1 = flatten(T, first(x)) + vec2, back2 = flatten(T, Base.tail(x)) + l1 = length(vec1) + l2 = length(vec2) function unflatten_to_Tuple(v::Vector{T}) - map(x_backs, lengths, sz) do x_back, l, s - return x_back(v[(s - l + 1):s]) - end + return (back1(v[1:l1]), back2(v[l1+1:l1+l2])) end - return reduce(vcat, x_vecs), unflatten_to_Tuple + return vcat(vec1, vec2), unflatten_to_Tuple +end + +function flatten(::Type{T}, x::Tuple{}) where {T<:Real} + v = T[] + unflatten_to_empty_Tuple(::Vector{T}) = x + return v, unflatten_to_empty_Tuple end function flatten(::Type{T}, x::NamedTuple) where {T<:Real}