diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 241ac3fc..adfb5be1 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -31,35 +31,25 @@ reverse(A::AbstractFill; dims=:) = A ## Algebraic identities - -function mult_fill(a::AbstractFill, b::AbstractFillMatrix) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1), axes(b, 2))) -end - -function mult_fill(a::AbstractFill, b::AbstractFillVector) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1),)) +function _mult_fill(a::AbstractFill, b::AbstractFill, ax, ::Type{Fill}) + val = getindex_value(a)*getindex_value(b)*size(a,2) + return Fill(val, ax) end -function mult_ones(a::AbstractVector, b::AbstractMatrix) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2))) +function _mult_fill(a, b, ax, ::Type{OnesZeros}) where {OnesZeros} + ElType = promote_type(eltype(a), eltype(b)) + return OnesZeros{ElType}(ax) end -function mult_zeros(a, b::AbstractMatrix) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Zeros{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2))) -end -function mult_zeros(a, b::AbstractVector) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Zeros{promote_type(eltype(a), eltype(b))}((axes(a, 1),)) +function mult_fill(a, b, T::Type = Fill) + Base.require_one_based_indexing(a, b) + size(a, 2) ≠ size(b, 1) && + throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) + ax_result = (axes(a, 1), axes(b)[2:end]...) + _mult_fill(a, b, ax_result, T) end +mult_zeros(a, b) = mult_fill(a, b, Zeros) +mult_ones(a, b) = mult_fill(a, b, Ones) *(a::AbstractFillVector, b::AbstractFillMatrix) = mult_fill(a,b) *(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b) @@ -74,9 +64,9 @@ end *(a::ZerosVector, b::AbstractFillMatrix) = mult_zeros(a, b) *(a::ZerosMatrix, b::AbstractFillMatrix) = mult_zeros(a, b) *(a::ZerosMatrix, b::AbstractFillVector) = mult_zeros(a, b) -*(a::AbstractFillVector, b::ZerosMatrix) = mult_zeros(a,b) -*(a::AbstractFillMatrix, b::ZerosMatrix) = mult_zeros(a,b) -*(a::AbstractFillMatrix, b::ZerosVector) = mult_zeros(a,b) +*(a::AbstractFillVector, b::ZerosMatrix) = mult_zeros(a, b) +*(a::AbstractFillMatrix, b::ZerosMatrix) = mult_zeros(a, b) +*(a::AbstractFillMatrix, b::ZerosVector) = mult_zeros(a, b) *(a::ZerosVector, b::AbstractMatrix) = mult_zeros(a, b) *(a::ZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b) @@ -93,11 +83,11 @@ end *(a::Diagonal, b::ZerosMatrix) = mult_zeros(a, b) function *(a::Diagonal, b::AbstractFillMatrix) size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) - a.diag .* b # use special broadcast + parent(a) .* b # use special broadcast end function *(a::AbstractFillMatrix, b::Diagonal) size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) - a .* permutedims(b.diag) # use special broadcast + a .* permutedims(parent(b)) # use special broadcast end @noinline function check_matmul_sizes(y::AbstractVector, A::AbstractMatrix, x::AbstractVector) diff --git a/test/runtests.jl b/test/runtests.jl index bd136843..e81cc6a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1249,6 +1249,12 @@ end @test_throws DimensionMismatch Fill(2,10)*Fill(3,2,12) @test_throws DimensionMismatch Fill(2,3,10)*Fill(3,2,12) + f = Fill(1, (Base.IdentityUnitRange(1:3), Base.IdentityUnitRange(1:3))) + @test f * f === Fill(size(f,2), axes(f)) + + f = Fill(2, (Base.IdentityUnitRange(2:3), Base.IdentityUnitRange(2:3))) + @test_throws ArgumentError f * f + @test Ones(10)*Fill(3,1,12) ≡ Fill(3.0,10,12) @test Ones(10,3)*Fill(3,3,12) ≡ Fill(9.0,10,12) @test Ones(10,3)*Fill(3,3) ≡ Fill(9.0,10) @@ -1272,6 +1278,12 @@ end @test Zeros(10)*Zeros(1,12) ≡ Zeros(10,12) @test Zeros(3,10)*Zeros(10,12) ≡ Zeros(3,12) @test Zeros(3,10)*Zeros(10) ≡ Zeros(3) + + f = Zeros((Base.IdentityUnitRange(1:4), Base.IdentityUnitRange(1:4))) + @test f * f === f + + f = Zeros((Base.IdentityUnitRange(3:4), Base.IdentityUnitRange(3:4))) + @test_throws ArgumentError f * f for W in (zeros(3,4), @MMatrix zeros(3,4)) mW, nW = size(W)