From f5bcb43c8547770c700efafc66f3135fe7e3dc37 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 19 Oct 2023 17:34:49 +0530 Subject: [PATCH 1/4] Fix corner cases and short-circuit mapreduce --- Project.toml | 2 +- src/FillArrays.jl | 10 ---------- src/fillbroadcast.jl | 41 ++++++++++++++++++++++++++++++++--------- test/runtests.jl | 29 ++++++++++++++++++++++++++++- 4 files changed, 61 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 13264cc3..f4ee73e4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.7.0" +version = "1.7.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 7ca7b78e..23aa6279 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -550,16 +550,6 @@ for SMT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal) end end - -######### -# maximum/minimum -######### - -for op in (:maximum, :minimum) - @eval $op(x::AbstractFill) = getindex_value(x) -end - - ######### # Cumsum ######### diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 6cf11284..1a0f8674 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -33,22 +33,45 @@ end ### mapreduce function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) - fval = f(getindex_value(A)) - out = fval - for _ in 2:length(A) - out = op(out, fval) + if length(A) == 0 + return Base.mapreduce_empty_iter(f, op, A, Base.HasEltype()) + end + val = getindex_value(A) + fval = f(val) + out = Base.mapreduce_first(f, op, val) + if op(out, fval) != out + for _ in 2:length(A) + out = op(out, fval) + end end out end +identityel(f, ::Union{typeof(+), typeof(Base.add_sum)}, A) = zero(f(zero(eltype(A)))) +identityel(f, ::Union{typeof(*), typeof(Base.mul_prod)}, A) = one(f(one(eltype(A)))) +identityel(f, ::typeof(&), A) = true +identityel(f, ::typeof(|), A) = false +identityel(f, ::Any, @nospecialize(A)) = throw(ArgumentError("reducing over an empty collection is not allowed")) +function mapreducedim_empty(f, op, A) + z = identityel(f, op, A) + op(z, z) +end + function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) - fval = f(getindex_value(A)) red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) - out = fval - for _ in 2:red - out = op(out, fval) + if red == 0 + out = mapreducedim_empty(f, op, A) + else + val = getindex_value(A) + out = Base.mapreduce_first(f, op, val) + fval = f(val) + if op(out, fval) != out + for _ in 2:red + out = op(out, fval) + end + end end - Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) + Fill(out, ntuple(d -> d in dims ? axes(A,ndims(A)+1) : axes(A,d), ndims(A))) end function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) diff --git a/test/runtests.jl b/test/runtests.jl index 50d42c54..97dcc936 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1090,7 +1090,7 @@ end @test Zeros(S, 10) .* (T(1):T(10)) ≡ Zeros(U, 10) @test_throws DimensionMismatch Zeros(S, 10) .* (T(1):T(11)) end - end + end end end @@ -1115,6 +1115,33 @@ end end @testset "mapreduce" begin + @testset "corner cases with small arrays" begin + @test_throws Exception mapreduce(identity, max, Fill(2,0)) + @test_throws Exception mapreduce(identity, max, Fill(2,0), dims=1) + @testset for op in (+, *) + @test mapreduce(identity, op, Fill(2,0)) == mapreduce(identity, op, fill(2,0)) + @test mapreduce(identity, op, Fill(2,0), dims=1) == mapreduce(identity, op, fill(2,0), dims=1) + end + @testset for op in (&, |) + @test mapreduce(identity, op, Fill(true,0)) == mapreduce(identity, op, fill(true,0)) + @test mapreduce(identity, op, Fill(true,0), dims=1) == mapreduce(identity, op, fill(true,0), dims=1) + end + @testset for op in (max, +, *) + @test mapreduce(identity, op, Fill(2,0), dims=2) == mapreduce(identity, op, fill(2,0), dims=2) + @test mapreduce(identity, op, Fill(2,0,1), dims=2) == mapreduce(identity, op, fill(2,0,1), dims=2) + @test mapreduce(identity, op, Fill(2,1)) == mapreduce(identity, op, fill(2,1)) + @test mapreduce(identity, op, Fill(2,1), dims=1) == mapreduce(identity, op, fill(2,1), dims=1) + @test mapreduce(identity, op, Fill(2,1), dims=2) == mapreduce(identity, op, fill(2,1), dims=2) + @test mapreduce(identity, op, Fill(2,1,1), dims=1) == mapreduce(identity, op, fill(2,1,1), dims=1) + @test mapreduce(identity, op, Fill(2,1,1), dims=2) == mapreduce(identity, op, fill(2,1,1), dims=2) + @test mapreduce(identity, op, Fill(2,2)) == mapreduce(identity, op, fill(2,2)) + @test mapreduce(identity, op, Fill(2,2), dims=1) == mapreduce(identity, op, fill(2,2), dims=1) + @test mapreduce(identity, op, Fill(2,2), dims=2) == mapreduce(identity, op, fill(2,2), dims=2) + @test mapreduce(identity, op, Fill(2,2,2), dims=1) == mapreduce(identity, op, fill(2,2,2), dims=1) + @test mapreduce(identity, op, Fill(2,2,2), dims=2) == mapreduce(identity, op, fill(2,2,2), dims=2) + end + end + x = rand(3, 4) y = fill(1.0, 3, 4) Y = Fill(1.0, 3, 4) From 715476545a1e4c545b5823caa8714836d1bb8bf2 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 19 Oct 2023 20:50:13 +0530 Subject: [PATCH 2/4] reduce with init --- src/fillbroadcast.jl | 40 +++++++++++++++++++++++++++-- test/runtests.jl | 60 ++++++++++++++++++++++++++++---------------- 2 files changed, 77 insertions(+), 23 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 1a0f8674..36c4ac3d 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -37,8 +37,23 @@ function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Col return Base.mapreduce_empty_iter(f, op, A, Base.HasEltype()) end val = getindex_value(A) - fval = f(val) out = Base.mapreduce_first(f, op, val) + fval = f(val) + if op(out, fval) != out + for _ in 2:length(A) + out = op(out, fval) + end + end + out +end + +function Base._mapreduce_dim(f, op, init, A::AbstractFill, ::Colon) + if length(A) == 0 + return init + end + val = getindex_value(A) + fval = f(val) + out = op(init, fval) if op(out, fval) != out for _ in 2:length(A) out = op(out, fval) @@ -57,6 +72,10 @@ function mapreducedim_empty(f, op, A) op(z, z) end +function reduced_indices(A, dims) + ntuple(d -> d in dims ? axes(A,ndims(A)+1) : axes(A,d), ndims(A)) +end + function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) if red == 0 @@ -71,7 +90,24 @@ function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) end end end - Fill(out, ntuple(d -> d in dims ? axes(A,ndims(A)+1) : axes(A,d), ndims(A))) + Fill(out, reduced_indices(A, dims)) +end + +function Base._mapreduce_dim(f, op, init, A::AbstractFill, dims) + red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) + if red == 0 + out = init + else + val = getindex_value(A) + fval = f(val) + out = op(init, fval) + if op(out, fval) != out + for _ in 2:red + out = op(out, fval) + end + end + end + Fill(out, reduced_indices(A, dims)) end function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) diff --git a/test/runtests.jl b/test/runtests.jl index 97dcc936..eb4c52a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1116,29 +1116,47 @@ end @testset "mapreduce" begin @testset "corner cases with small arrays" begin - @test_throws Exception mapreduce(identity, max, Fill(2,0)) - @test_throws Exception mapreduce(identity, max, Fill(2,0), dims=1) - @testset for op in (+, *) - @test mapreduce(identity, op, Fill(2,0)) == mapreduce(identity, op, fill(2,0)) - @test mapreduce(identity, op, Fill(2,0), dims=1) == mapreduce(identity, op, fill(2,0), dims=1) + @test_throws Exception reduce(max, Fill(2,0)) + @test_throws Exception reduce(max, Fill(2,0), dims=1) + @testset for (op, init) in ((+, 0), (*, 1)) + @test reduce(op, Fill(2,0)) == reduce(op, fill(2,0)) + @test reduce(op, Fill(2,0); init) == reduce(op, fill(2,0); init) + @test reduce(op, Fill(2,0), dims=1) == reduce(op, fill(2,0), dims=1) + @test reduce(op, Fill(2,0); init, dims=1) == reduce(op, fill(2,0); init, dims=1) end - @testset for op in (&, |) - @test mapreduce(identity, op, Fill(true,0)) == mapreduce(identity, op, fill(true,0)) - @test mapreduce(identity, op, Fill(true,0), dims=1) == mapreduce(identity, op, fill(true,0), dims=1) + @testset for (op, init) in ((&, true), (|, false)) + @test reduce(op, Fill(true,0)) == reduce(op, fill(true,0)) + @test reduce(op, Fill(true,0); init) == reduce(op, fill(true,0); init) + @test reduce(op, Fill(true,0), dims=1) == reduce(op, fill(true,0), dims=1) + @test reduce(op, Fill(true,0); init, dims=1) == reduce(op, fill(true,0); init, dims=1) end - @testset for op in (max, +, *) - @test mapreduce(identity, op, Fill(2,0), dims=2) == mapreduce(identity, op, fill(2,0), dims=2) - @test mapreduce(identity, op, Fill(2,0,1), dims=2) == mapreduce(identity, op, fill(2,0,1), dims=2) - @test mapreduce(identity, op, Fill(2,1)) == mapreduce(identity, op, fill(2,1)) - @test mapreduce(identity, op, Fill(2,1), dims=1) == mapreduce(identity, op, fill(2,1), dims=1) - @test mapreduce(identity, op, Fill(2,1), dims=2) == mapreduce(identity, op, fill(2,1), dims=2) - @test mapreduce(identity, op, Fill(2,1,1), dims=1) == mapreduce(identity, op, fill(2,1,1), dims=1) - @test mapreduce(identity, op, Fill(2,1,1), dims=2) == mapreduce(identity, op, fill(2,1,1), dims=2) - @test mapreduce(identity, op, Fill(2,2)) == mapreduce(identity, op, fill(2,2)) - @test mapreduce(identity, op, Fill(2,2), dims=1) == mapreduce(identity, op, fill(2,2), dims=1) - @test mapreduce(identity, op, Fill(2,2), dims=2) == mapreduce(identity, op, fill(2,2), dims=2) - @test mapreduce(identity, op, Fill(2,2,2), dims=1) == mapreduce(identity, op, fill(2,2,2), dims=1) - @test mapreduce(identity, op, Fill(2,2,2), dims=2) == mapreduce(identity, op, fill(2,2,2), dims=2) + @test reduce(vcat, Fill(2,0), init=1) == reduce(vcat, fill(2,0), init=1) + @test reduce(vcat, Fill(2,1), init=1) == reduce(vcat, fill(2,1), init=1) + @testset for (op, init) in ((max, 0), (+, 0), (*, 1)) + @test reduce(op, Fill(2,0), dims=2) == reduce(op, fill(2,0), dims=2) + @test reduce(op, Fill(2,0,1), dims=2) == reduce(op, fill(2,0,1), dims=2) + @test reduce(op, Fill(2,0); init, dims=2) == reduce(op, fill(2,0); init, dims=2) + @test reduce(op, Fill(2,0,1); init, dims=2) == reduce(op, fill(2,0,1); init, dims=2) + + @test reduce(op, Fill(2,1)) == reduce(op, fill(2,1)) + @test reduce(op, Fill(2,1), dims=1) == reduce(op, fill(2,1), dims=1) + @test reduce(op, Fill(2,1), dims=2) == reduce(op, fill(2,1), dims=2) + @test reduce(op, Fill(2,1,1), dims=1) == reduce(op, fill(2,1,1), dims=1) + @test reduce(op, Fill(2,1,1), dims=2) == reduce(op, fill(2,1,1), dims=2) + @test reduce(op, Fill(2,1); init, dims=1) == reduce(op, fill(2,1); init, dims=1) + @test reduce(op, Fill(2,1); init, dims=2) == reduce(op, fill(2,1); init, dims=2) + @test reduce(op, Fill(2,1,1); init, dims=1) == reduce(op, fill(2,1,1); init, dims=1) + @test reduce(op, Fill(2,1,1); init, dims=2) == reduce(op, fill(2,1,1); init, dims=2) + + @test reduce(op, Fill(2,2)) == reduce(op, fill(2,2)) + @test reduce(op, Fill(2,2), dims=1) == reduce(op, fill(2,2), dims=1) + @test reduce(op, Fill(2,2), dims=2) == reduce(op, fill(2,2), dims=2) + @test reduce(op, Fill(2,2,2), dims=1) == reduce(op, fill(2,2,2), dims=1) + @test reduce(op, Fill(2,2,2), dims=2) == reduce(op, fill(2,2,2), dims=2) + @test reduce(op, Fill(2,2); init, dims=1) == reduce(op, fill(2,2); init, dims=1) + @test reduce(op, Fill(2,2); init, dims=2) == reduce(op, fill(2,2); init, dims=2) + @test reduce(op, Fill(2,2,2); init, dims=1) == reduce(op, fill(2,2,2); init, dims=1) + @test reduce(op, Fill(2,2,2); init, dims=2) == reduce(op, fill(2,2,2); init, dims=2) end end From 6480b932f2936fe75801122e7c0ef21882863e26 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 20 Oct 2023 00:26:58 +0530 Subject: [PATCH 3/4] Test function --- test/runtests.jl | 72 +++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index eb4c52a3..208b8772 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1118,45 +1118,49 @@ end @testset "corner cases with small arrays" begin @test_throws Exception reduce(max, Fill(2,0)) @test_throws Exception reduce(max, Fill(2,0), dims=1) + function testreduce(op, A; kw...) + B = Array(A) + @test reduce(op, A; kw...) == reduce(op, B; kw...) + end @testset for (op, init) in ((+, 0), (*, 1)) - @test reduce(op, Fill(2,0)) == reduce(op, fill(2,0)) - @test reduce(op, Fill(2,0); init) == reduce(op, fill(2,0); init) - @test reduce(op, Fill(2,0), dims=1) == reduce(op, fill(2,0), dims=1) - @test reduce(op, Fill(2,0); init, dims=1) == reduce(op, fill(2,0); init, dims=1) + testreduce(op, Fill(2,0)) + testreduce(op, Fill(2,0); init) + testreduce(op, Fill(2,0), dims=1) + testreduce(op, Fill(2,0); init, dims=1) end @testset for (op, init) in ((&, true), (|, false)) - @test reduce(op, Fill(true,0)) == reduce(op, fill(true,0)) - @test reduce(op, Fill(true,0); init) == reduce(op, fill(true,0); init) - @test reduce(op, Fill(true,0), dims=1) == reduce(op, fill(true,0), dims=1) - @test reduce(op, Fill(true,0); init, dims=1) == reduce(op, fill(true,0); init, dims=1) + testreduce(op, Fill(true,0)) + testreduce(op, Fill(true,0); init) + testreduce(op, Fill(true,0), dims=1) + testreduce(op, Fill(true,0); init, dims=1) end - @test reduce(vcat, Fill(2,0), init=1) == reduce(vcat, fill(2,0), init=1) - @test reduce(vcat, Fill(2,1), init=1) == reduce(vcat, fill(2,1), init=1) + testreduce(vcat, Fill(2,0), init=Int[]) + testreduce(vcat, Fill(2,1), init=Int[]) @testset for (op, init) in ((max, 0), (+, 0), (*, 1)) - @test reduce(op, Fill(2,0), dims=2) == reduce(op, fill(2,0), dims=2) - @test reduce(op, Fill(2,0,1), dims=2) == reduce(op, fill(2,0,1), dims=2) - @test reduce(op, Fill(2,0); init, dims=2) == reduce(op, fill(2,0); init, dims=2) - @test reduce(op, Fill(2,0,1); init, dims=2) == reduce(op, fill(2,0,1); init, dims=2) - - @test reduce(op, Fill(2,1)) == reduce(op, fill(2,1)) - @test reduce(op, Fill(2,1), dims=1) == reduce(op, fill(2,1), dims=1) - @test reduce(op, Fill(2,1), dims=2) == reduce(op, fill(2,1), dims=2) - @test reduce(op, Fill(2,1,1), dims=1) == reduce(op, fill(2,1,1), dims=1) - @test reduce(op, Fill(2,1,1), dims=2) == reduce(op, fill(2,1,1), dims=2) - @test reduce(op, Fill(2,1); init, dims=1) == reduce(op, fill(2,1); init, dims=1) - @test reduce(op, Fill(2,1); init, dims=2) == reduce(op, fill(2,1); init, dims=2) - @test reduce(op, Fill(2,1,1); init, dims=1) == reduce(op, fill(2,1,1); init, dims=1) - @test reduce(op, Fill(2,1,1); init, dims=2) == reduce(op, fill(2,1,1); init, dims=2) - - @test reduce(op, Fill(2,2)) == reduce(op, fill(2,2)) - @test reduce(op, Fill(2,2), dims=1) == reduce(op, fill(2,2), dims=1) - @test reduce(op, Fill(2,2), dims=2) == reduce(op, fill(2,2), dims=2) - @test reduce(op, Fill(2,2,2), dims=1) == reduce(op, fill(2,2,2), dims=1) - @test reduce(op, Fill(2,2,2), dims=2) == reduce(op, fill(2,2,2), dims=2) - @test reduce(op, Fill(2,2); init, dims=1) == reduce(op, fill(2,2); init, dims=1) - @test reduce(op, Fill(2,2); init, dims=2) == reduce(op, fill(2,2); init, dims=2) - @test reduce(op, Fill(2,2,2); init, dims=1) == reduce(op, fill(2,2,2); init, dims=1) - @test reduce(op, Fill(2,2,2); init, dims=2) == reduce(op, fill(2,2,2); init, dims=2) + testreduce(op, Fill(2,0), dims=2) + testreduce(op, Fill(2,0,1), dims=2) + testreduce(op, Fill(2,0); init, dims=2) + testreduce(op, Fill(2,0,1); init, dims=2) + + testreduce(op, Fill(2,1)) + testreduce(op, Fill(2,1), dims=1) + testreduce(op, Fill(2,1), dims=2) + testreduce(op, Fill(2,1,1), dims=1) + testreduce(op, Fill(2,1,1), dims=2) + testreduce(op, Fill(2,1); init, dims=1) + testreduce(op, Fill(2,1); init, dims=2) + testreduce(op, Fill(2,1,1); init, dims=1) + testreduce(op, Fill(2,1,1); init, dims=2) + + testreduce(op, Fill(2,2)) + testreduce(op, Fill(2,2), dims=1) + testreduce(op, Fill(2,2), dims=2) + testreduce(op, Fill(2,2,2), dims=1) + testreduce(op, Fill(2,2,2), dims=2) + testreduce(op, Fill(2,2); init, dims=1) + testreduce(op, Fill(2,2); init, dims=2) + testreduce(op, Fill(2,2,2); init, dims=1) + testreduce(op, Fill(2,2,2); init, dims=2) end end From 831933f3cdc8778a06d67ba3ff87ff42cc42e439 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 20 Oct 2023 00:53:08 +0530 Subject: [PATCH 4/4] Tests with non-trivial mapping function --- test/runtests.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 208b8772..2bacc8c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1120,7 +1120,14 @@ end @test_throws Exception reduce(max, Fill(2,0), dims=1) function testreduce(op, A; kw...) B = Array(A) - @test reduce(op, A; kw...) == reduce(op, B; kw...) + F = reduce(op, A; kw...) + @test F == reduce(op, B; kw...) + if haskey(kw, :dims) + @test F isa Fill + end + if !isempty(A) + @test mapreduce(x->x^2, op, A; kw...) == mapreduce(x->x^2, op, B; kw...) + end end @testset for (op, init) in ((+, 0), (*, 1)) testreduce(op, Fill(2,0))