From 9b9d8b2471fb9908023267d4e28cd17d6b015da2 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 9 Feb 2022 22:18:52 +0800 Subject: [PATCH] Make `StructArrayStyle` track inputs dimension fix #185 --- src/structarray.jl | 10 +++++++--- test/runtests.jl | 18 +++++++++++++++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 650a6b44..28abdafe 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -445,7 +445,11 @@ end # broadcast import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle -struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end +struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end +# If `S` also track input's dimensionality, we'd better also update it. +StructArrayStyle{S,M}(::Val{N}) where {M,S<:AbstractArrayStyle{M},N} = + StructArrayStyle{typeof(S(Val(N))),N}() +StructArrayStyle{S,M}(::Val{N}) where {M,S,N} = StructArrayStyle{S,N}() @inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = combine_style_types(BroadcastStyle(A), args...) @@ -455,9 +459,9 @@ combine_style_types(s::BroadcastStyle) = s Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...) -BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}() +BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)),ndims(SA)}() -Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} = +Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} = isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc)) # for aliasing analysis during broadcast diff --git a/test/runtests.jl b/test/runtests.jl index 1fe7be57..20412dc5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -926,8 +926,24 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El # used inside of broadcast but we also test it here explicitly @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) - s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2)))) + s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2)))) @test_throws MethodError s .+ s + + # test for dimensionality track + @test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} + @test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} + @test Base.broadcasted(+, s, [1;;2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} + @test Base.broadcasted(+, [1;;;2], s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} + + a = StructArray([1;2+im]) + b = StructArray([1;;2+im]) + @test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b) + + # issue #185 + A = StructArray(randn(ComplexF64, 3, 3)) + B = randn(ComplexF64, 3, 3) + c = StructArray(randn(ComplexF64, 3)) + @test (A .= B .* c) === A end @testset "staticarrays" begin