Skip to content

Commit

Permalink
Merge pull request #311 from AayushSabharwal/as/sarray-broadcast
Browse files Browse the repository at this point in the history
fix: rework broadcasting copyto!
  • Loading branch information
ChrisRackauckas authored Dec 22, 2023
2 parents f654568 + dc573f0 commit 71fa2db
Showing 1 changed file with 37 additions and 28 deletions.
65 changes: 37 additions & 28 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -658,40 +658,49 @@ Broadcast.broadcastable(x::AbstractVectorOfArray) = x
end)
end

@inline function Base.copyto!(dest::AbstractVectorOfArray,
bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
bc = Broadcast.flatten(bc)
N = narrays(bc)
@inbounds for i in 1:N
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
copyto!(dest[:, i], unpack_voa(bc, i))
else
unpacked = unpack_voa(bc, i)
dest[:, i] = unpacked.f(unpacked.args...)
end
end
dest
end

@inline function Base.copyto!(dest::AbstractVectorOfArray,
bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
bc = Broadcast.flatten(bc)
@inbounds for i in 1:length(dest.u)
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
copyto!(dest[:, i], unpack_voa(bc, i))
else
unpacked = unpack_voa(bc, i)
value = unpacked.f(unpacked.args...)
dest[:, i] = if value isa Number && dest[:, i] isa AbstractArray
fill(value, StaticArraysCore.similar_type(dest[:, i]))
for (type, N_expr) in [
(Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))),
(Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u)))
]
@eval @inline function Base.copyto!(dest::AbstractVectorOfArray,
bc::$type)
bc = Broadcast.flatten(bc)
N = $N_expr
@inbounds for i in 1:N
if dest[:, i] isa AbstractArray
if ArrayInterface.ismutable(dest[:, i])
copyto!(dest[:, i], unpack_voa(bc, i))
else
unpacked = unpack_voa(bc, i)
arr_type = StaticArraysCore.similar_type(dest[:, i])
dest[:, i] = if length(unpacked) == 1
fill(copy(unpacked), arr_type)
else
arr_type(unpacked[j] for j in eachindex(unpacked))
end
end
else
value
dest[:, i] = copy(unpack_voa(bc, i))
end
end
dest
end
dest
end

# @inline function Base.copyto!(dest::AbstractVectorOfArray,
# bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
# bc = Broadcast.flatten(bc)
# @inbounds for i in 1:length(dest.u)
# if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
# copyto!(dest[:, i], unpack_voa(bc, i))
# else
# unpacked = unpack_voa(bc, i)
# dest[:, i] = StaticArraysCore.similar_type(dest[:, i])(unpacked[j] for j in eachindex(unpacked))
# end
# end
# dest
# end

## broadcasting utils

"""
Expand Down

0 comments on commit 71fa2db

Please sign in to comment.