diff --git a/src/utils.jl b/src/utils.jl index d9c1585a..658e8418 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -174,16 +174,18 @@ vecvecapply(f::Base.Callable, v) Calls `f` on each element of a vecvec `v`. """ -function vecvecapply(f, v) +function vecvecapply(f, v::AbstractArray{<:AbstractArray}) sol = Vector{eltype(eltype(v))}() for i in eachindex(v) - for j in eachindex(v[:, i]) - push!(sol, v[:, i][j]) + for j in eachindex(v[i]) + push!(sol, v[i][j]) end end f(sol) end +vecvecapply(f, v::AbstractVectorOfArray) = vecvecapply(f, v.u) + function vecvecapply(f, v::Array{T}) where {T <: Number} f(v) end diff --git a/test/utils_test.jl b/test/utils_test.jl index 52f24bd6..8bda7442 100644 --- a/test/utils_test.jl +++ b/test/utils_test.jl @@ -14,6 +14,13 @@ A = [[1 2; 3 4], [1 3; 4 6], [5 6; 7 8]] A = zeros(5, 5) @test recursive_unitless_eltype(A) == Float64 +@test vecvecapply(x -> abs.(x), -1) == 1 +@test vecvecapply(x -> abs.(x), [-1, -2, 3, -4]) == [1, 2, 3, 4] +v = [[-1 2; 3 -4], [5 -6; -7 -8]] +vv = [1, 3, 2, 4, 5, 7, 6, 8] +@test vecvecapply(x -> abs.(x), v) == vv +@test vecvecapply(x -> abs.(x), VectorOfArray(v)) == vv + using Unitful A = zeros(5, 5) * 1u"kg" @test recursive_unitless_eltype(A) == Float64