Skip to content

Commit

Permalink
Merge pull request #406 from huiyuxie/fix
Browse files Browse the repository at this point in the history
Fix broadcast failure for `VectorOfArray` with `SVector{1}`
  • Loading branch information
ChrisRackauckas authored Oct 29, 2024
2 parents 6947538 + baa2af9 commit 3819dee
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,9 @@ for (type, N_expr) in [
else
unpacked = unpack_voa(bc, i)
arr_type = StaticArraysCore.similar_type(dest[:, i])
dest[:, i] = if length(unpacked) == 1
dest[:, i] = if length(unpacked) == 1 && length(dest[:, i]) == 1
arr_type(unpacked[1])
elseif length(unpacked) == 1
fill(copy(unpacked), arr_type)
else
arr_type(unpacked[j] for j in eachindex(unpacked))
Expand Down
32 changes: 32 additions & 0 deletions test/copy_static_array_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,35 @@ b = recursivecopy(a)
@test a[1] == b[1]
a[1] *= 2
@test a[1] != b[1]

# Broadcasting when SVector{N} where N = 1
a = [SVector(0.0) for _ in 1:2]
a_voa = VectorOfArray(a)
b_voa = copy(a_voa)
a_voa[1] = SVector(1.0)
a_voa[2] = SVector(1.0)
@. b_voa = a_voa
@test b_voa[1] == a_voa[1]
@test b_voa[2] == a_voa[2]

a = [SVector(0.0) for _ in 1:2]
a_voa = VectorOfArray(a)
a_voa .= 1.0
@test a_voa[1] == SVector(1.0)
@test a_voa[2] == SVector(1.0)

# Broadcasting when SVector{N} where N > 1
a = [SVector(0.0, 0.0) for _ in 1:2]
a_voa = VectorOfArray(a)
b_voa = copy(a_voa)
a_voa[1] = SVector(1.0, 1.0)
a_voa[2] = SVector(1.0, 1.0)
@. b_voa = a_voa
@test b_voa[1] == a_voa[1]
@test b_voa[2] == a_voa[2]

a = [SVector(0.0, 0.0) for _ in 1:2]
a_voa = VectorOfArray(a)
a_voa .= 1.0
@test a_voa[1] == SVector(1.0, 1.0)
@test a_voa[2] == SVector(1.0, 1.0)

0 comments on commit 3819dee

Please sign in to comment.