diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 1fdab27d..3214cbed 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, +, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero, - show, view, in, mapreduce, one, reverse, promote_op, promote_rule + show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, @@ -762,6 +762,43 @@ Base.@propagate_inbounds function view(A::AbstractFill, I::Vararg{Real}) fillsimilar(A) end +# repeat + +_first(t::Tuple) = t[1] +_first(t::Tuple{}) = 1 + +_maybetail(t::Tuple) = Base.tail(t) +_maybetail(t::Tuple{}) = t + +_match_size(sz::Tuple{}, inner::Tuple{}, outer::Tuple{}) = () +function _match_size(sz::Tuple, inner::Tuple, outer::Tuple) + t1 = (_first(sz), _first(inner), _first(outer)) + t2 = _match_size(_maybetail(sz), _maybetail(inner), _maybetail(outer)) + (t1, t2...) +end + +function _repeat_size(sz::Tuple, inner::Tuple, outer::Tuple) + t = _match_size(sz, inner, outer) + map(*, getindex.(t, 1), getindex.(t, 2), getindex.(t, 3)) +end + +function _repeat(A; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) + Base.require_one_based_indexing(A) + length(inner) >= ndims(A) || + throw(ArgumentError("number of inner repetitions $(length(inner)) cannot be "* + "less than number of dimensions of input array $(ndims(A))")) + length(outer) >= ndims(A) || + throw(ArgumentError("number of outer repetitions $(length(outer)) cannot be "* + "less than number of dimensions of input array $(ndims(A))")) + sz = _repeat_size(size(A), Tuple(inner), Tuple(outer)) + fillsimilar(A, sz) +end + +repeat(A::AbstractFill, count::Integer...) = _repeat(A, outer=count) +function repeat(A::AbstractFill; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) + _repeat(A, inner=inner, outer=outer) +end + include("oneelement.jl") end # module diff --git a/test/runtests.jl b/test/runtests.jl index 6fb33b33..4c93823d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1706,3 +1706,109 @@ end end end end + +@testset "repeat" begin + @testset "0D" begin + @test repeat(Zeros()) isa Zeros + @test repeat(Zeros()) == repeat(zeros()) + @test repeat(Ones()) isa Ones + @test repeat(Ones()) == repeat(ones()) + @test repeat(Fill(3)) isa Fill + @test repeat(Fill(3)) == repeat(fill(3)) + + @test repeat(Zeros(), inner=(), outer=()) isa Zeros + @test repeat(Zeros(), inner=(), outer=()) == repeat(zeros(), inner=(), outer=()) + @test repeat(Ones(), inner=(), outer=()) isa Ones + @test repeat(Ones(), inner=(), outer=()) == repeat(ones(), inner=(), outer=()) + @test repeat(Fill(4), inner=(), outer=()) isa Fill + @test repeat(Fill(4), inner=(), outer=()) == repeat(fill(4), inner=(), outer=()) + + @test repeat(Zeros{Bool}(), 2, 3) isa Zeros{Bool} + @test repeat(Zeros{Bool}(), 2, 3) == repeat(zeros(Bool), 2, 3) + @test repeat(Ones{Bool}(), 2, 3) isa Ones{Bool} + @test repeat(Ones{Bool}(), 2, 3) == repeat(ones(Bool), 2, 3) + @test repeat(Fill(false), 2, 3) isa Fill + @test repeat(Fill(false), 2, 3) == repeat(fill(false), 2, 3) + + @test repeat(Zeros(), inner=(2,2), outer=5) isa Zeros + @test repeat(Zeros(), inner=(2,2), outer=5) == repeat(zeros(), inner=(2,2), outer=5) + @test repeat(Ones(), inner=(2,2), outer=5) isa Ones + @test repeat(Ones(), inner=(2,2), outer=5) == repeat(ones(), inner=(2,2), outer=5) + @test repeat(Fill(2), inner=(2,2), outer=5) isa Fill + @test repeat(Fill(2), inner=(2,2), outer=5) == repeat(fill(2), inner=(2,2), outer=5) + + @test repeat(Zeros(), inner=(2,2), outer=(2,3)) isa Zeros + @test repeat(Zeros(), inner=(2,2), outer=(2,3)) == repeat(zeros(), inner=(2,2), outer=(2,3)) + @test repeat(Ones(), inner=(2,2), outer=(2,3)) isa Ones + @test repeat(Ones(), inner=(2,2), outer=(2,3)) == repeat(ones(), inner=(2,2), outer=(2,3)) + @test repeat(Fill("a"), inner=(2,2), outer=(2,3)) isa Fill + @test repeat(Fill("a"), inner=(2,2), outer=(2,3)) == repeat(fill("a"), inner=(2,2), outer=(2,3)) + end + @testset "1D" begin + @test repeat(Zeros(2), 2, 3) isa Zeros + @test repeat(Zeros(2), 2, 3) == repeat(zeros(2), 2, 3) + @test repeat(Ones(2), 2, 3) isa Ones + @test repeat(Ones(2), 2, 3) == repeat(ones(2), 2, 3) + @test repeat(Fill(2,3), 2, 3) isa Fill + @test repeat(Fill(2,3), 2, 3) == repeat(fill(2,3), 2, 3) + + @test repeat(Zeros(2), inner=2, outer=4) isa Zeros + @test repeat(Zeros(2), inner=2, outer=4) == repeat(zeros(2), inner=2, outer=4) + @test repeat(Ones(2), inner=2, outer=4) isa Ones + @test repeat(Ones(2), inner=2, outer=4) == repeat(ones(2), inner=2, outer=4) + @test repeat(Fill(2,3), inner=2, outer=4) isa Fill + @test repeat(Fill(2,3), inner=2, outer=4) == repeat(fill(2,3), inner=2, outer=4) + + @test repeat(Zeros(2), inner=2, outer=(2,3)) isa Zeros + @test repeat(Zeros(2), inner=2, outer=(2,3)) == repeat(zeros(2), inner=2, outer=(2,3)) + @test repeat(Ones(2), inner=2, outer=(2,3)) isa Ones + @test repeat(Ones(2), inner=2, outer=(2,3)) == repeat(ones(2), inner=2, outer=(2,3)) + @test repeat(Fill("b",3), inner=2, outer=(2,3)) isa Fill + @test repeat(Fill("b",3), inner=2, outer=(2,3)) == repeat(fill("b",3), inner=2, outer=(2,3)) + + @test repeat(Zeros(Int, 2), inner=(2,), outer=(2,3)) isa Zeros + @test repeat(Zeros(Int, 2), inner=(2,), outer=(2,3)) == repeat(zeros(Int, 2), inner=(2,), outer=(2,3)) + @test repeat(Ones(Int, 2), inner=(2,), outer=(2,3)) isa Ones + @test repeat(Ones(Int, 2), inner=(2,), outer=(2,3)) == repeat(ones(Int, 2), inner=(2,), outer=(2,3)) + @test repeat(Fill(2,3), inner=(2,), outer=(2,3)) isa Fill + @test repeat(Fill(2,3), inner=(2,), outer=(2,3)) == repeat(fill(2,3), inner=(2,), outer=(2,3)) + + @test repeat(Zeros(2), inner=(2,2,1,4), outer=(2,3)) isa Zeros + @test repeat(Zeros(2), inner=(2,2,1,4), outer=(2,3)) == repeat(zeros(2), inner=(2,2,1,4), outer=(2,3)) + @test repeat(Ones(2), inner=(2,2,1,4), outer=(2,3)) isa Ones + @test repeat(Ones(2), inner=(2,2,1,4), outer=(2,3)) == repeat(ones(2), inner=(2,2,1,4), outer=(2,3)) + @test repeat(Fill(2,3), inner=(2,2,1,4), outer=(2,3)) isa Fill + @test repeat(Fill(2,3), inner=(2,2,1,4), outer=(2,3)) == repeat(fill(2,3), inner=(2,2,1,4), outer=(2,3)) + + @test_throws ArgumentError repeat(Fill(2,3), inner=()) + @test_throws ArgumentError repeat(Fill(2,3), outer=()) + end + + @testset "2D" begin + @test repeat(Zeros(2,3), 2, 3) isa Zeros + @test repeat(Zeros(2,3), 2, 3) == repeat(zeros(2,3), 2, 3) + @test repeat(Ones(2,3), 2, 3) isa Ones + @test repeat(Ones(2,3), 2, 3) == repeat(ones(2,3), 2, 3) + @test repeat(Fill(2,3,4), 2, 3) isa Fill + @test repeat(Fill(2,3,4), 2, 3) == repeat(fill(2,3,4), 2, 3) + + @test repeat(Zeros(2,3), inner=(1,2), outer=(4,2)) isa Zeros + @test repeat(Zeros(2,3), inner=(1,2), outer=(4,2)) == repeat(zeros(2,3), inner=(1,2), outer=(4,2)) + @test repeat(Ones(2,3), inner=(1,2), outer=(4,2)) isa Ones + @test repeat(Ones(2,3), inner=(1,2), outer=(4,2)) == repeat(ones(2,3), inner=(1,2), outer=(4,2)) + @test repeat(Fill(2,3,4), inner=(1,2), outer=(4,2)) isa Fill + @test repeat(Fill(2,3,4), inner=(1,2), outer=(4,2)) == repeat(fill(2,3,4), inner=(1,2), outer=(4,2)) + + @test repeat(Zeros(2,3), inner=(2,2,1,4), outer=(2,1,3)) isa Zeros + @test repeat(Zeros(2,3), inner=(2,2,1,4), outer=(2,1,3)) == repeat(zeros(2,3), inner=(2,2,1,4), outer=(2,1,3)) + @test repeat(Ones(2,3), inner=(2,2,1,4), outer=(2,1,3)) isa Ones + @test repeat(Ones(2,3), inner=(2,2,1,4), outer=(2,1,3)) == repeat(ones(2,3), inner=(2,2,1,4), outer=(2,1,3)) + @test repeat(Fill(2,3,4), inner=(2,2,1,4), outer=(2,1,3)) isa Fill + @test repeat(Fill(2,3,4), inner=(2,2,1,4), outer=(2,1,3)) == repeat(fill(2,3,4), inner=(2,2,1,4), outer=(2,1,3)) + + @test_throws ArgumentError repeat(Fill(2,3,4), inner=()) + @test_throws ArgumentError repeat(Fill(2,3,4), outer=()) + @test_throws ArgumentError repeat(Fill(2,3,4), inner=(1,)) + @test_throws ArgumentError repeat(Fill(2,3,4), outer=(1,)) + end +end