diff --git a/src/linalg.jl b/src/linalg.jl index 9cf91d29..af7fe50d 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -650,6 +650,62 @@ function dot(A::AbstractSparseMatrixCSC, B::Union{DenseMatrixUnion,WrapperMatrix return conj(dot(B, A)) end +function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractVector) + d = D.diag + if length(x) != length(y) || length(y) != length(d) + throw( + DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), D has side dimension $(length(d))") + ) + end + nzvals = nonzeros(x) + nzinds = nonzeroinds(x) + s = zero(typeof(dot(first(x), first(D), first(y)))) + @inbounds for nzidx in eachindex(nzvals) + s += dot(nzvals[nzidx], d[nzinds[nzidx]], y[nzinds[nzidx]]) + end + return s +end + +dot(x::AbstractVector, D::Diagonal, y::AbstractSparseVector) = adjoint(dot(y, D', x)) + +function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractSparseVector) + d = D.diag + if length(y) != length(x) || length(y) != length(d) + throw( + DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))") + ) + end + xnzind = nonzeroinds(x) + ynzind = nonzeroinds(y) + xnzval = nonzeros(x) + ynzval = nonzeros(y) + s = zero(typeof(dot(first(x), first(D), first(y)))) + if isempty(xnzind) || isempty(ynzind) + return s + end + + x_idx = 1 + y_idx = 1 + x_idx_last = length(xnzind) + y_idx_last = length(ynzind) + + # go through the nonzero indices of a and b simultaneously + @inbounds while x_idx <= x_idx_last && y_idx <= y_idx_last + ix = xnzind[x_idx] + iy = ynzind[y_idx] + if ix == iy + s += dot(xnzval[x_idx], d[ix], ynzval[y_idx]) + x_idx += 1 + y_idx += 1 + elseif ix < iy + x_idx += 1 + else + y_idx += 1 + end + end + return s +end + ## triangular sparse handling ## triangular multiplication function LinearAlgebra.generic_trimatmul!(C::StridedVecOrMat, uploc, isunitc, tfun::Function, A::SparseMatrixCSCUnion, B::AbstractVecOrMat) diff --git a/test/linalg.jl b/test/linalg.jl index d3f004ca..afa6eb47 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -850,6 +850,20 @@ end @test dot(TA,WB) ≈ dot(Matrix(TA), WB) @test dot(TA,WC) ≈ dot(Matrix(TA), WC) end + for M in (A, B, C) + D = Diagonal(M * M') + a = spzeros(Complex{Float64}, size(D, 1)) + a[1:3] = rand(Complex{Float64}, 3) + b = spzeros(Complex{Float64}, size(D, 1)) + b[1:3] = rand(Complex{Float64}, 3) + @test dot(a, D, b) ≈ dot(a, sparse(D), b) + @test dot(b, D, a) ≈ dot(b, sparse(D), a) + @test dot(b, D, a) ≈ dot(b, D, collect(a)) + @test dot(b, D, a) ≈ dot(collect(b), D, a) + @test_throws DimensionMismatch dot(b, D, [a; 1]) + @test_throws DimensionMismatch dot([b; 1], D, a) + @test_throws DimensionMismatch dot([b; 1], D, [a; 1]) + end end @test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2)) @test_throws DimensionMismatch dot(rand(5,5),sprand(5,6,0.2))