diff --git a/Project.toml b/Project.toml index 10afd1ec..88ce5b46 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -33,6 +34,7 @@ RecursiveArrayToolsMeasurementsExt = "Measurements" RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] RecursiveArrayToolsSparseArraysExt = ["SparseArrays"] +RecursiveArrayToolsStructArraysExt = "StructArrays" RecursiveArrayToolsTrackerExt = "Tracker" RecursiveArrayToolsZygoteExt = "Zygote" diff --git a/ext/RecursiveArrayToolsStructArraysExt.jl b/ext/RecursiveArrayToolsStructArraysExt.jl new file mode 100644 index 00000000..b4a5d07c --- /dev/null +++ b/ext/RecursiveArrayToolsStructArraysExt.jl @@ -0,0 +1,6 @@ +module RecursiveArrayToolsStructArraysExt + +import RecursiveArrayTools, StructArrays +RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u) + +end \ No newline at end of file diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index fc2fce0b..eef23eac 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -849,22 +849,25 @@ end @inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}) bc = Broadcast.flatten(bc) - parent = find_VoA_parent(bc.args) - if parent isa AbstractVector + u = if parent isa AbstractVector # this is the default behavior in v3.15.0 N = narrays(bc) - return VectorOfArray(map(1:N) do i + map(1:N) do i copy(unpack_voa(bc, i)) - end) + end else # if parent isa AbstractArray - return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _) + map(enumerate(Iterators.product(axes(parent)...))) do (i, _) copy(unpack_voa(bc, i)) - end) + end end + VectorOfArray(rewrap(parent, u)) end +rewrap(::Array,u) = u +rewrap(parent, u) = convert(typeof(parent), u) + for (type, N_expr) in [ (Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))), (Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u))) diff --git a/test/copy_static_array_test.jl b/test/copy_static_array_test.jl index ffcfa52c..d3346593 100644 --- a/test/copy_static_array_test.jl +++ b/test/copy_static_array_test.jl @@ -114,3 +114,10 @@ a_voa = VectorOfArray(a) a_voa .= 1.0 @test a_voa[1] == SVector(1.0, 1.0) @test a_voa[2] == SVector(1.0, 1.0) + +#Broadcast Copy of StructArray +x = StructArray{SVector{2, Float64}}((randn(2), randn(2))) +vx = VectorOfArray(x) +vx2 = copy(vx) .+ 1 +ans = vx .+ vx2 +@test ans.u isa StructArray \ No newline at end of file