From 4d8130e60b500014c36419d94530e7cf66958996 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Dec 2023 13:08:14 +0530 Subject: [PATCH] feat: add ability to set VectorOfArray with Array using broadcast --- src/vector_of_array.jl | 15 ++++++++++++++- test/interface_tests.jl | 8 ++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 42c24089..ae5901fb 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -663,7 +663,20 @@ end bc = Broadcast.flatten(bc) N = narrays(bc) @inbounds for i in 1:N - if dest[:, i] isa AbstractArray && !isa(dest[:, i], StaticArraysCore.SArray) + if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) + copyto!(dest[:, i], unpack_voa(bc, i)) + else + dest[:, i] = copy(unpack_voa(bc, i)) + end + end + dest +end + +@inline function Base.copyto!(dest::AbstractVectorOfArray, + bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}) + bc = Broadcast.flatten(bc) + @inbounds for i in 1:length(dest.u) + if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) copyto!(dest[:, i], unpack_voa(bc, i)) else dest[:, i] = copy(unpack_voa(bc, i)) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 4f79c3e6..ad3fec43 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -125,3 +125,11 @@ z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})]) z .= x .+ y @test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) + +yy = [2.0 1.0; 2.0 1.0] +zz = x .+ yy +@test zz == [4.0 2.0; 4.0 2.0] + +z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})]) +z .= zz +@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])