diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index b05c646b..6a3573a9 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -658,40 +658,49 @@ Broadcast.broadcastable(x::AbstractVectorOfArray) = x end) end -@inline function Base.copyto!(dest::AbstractVectorOfArray, - bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}) - bc = Broadcast.flatten(bc) - N = narrays(bc) - @inbounds for i in 1:N - if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) - copyto!(dest[:, i], unpack_voa(bc, i)) - else - unpacked = unpack_voa(bc, i) - dest[:, i] = unpacked.f(unpacked.args...) - end - end - dest -end - -@inline function Base.copyto!(dest::AbstractVectorOfArray, - bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}) - bc = Broadcast.flatten(bc) - @inbounds for i in 1:length(dest.u) - if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) - copyto!(dest[:, i], unpack_voa(bc, i)) - else - unpacked = unpack_voa(bc, i) - value = unpacked.f(unpacked.args...) - dest[:, i] = if value isa Number && dest[:, i] isa AbstractArray - fill(value, StaticArraysCore.similar_type(dest[:, i])) +for (type, N_expr) in [ + (Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))), + (Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u))) + ] + @eval @inline function Base.copyto!(dest::AbstractVectorOfArray, + bc::$type) + bc = Broadcast.flatten(bc) + N = $N_expr + @inbounds for i in 1:N + if dest[:, i] isa AbstractArray + if ArrayInterface.ismutable(dest[:, i]) + copyto!(dest[:, i], unpack_voa(bc, i)) + else + unpacked = unpack_voa(bc, i) + arr_type = StaticArraysCore.similar_type(dest[:, i]) + dest[:, i] = if length(unpacked) == 1 + fill(copy(unpacked), arr_type) + else + arr_type(unpacked[j] for j in eachindex(unpacked)) + end + end else - value + dest[:, i] = copy(unpack_voa(bc, i)) end end + dest end - dest end +# @inline function Base.copyto!(dest::AbstractVectorOfArray, +# bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}) +# bc = Broadcast.flatten(bc) +# @inbounds for i in 1:length(dest.u) +# if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) +# copyto!(dest[:, i], unpack_voa(bc, i)) +# else +# unpacked = unpack_voa(bc, i) +# dest[:, i] = StaticArraysCore.similar_type(dest[:, i])(unpacked[j] for j in eachindex(unpacked)) +# end +# end +# dest +# end + ## broadcasting utils """