Skip to content

Commit

Permalink
Resolve the review comments
Browse files Browse the repository at this point in the history
1. Update Project.toml.
2. test `backend`'s inferability.

Co-Authored-By: Pietro Vertechi <[email protected]>
  • Loading branch information
N5N3 and piever committed Nov 6, 2022
1 parent e711ebe commit 8c83220
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Adapt = "1, 2, 3"
DataAPI = "1"
StaticArraysCore = "1.3"
StaticArrays = "1.5.6"
GPUArraysCore = "~0.1.2"
GPUArraysCore = "0.1.2"
Tables = "1"
julia = "1.6"

Expand Down
8 changes: 5 additions & 3 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x
# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backs = map(GPUArraysCore.backend, fieldtypes(array_types(T)))
all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!")
return backs[1]
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end

end # module
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
bcmul2(a) = 2 .* a
a = StructArray(randn(ComplexF32, 10, 10))
sa = jl(a)
backend = StructArrays.GPUArraysCore.backend
@test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im)
@test collect(@inferred(bcabs(sa))) == bcabs(a)
@test @inferred(bcmul2(sa)) isa StructArray
@test (sa .+= 1) isa StructArray
Expand Down

0 comments on commit 8c83220

Please sign in to comment.