Skip to content

Commit

Permalink
Bugfix in matrix multiplication involving adj/trans (#360)
Browse files Browse the repository at this point in the history
* Bugfix in matrix multiplication

* _mulfill implementation for all matrix orderings

* Loop in copyfirstrow!

* Revert some unnecessary changes
  • Loading branch information
jishnub authored Apr 22, 2024
1 parent 7177c59 commit 8509b6b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 32 deletions.
92 changes: 60 additions & 32 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,21 @@ for (T, f) in ((:Adjoint, :adjoint), (:Transpose, :transpose))
end
end

function mul!(C::AbstractMatrix, A::AbstractFillMatrix, B::AbstractFillMatrix, alpha::Number, beta::Number)
# unnecessary indirection, added for ambiguity resolution
function _mulfill!(C::AbstractMatrix, A::AbstractFillMatrix, B::AbstractFillMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
ABα = getindex_value(A) * getindex_value(B) * alpha * size(B,1)
if iszero(beta)
C .= ABα
else
C .= ABα .+ C .* beta
end
C
return C
end

function mul!(C::AbstractMatrix, A::AbstractFillMatrix, B::AbstractFillMatrix, alpha::Number, beta::Number)
_mulfill!(C, A, B, alpha, beta)
return C
end

function copyfirstcol!(C)
Expand All @@ -229,50 +235,72 @@ function copyfirstcol!(C)
end
return C
end
function copyfirstcol!(C::Union{Adjoint, Transpose})
# in this case, we copy the first row of the parent to others
Cp = parent(C)
for colind in axes(Cp, 2)
Cp[2:end, colind] .= Cp[1, colind]

_firstcol(C::AbstractMatrix) = first(eachcol(C))

function copyfirstrow!(C)
# C[begin+1:end, ind] .= permutedims(_firstrow(C))
# we loop here as the aliasing check isn't smart enough to
# detect that the two sides don't alias, and ends up materializing the RHS
for (ind, v) in pairs(_firstrow(C))
C[begin+1:end, ind] .= Ref(v)
end
return C
end
_firstrow(C::AbstractMatrix) = first(eachrow(C))

_firstcol(C::AbstractMatrix) = view(C, :, 1)
_firstcol(C::Union{Adjoint, Transpose}) = view(parent(C), 1, :)

function _mulfill!(C, A, B::AbstractFillMatrix, alpha, beta)
function _mulfill!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractFillMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
iszero(size(B,2)) && return C # no columns in B and C, empty matrix
if iszero(beta)
# the mat-vec product sums along the rows of A
mul!(_firstcol(C), A, _firstcol(B), alpha, beta)
copyfirstcol!(C)
else
# the mat-vec product sums along the rows of A, which produces the first column of ABα
# allocate a temporary column vector to store the result
v = A * (_firstcol(B) * alpha)
C .= v .+ C .* beta
end
return C
end
function _mulfill!(C::AbstractMatrix, A::AbstractFillMatrix, B::AbstractMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(size(B,2))
return rmul!(C, beta)
iszero(size(A,1)) && return C # no rows in A and C, empty matrix
Aval = getindex_value(A)
if iszero(beta)
Crow = _firstrow(C)
# sum along the columns of B
Crow .= Ref(Aval) .* sum.(eachcol(B)) .* alpha
copyfirstrow!(C)
else
# sum along the columns of B, and allocate the result.
# This is the first row of ABα
ABα_row = Ref(Aval) .* sum.(eachcol(B)) .* alpha
C .= permutedims(ABα_row) .+ C .* beta
end
mul!(_firstcol(C), A, view(B, :, 1), alpha, beta)
copyfirstcol!(C)
C
return C
end

function mul!(C::StridedMatrix, A::StridedMatrix, B::AbstractFillMatrix, alpha::Number, beta::Number)
_mulfill!(C, A, B, alpha, beta)
return C
end

for T in (:Adjoint, :Transpose)
@eval function mul!(C::StridedMatrix, A::$T{<:Any, <:StridedMatrix}, B::AbstractFillMatrix, alpha::Number, beta::Number)
_mulfill!(C, A, B, alpha, beta)
end
end

function mul!(C::StridedMatrix, A::AbstractFillMatrix, B::StridedMatrix, alpha::Number, beta::Number)
check_matmul_sizes(C, A, B)
for (colC, colB) in zip(eachcol(C), eachcol(B))
mul!(colC, A, colB, alpha, beta)
end
C
_mulfill!(C, A, B, alpha, beta)
return C
end

for (T, f) in ((:Adjoint, :adjoint), (:Transpose, :transpose))
@eval function mul!(C::StridedMatrix, A::AbstractFillMatrix, B::$T{<:Any, <:StridedMatrix}, alpha::Number, beta::Number)
_mulfill!($f(C), $f(B), $f(A), alpha, beta)
C
for T in (:Adjoint, :Transpose)
@eval begin
function mul!(C::StridedMatrix, A::$T{<:Any, <:StridedMatrix}, B::AbstractFillMatrix, alpha::Number, beta::Number)
_mulfill!(C, A, B, alpha, beta)
return C
end
function mul!(C::StridedMatrix, A::AbstractFillMatrix, B::$T{<:Any, <:StridedMatrix}, alpha::Number, beta::Number)
_mulfill!(C, A, B, alpha, beta)
return C
end
end
end

Expand Down
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,16 @@ end
@test transpose(A)*fillmat transpose(A)*Array(fillmat)
@test adjoint(A)*fillvec adjoint(A)*Array(fillvec)
@test adjoint(A)*fillmat adjoint(A)*Array(fillmat)

# inplace C = F * B' * alpha + C * beta
F = Fill(fv, m, k)
A = Array(F)
B = rand(T, n, k)
C = rand(T, m, n)
@testset for f in (adjoint, transpose)
@test mul!(copy(C), F, f(B)) mul!(copy(C), A, f(B))
@test mul!(copy(C), F, f(B), 1.0, 2.0) mul!(copy(C), A, f(B), 1.0, 2.0)
end
end

@testset "non-commutative" begin
Expand Down

0 comments on commit 8509b6b

Please sign in to comment.