Skip to content

Commit

Permalink
Reduce mult_fill methods (#243)
Browse files Browse the repository at this point in the history
* reduce mult_fill methods

* call mult_axes with axes

* disallow offset arrays in matrix multiplication

* remove mult_axes

* bring back mult_zeros and mult_ones
  • Loading branch information
jishnub authored May 24, 2023
1 parent 852da1e commit 427e8d3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
48 changes: 19 additions & 29 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,25 @@ reverse(A::AbstractFill; dims=:) = A

## Algebraic identities


function mult_fill(a::AbstractFill, b::AbstractFillMatrix)
axes(a, 2) axes(b, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1), axes(b, 2)))
end

function mult_fill(a::AbstractFill, b::AbstractFillVector)
axes(a, 2) axes(b, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1),))
function _mult_fill(a::AbstractFill, b::AbstractFill, ax, ::Type{Fill})
val = getindex_value(a)*getindex_value(b)*size(a,2)
return Fill(val, ax)
end

function mult_ones(a::AbstractVector, b::AbstractMatrix)
axes(a, 2) axes(b, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2)))
function _mult_fill(a, b, ax, ::Type{OnesZeros}) where {OnesZeros}
ElType = promote_type(eltype(a), eltype(b))
return OnesZeros{ElType}(ax)
end

function mult_zeros(a, b::AbstractMatrix)
axes(a, 2) axes(b, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
return Zeros{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2)))
end
function mult_zeros(a, b::AbstractVector)
axes(a, 2) axes(b, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
return Zeros{promote_type(eltype(a), eltype(b))}((axes(a, 1),))
function mult_fill(a, b, T::Type = Fill)
Base.require_one_based_indexing(a, b)
size(a, 2) size(b, 1) &&
throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
ax_result = (axes(a, 1), axes(b)[2:end]...)
_mult_fill(a, b, ax_result, T)
end
mult_zeros(a, b) = mult_fill(a, b, Zeros)
mult_ones(a, b) = mult_fill(a, b, Ones)

*(a::AbstractFillVector, b::AbstractFillMatrix) = mult_fill(a,b)
*(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b)
Expand All @@ -74,9 +64,9 @@ end
*(a::ZerosVector, b::AbstractFillMatrix) = mult_zeros(a, b)
*(a::ZerosMatrix, b::AbstractFillMatrix) = mult_zeros(a, b)
*(a::ZerosMatrix, b::AbstractFillVector) = mult_zeros(a, b)
*(a::AbstractFillVector, b::ZerosMatrix) = mult_zeros(a,b)
*(a::AbstractFillMatrix, b::ZerosMatrix) = mult_zeros(a,b)
*(a::AbstractFillMatrix, b::ZerosVector) = mult_zeros(a,b)
*(a::AbstractFillVector, b::ZerosMatrix) = mult_zeros(a, b)
*(a::AbstractFillMatrix, b::ZerosMatrix) = mult_zeros(a, b)
*(a::AbstractFillMatrix, b::ZerosVector) = mult_zeros(a, b)

*(a::ZerosVector, b::AbstractMatrix) = mult_zeros(a, b)
*(a::ZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b)
Expand All @@ -93,11 +83,11 @@ end
*(a::Diagonal, b::ZerosMatrix) = mult_zeros(a, b)
function *(a::Diagonal, b::AbstractFillMatrix)
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a.diag .* b # use special broadcast
parent(a) .* b # use special broadcast
end
function *(a::AbstractFillMatrix, b::Diagonal)
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a .* permutedims(b.diag) # use special broadcast
a .* permutedims(parent(b)) # use special broadcast
end

@noinline function check_matmul_sizes(y::AbstractVector, A::AbstractMatrix, x::AbstractVector)
Expand Down
12 changes: 12 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,12 @@ end
@test_throws DimensionMismatch Fill(2,10)*Fill(3,2,12)
@test_throws DimensionMismatch Fill(2,3,10)*Fill(3,2,12)

f = Fill(1, (Base.IdentityUnitRange(1:3), Base.IdentityUnitRange(1:3)))
@test f * f === Fill(size(f,2), axes(f))

f = Fill(2, (Base.IdentityUnitRange(2:3), Base.IdentityUnitRange(2:3)))
@test_throws ArgumentError f * f

@test Ones(10)*Fill(3,1,12) Fill(3.0,10,12)
@test Ones(10,3)*Fill(3,3,12) Fill(9.0,10,12)
@test Ones(10,3)*Fill(3,3) Fill(9.0,10)
Expand All @@ -1272,6 +1278,12 @@ end
@test Zeros(10)*Zeros(1,12) Zeros(10,12)
@test Zeros(3,10)*Zeros(10,12) Zeros(3,12)
@test Zeros(3,10)*Zeros(10) Zeros(3)

f = Zeros((Base.IdentityUnitRange(1:4), Base.IdentityUnitRange(1:4)))
@test f * f === f

f = Zeros((Base.IdentityUnitRange(3:4), Base.IdentityUnitRange(3:4)))
@test_throws ArgumentError f * f

for W in (zeros(3,4), @MMatrix zeros(3,4))
mW, nW = size(W)
Expand Down

0 comments on commit 427e8d3

Please sign in to comment.