Skip to content

Commit

Permalink
Merge pull request #412 from SciML/structarray_broadcast
Browse files Browse the repository at this point in the history
Fix StructArray broadcast in VectorOfArray
  • Loading branch information
ChrisRackauckas authored Nov 20, 2024
2 parents 9c8387c + dfdf6ca commit 0beb98e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -33,6 +34,7 @@ RecursiveArrayToolsMeasurementsExt = "Measurements"
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
RecursiveArrayToolsSparseArraysExt = ["SparseArrays"]
RecursiveArrayToolsStructArraysExt = "StructArrays"
RecursiveArrayToolsTrackerExt = "Tracker"
RecursiveArrayToolsZygoteExt = "Zygote"

Expand Down
6 changes: 6 additions & 0 deletions ext/RecursiveArrayToolsStructArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module RecursiveArrayToolsStructArraysExt

import RecursiveArrayTools, StructArrays
RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u)

end
15 changes: 9 additions & 6 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,22 +849,25 @@ end

@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
bc = Broadcast.flatten(bc)

parent = find_VoA_parent(bc.args)

if parent isa AbstractVector
u = if parent isa AbstractVector
# this is the default behavior in v3.15.0
N = narrays(bc)
return VectorOfArray(map(1:N) do i
map(1:N) do i
copy(unpack_voa(bc, i))
end)
end
else # if parent isa AbstractArray
return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
copy(unpack_voa(bc, i))
end)
end
end
VectorOfArray(rewrap(parent, u))
end

rewrap(::Array,u) = u
rewrap(parent, u) = convert(typeof(parent), u)

for (type, N_expr) in [
(Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))),
(Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u)))
Expand Down
7 changes: 7 additions & 0 deletions test/copy_static_array_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,10 @@ a_voa = VectorOfArray(a)
a_voa .= 1.0
@test a_voa[1] == SVector(1.0, 1.0)
@test a_voa[2] == SVector(1.0, 1.0)

#Broadcast Copy of StructArray
x = StructArray{SVector{2, Float64}}((randn(2), randn(2)))
vx = VectorOfArray(x)
vx2 = copy(vx) .+ 1
ans = vx .+ vx2
@test ans.u isa StructArray

0 comments on commit 0beb98e

Please sign in to comment.