From 4044bd53a032c67ab9b01cc79d46517f5e8e91ac Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 20 Sep 2023 09:05:02 -0700 Subject: [PATCH] Attempt to fix Broadcast.broadcast_shape inference --- src/blockbroadcast.jl | 2 ++ src/tuple_tools.jl | 51 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/test_blockbroadcast.jl | 6 +++++ test/test_tuple_tools.jl | 35 +++++++++++++++++++++++++ 5 files changed, 95 insertions(+) create mode 100644 src/tuple_tools.jl create mode 100644 test/test_tuple_tools.jl diff --git a/src/blockbroadcast.jl b/src/blockbroadcast.jl index 2a48ad8d..8d86228a 100644 --- a/src/blockbroadcast.jl +++ b/src/blockbroadcast.jl @@ -29,7 +29,9 @@ BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle( # sortedunion can assume inputs are already sorted so this could be improved +include("tuple_tools.jl") sortedunion(a,b) = sort!(union(a,b)) +sortedunion(a::Tuple, b::Tuple) = tuple_sort(tuple_union(a,b)) sortedunion(a::Base.OneTo, b::Base.OneTo) = Base.OneTo(max(last(a),last(b))) sortedunion(a::AbstractUnitRange, b::AbstractUnitRange) = min(first(a),first(b)):max(last(a),last(b)) combine_blockaxes(a, b) = _BlockedUnitRange(sortedunion(blocklasts(a), blocklasts(b))) diff --git a/src/tuple_tools.jl b/src/tuple_tools.jl new file mode 100644 index 00000000..e272274a --- /dev/null +++ b/src/tuple_tools.jl @@ -0,0 +1,51 @@ +##### +##### From TupleTools.jl +##### +function _split(t::Tuple) + N = length(t) + M = N >> 1 + return ntuple(i -> t[i], M), ntuple(i -> t[i + M], N - M) +end +function _merge(t1::Tuple, t2::Tuple, lt, by, rev) + if lt(by(first(t1)), by(first(t2))) != rev + return (first(t1), _merge(tail(t1), t2, lt, by, rev)...) + else + return (first(t2), _merge(t1, tail(t2), lt, by, rev)...) + end +end +_merge(::Tuple{}, t2::Tuple, lt, by, rev) = t2 +_merge(t1::Tuple, ::Tuple{}, lt, by, rev) = t1 +_merge(::Tuple{}, ::Tuple{}, lt, by, rev) = () + +tuple_sort(t::Tuple; lt=isless, by=identity, rev::Bool=false) = _tuple_sort(t, lt, by, rev) +@inline function _tuple_sort(t::Tuple, lt=isless, by=identity, rev::Bool=false) + t1, t2 = _split(t) + t1s = _tuple_sort(t1, lt, by, rev) + t2s = _tuple_sort(t2, lt, by, rev) + return _merge(t1s, t2s, lt, by, rev) +end +_tuple_sort(t::Tuple{Any}, lt=isless, by=identity, rev::Bool=false) = t +_tuple_sort(t::Tuple{}, lt=isless, by=identity, rev::Bool=false) = t + + +###### +###### tuple_union +###### + +struct DistinctElems{T<:Tuple} + elems::T +end +tuple_union(a::Tuple, b::Tuple) = distinct_elems(DistinctElems(()), a..., b...).elems + +distinct_elems(x::DistinctElems) = x + +distinct_elems(x::DistinctElems, r1) = + r1 in x.elems ? x : DistinctElems((x.elems..., r1)) + +function distinct_elems(x::DistinctElems, r1, remaining...) + return if r1 in x.elems + distinct_elems(x, remaining...) + else + distinct_elems(DistinctElems((x.elems..., r1)), remaining...) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8c7b6dc8..dedacaf9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,3 +29,4 @@ include("test_blockproduct.jl") include("test_blockreduce.jl") include("test_blockdeque.jl") include("test_blockcholesky.jl") +include("test_tuple_tools.jl") diff --git a/test/test_blockbroadcast.jl b/test/test_blockbroadcast.jl index 0d58aa3d..4d861a62 100644 --- a/test/test_blockbroadcast.jl +++ b/test/test_blockbroadcast.jl @@ -182,6 +182,12 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal u = BlockArray(randn(5), [2,3]); @inferred(copyto!(similar(u), Base.broadcasted(exp, u))) @test exp.(u) == exp.(Vector(u)) + + shape1 = (BlockArrays._BlockedUnitRange((2,)),); + shape2 = (BlockArrays._BlockedUnitRange((2,)),); + @show Base.Broadcast.broadcast_shape(shape1, shape2) + @inferred Base.Broadcast.broadcast_shape(shape1, shape2) + @code_warntype Base.Broadcast.broadcast_shape(shape1, shape2) end @testset "adjtrans" begin diff --git a/test/test_tuple_tools.jl b/test/test_tuple_tools.jl new file mode 100644 index 00000000..c6920d12 --- /dev/null +++ b/test/test_tuple_tools.jl @@ -0,0 +1,35 @@ +using BlockArrays +using Test +using Random + +@testset "Tuple Tools" begin + @testset "tuple_sort" begin + n = 10 + p = randperm(n) + t = (p...,) + @test @inferred(BlockArrays.tuple_sort((1,))) == (1,) + @test @inferred(BlockArrays.tuple_sort(())) == () + @inferred(BlockArrays.tuple_sort(t; rev=true)) == (sort(p; rev=true)...,) + @test @inferred(BlockArrays.tuple_sort(t; rev=false)) == (sort(p; rev=false)...,) + @test BlockArrays.tuple_sort((2, 1, 3.0)) === (1, 2, 3.0) + + shape1 = (BlockArrays._BlockedUnitRange((2,)),); + shape2 = (BlockArrays._BlockedUnitRange((2,)),); + bl1 = BlockArrays.blocklasts(shape1[1]) + bl2 = BlockArrays.blocklasts(shape2[1]) + # @show BlockArrays.tuple_union(bl1,bl2) + @test BlockArrays.sortedunion(bl1, bl2) == (2,) + end + + # from Base + @testset "tuple_union" begin + for S in (identity,) + s = BlockArrays.tuple_union(S((1,2)), S((3,4))) + @test s == S((1,2,3,4)) + s = BlockArrays.tuple_union(S((5,6,7,8)), S((7,8,9))) + @test s == S((5,6,7,8,9)) + s = BlockArrays.tuple_union(S((1,3,5,7)), (2,3,4,5)) + @test s == S((1,3,5,7,2,4)) + end + end +end