From 83c990ac98bef49f39107dc0310342e3522794a5 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 19 Nov 2024 21:36:38 -0100 Subject: [PATCH 1/2] Fix StructArray broadcast in VectorOfArray Fixes https://github.com/SciML/RecursiveArrayTools.jl/issues/410 This specializes so that if `u.u` is not a vector, it will convert the broadcast to fix that. I couldn't find a nice generic way to use `map` so the fallback is to build the vector and convert, which seems to not be a big performance issue. For StructArrays, `convert(typeof(x), Vector(x))` fails, and so this case is specialized. --- Project.toml | 2 ++ ext/RecursiveArrayToolsStructArraysExt.jl | 6 ++++++ src/vector_of_array.jl | 17 +++++++++++------ test/copy_static_array_test.jl | 7 +++++++ 4 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 ext/RecursiveArrayToolsStructArraysExt.jl diff --git a/Project.toml b/Project.toml index 10afd1ec..88ce5b46 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -33,6 +34,7 @@ RecursiveArrayToolsMeasurementsExt = "Measurements" RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] RecursiveArrayToolsSparseArraysExt = ["SparseArrays"] +RecursiveArrayToolsStructArraysExt = "StructArrays" RecursiveArrayToolsTrackerExt = "Tracker" RecursiveArrayToolsZygoteExt = "Zygote" diff --git a/ext/RecursiveArrayToolsStructArraysExt.jl b/ext/RecursiveArrayToolsStructArraysExt.jl new file mode 100644 index 00000000..b4a5d07c --- /dev/null +++ b/ext/RecursiveArrayToolsStructArraysExt.jl @@ -0,0 +1,6 @@ +module RecursiveArrayToolsStructArraysExt + +import RecursiveArrayTools, StructArrays +RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u) + +end \ No newline at end of file diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index fc2fce0b..78442274 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -849,28 +849,33 @@ end @inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}) bc = Broadcast.flatten(bc) - parent = find_VoA_parent(bc.args) - if parent isa AbstractVector + u = if parent isa AbstractVector # this is the default behavior in v3.15.0 N = narrays(bc) - return VectorOfArray(map(1:N) do i + map(1:N) do i copy(unpack_voa(bc, i)) - end) + end else # if parent isa AbstractArray - return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _) + map(enumerate(Iterators.product(axes(parent)...))) do (i, _) copy(unpack_voa(bc, i)) - end) + end end + VectorOfArray(rewrap(parent, u)) end +rewrap(::Array,u) = u +rewrap(parent, u) = convert(typeof(parent), u) + for (type, N_expr) in [ (Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))), (Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u))) ] @eval @inline function Base.copyto!(dest::AbstractVectorOfArray, bc::$type) + @show typeof(dest) + error() bc = Broadcast.flatten(bc) N = $N_expr @inbounds for i in 1:N diff --git a/test/copy_static_array_test.jl b/test/copy_static_array_test.jl index ffcfa52c..d3346593 100644 --- a/test/copy_static_array_test.jl +++ b/test/copy_static_array_test.jl @@ -114,3 +114,10 @@ a_voa = VectorOfArray(a) a_voa .= 1.0 @test a_voa[1] == SVector(1.0, 1.0) @test a_voa[2] == SVector(1.0, 1.0) + +#Broadcast Copy of StructArray +x = StructArray{SVector{2, Float64}}((randn(2), randn(2))) +vx = VectorOfArray(x) +vx2 = copy(vx) .+ 1 +ans = vx .+ vx2 +@test ans.u isa StructArray \ No newline at end of file From dfdf6caec6454619937a11e00d2efeab620cccce Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 20 Nov 2024 09:45:54 -0500 Subject: [PATCH 2/2] Update src/vector_of_array.jl --- src/vector_of_array.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 78442274..eef23eac 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -874,8 +874,6 @@ for (type, N_expr) in [ ] @eval @inline function Base.copyto!(dest::AbstractVectorOfArray, bc::$type) - @show typeof(dest) - error() bc = Broadcast.flatten(bc) N = $N_expr @inbounds for i in 1:N