Skip to content

Commit

Permalink
Use StaticArrayStyle instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Mar 20, 2022
1 parent a92bc41 commit f8c988e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
26 changes: 21 additions & 5 deletions src/staticarrays_support.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import StaticArrays: StaticArray, FieldArray, tuple_prod, StaticArrayStyle
using StaticArrays: StaticArrays, StaticArray, FieldArray, tuple_prod, StaticArrayStyle
import StaticArrays: Size
import Base.Broadcast: instantiate

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
Expand Down Expand Up @@ -28,10 +30,24 @@ end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

@static if isdefined(StaticArrays, :static_combine_axes)
# StaticArrayStyle has no similar defined.
# Convert to `DefaultArrayStyle` to return a sized (Struct)Array.
# TODO: return a StaticArray?
function Base.copy(bc::Broadcasted{StructArrayStyle{StaticArrayStyle{N},N}}) where {N}
bc′ = convert(Broadcasted{StructArrayStyle{Broadcast.DefaultArrayStyle{N},N}}, bc)
# Convert to `StaticArrayStyle` to return a StaticArray instead.
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
bc′ = convert(Broadcasted{StaticArrayStyle{M}}, bc)
return copy(bc′)
end
function instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
bc′ = instantiate(convert(Broadcasted{StaticArrayStyle{M}}, bc))
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
end
function Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing)
return StaticArrays.static_combine_axes(bc.args...)
end
Size(::Type{SA}) where {SA<:StructArray} = Size(fieldtype(fieldtype(SA, 1), 1))
StaticArrays.isstatic(::SA) where {SA<:StructArray} = cst(SA) isa StaticArrayStyle
function StaticArrays.similar_type(::Type{SA}, ::Type{T}, s::Size{S}) where {SA<:StructArray, T, S}
return StaticArrays.similar_type(fieldtype(fieldtype(SA, 1), 1), T, s)
end
end
25 changes: 17 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -977,26 +977,35 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
@test @inferred(broadcast(el -> el.a, v)) == ["s1", "s2"]

# ambiguity check (can we do this better?)
function _test(a, b, c)
function _test(a, b, c, T = StructArray)
if a isa StructArray || b isa StructArray || c isa StructArray
d = @inferred a .+ b .- c
@test d == collect(a) .+ collect(b) .- collect(c)
@test d isa StructArray
@test d isa T
end
end
testset = (StructArray([1;2+im]),
testset = Any[StructArray([1;2+im]),
StructArray([1 2+im]),
1:2,
(1,2),
(@SArray [1 2]),
StructArray(@SArray [1 1+2im]))
(@SArray [1 2])]
for aa in testset, bb in testset, cc in testset
_test(aa, bb, cc)
end
if isdefined(StaticArrays, :static_combine_axes)
testset = Any[StructArray(@SArray [1 1+2im]), (1,2), StructArray(@SArray [1;1+2im])]
for aa in testset, bb in testset, cc in testset
_test(aa, bb, cc, StaticArray)
end
end
end

a = @SArray randn(3,3);
b = StructArray{ComplexF64}((a,a))
@test a[:,1] .+ b isa StructArray && (a[:,1] .+ b).re isa SizedMatrix
function struct_static_allocated_test()
s = StructArray{ComplexF64}((SVector(1., 2., 3.), SVector(0., 0., 0.)))
return broadcast(log, s)
end
if isdefined(StaticArrays, :static_combine_axes)
@test (@allocated struct_static_allocated_test()) === 0
end

@testset "staticarrays" begin
Expand Down

0 comments on commit f8c988e

Please sign in to comment.