From f922d5ce5592e9bd22976d28f8cf139bd638272d Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Thu, 26 May 2016 16:32:08 -0400 Subject: [PATCH] Ensure that multiplication with QRQs doesn't fall back to generic_matmatmul! --- base/linalg/qr.jl | 54 +++++++++++++++++++++++++++++++++++------------ test/linalg/qr.jl | 3 +++ 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/base/linalg/qr.jl b/base/linalg/qr.jl index 3ccad5b12b66a..3b90c22b29253 100644 --- a/base/linalg/qr.jl +++ b/base/linalg/qr.jl @@ -265,7 +265,7 @@ end ### QB A_mul_B!{T<:BlasFloat}(A::QRCompactWYQ{T}, B::StridedVecOrMat{T}) = LAPACK.gemqrt!('L','N',A.factors,A.T,B) A_mul_B!{T<:BlasFloat}(A::QRPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormqr!('L','N',A.factors,A.τ,B) -function A_mul_B!{T}(A::QRPackedQ{T}, B::AbstractVecOrMat{T}) +function A_mul_B!(A::QRPackedQ, B::AbstractVecOrMat) mA, nA = size(A.factors) mB, nB = size(B,1), size(B,2) if mA != mB @@ -290,8 +290,8 @@ function A_mul_B!{T}(A::QRPackedQ{T}, B::AbstractVecOrMat{T}) B end -function (*){TA,Tb}(A::Union{QRPackedQ{TA},QRCompactWYQ{TA}}, b::StridedVector{Tb}) - TAb = promote_type(TA, Tb) +function (*)(A::Union{QRPackedQ,QRCompactWYQ}, b::StridedVector) + TAb = promote_type(eltype(A), eltype(b)) Anew = convert(AbstractMatrix{TAb}, A) if size(A.factors, 1) == length(b) bnew = copy_oftype(b, TAb) @@ -302,8 +302,8 @@ function (*){TA,Tb}(A::Union{QRPackedQ{TA},QRCompactWYQ{TA}}, b::StridedVector{T end A_mul_B!(Anew, bnew) end -function (*){TA,TB}(A::Union{QRPackedQ{TA},QRCompactWYQ{TA}}, B::StridedMatrix{TB}) - TAB = promote_type(TA, TB) +function (*)(A::Union{QRPackedQ,QRCompactWYQ}, B::StridedMatrix) + TAB = promote_type(eltype(A), eltype(B)) Anew = convert(AbstractMatrix{TAB}, A) if size(A.factors, 1) == size(B, 1) Bnew = copy_oftype(B, TAB) @@ -320,7 +320,7 @@ Ac_mul_B!{T<:BlasReal}(A::QRCompactWYQ{T}, B::StridedVecOrMat{T}) = LAPACK.gemqr Ac_mul_B!{T<:BlasComplex}(A::QRCompactWYQ{T}, B::StridedVecOrMat{T}) = LAPACK.gemqrt!('L','C',A.factors,A.T,B) Ac_mul_B!{T<:BlasReal}(A::QRPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormqr!('L','T',A.factors,A.τ,B) Ac_mul_B!{T<:BlasComplex}(A::QRPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormqr!('L','C',A.factors,A.τ,B) -function Ac_mul_B!{T}(A::QRPackedQ{T}, B::AbstractVecOrMat{T}) +function Ac_mul_B!(A::QRPackedQ, B::AbstractVecOrMat) mA, nA = size(A.factors) mB, nB = size(B,1), size(B,2) if mA != mB @@ -344,15 +344,28 @@ function Ac_mul_B!{T}(A::QRPackedQ{T}, B::AbstractVecOrMat{T}) end B end -function Ac_mul_B{TQ<:Number,TB<:Number,N}(Q::Union{QRPackedQ{TQ},QRCompactWYQ{TQ}}, B::StridedArray{TB,N}) - TQB = promote_type(TQ,TB) +function Ac_mul_B(Q::Union{QRPackedQ,QRCompactWYQ}, B::StridedVecOrMat) + TQB = promote_type(eltype(Q), eltype(B)) return Ac_mul_B!(convert(AbstractMatrix{TQB}, Q), copy_oftype(B, TQB)) end +### QBc/QcBc +for (f1, f2) in ((:A_mul_Bc, :A_mul_B!), + (:Ac_mul_Bc, :Ac_mul_B!)) + @eval begin + function ($f1)(Q::Union{QRPackedQ,QRCompactWYQ}, B::StridedVecOrMat) + TQB = promote_type(eltype(Q), eltype(B)) + Bc = similar(B, TQB, (size(B, 2), size(B, 1))) + ctranspose!(Bc, B) + return ($f2)(convert(AbstractMatrix{TQB}, Q), Bc) + end + end +end + ### AQ A_mul_B!{T<:BlasFloat}(A::StridedVecOrMat{T}, B::QRCompactWYQ{T}) = LAPACK.gemqrt!('R','N', B.factors, B.T, A) A_mul_B!{T<:BlasFloat}(A::StridedVecOrMat{T}, B::QRPackedQ{T}) = LAPACK.ormqr!('R', 'N', B.factors, B.τ, A) -function A_mul_B!{T}(A::StridedMatrix{T},Q::QRPackedQ{T}) +function A_mul_B!(A::StridedMatrix,Q::QRPackedQ) mQ, nQ = size(Q.factors) mA, nA = size(A,1), size(A,2) if nA != mQ @@ -377,8 +390,8 @@ function A_mul_B!{T}(A::StridedMatrix{T},Q::QRPackedQ{T}) A end -function (*){TA,TQ,N}(A::StridedArray{TA,N}, Q::Union{QRPackedQ{TQ},QRCompactWYQ{TQ}}) - TAQ = promote_type(TA, TQ) +function (*)(A::StridedMatrix, Q::Union{QRPackedQ,QRCompactWYQ}) + TAQ = promote_type(eltype(A), eltype(Q)) return A_mul_B!(copy_oftype(A, TAQ), convert(AbstractMatrix{TAQ}, Q)) end @@ -387,7 +400,7 @@ A_mul_Bc!{T<:BlasReal}(A::StridedVecOrMat{T}, B::QRCompactWYQ{T}) = LAPACK.gemqr A_mul_Bc!{T<:BlasComplex}(A::StridedVecOrMat{T}, B::QRCompactWYQ{T}) = LAPACK.gemqrt!('R','C',B.factors,B.T,A) A_mul_Bc!{T<:BlasReal}(A::StridedVecOrMat{T}, B::QRPackedQ{T}) = LAPACK.ormqr!('R','T',B.factors,B.τ,A) A_mul_Bc!{T<:BlasComplex}(A::StridedVecOrMat{T}, B::QRPackedQ{T}) = LAPACK.ormqr!('R','C',B.factors,B.τ,A) -function A_mul_Bc!{T}(A::AbstractMatrix{T},Q::QRPackedQ{T}) +function A_mul_Bc!(A::AbstractMatrix,Q::QRPackedQ) mQ, nQ = size(Q.factors) mA, nA = size(A,1), size(A,2) if nA != mQ @@ -411,8 +424,8 @@ function A_mul_Bc!{T}(A::AbstractMatrix{T},Q::QRPackedQ{T}) end A end -function A_mul_Bc{TA,TB}(A::AbstractMatrix{TA}, B::Union{QRCompactWYQ{TB},QRPackedQ{TB}}) - TAB = promote_type(TA,TB) +function A_mul_Bc(A::AbstractMatrix, B::Union{QRCompactWYQ,QRPackedQ}) + TAB = promote_type(eltype(A),eltype(B)) BB = convert(AbstractMatrix{TAB}, B) if size(A,2) == size(B.factors, 1) AA = similar(A, TAB, size(A)) @@ -425,6 +438,19 @@ function A_mul_Bc{TA,TB}(A::AbstractMatrix{TA}, B::Union{QRCompactWYQ{TB},QRPack end end +### AcQ/AcQc +for (f1, f2) in ((:Ac_mul_B, :A_mul_B!), + (:Ac_mul_Bc, :A_mul_Bc!)) + @eval begin + function ($f1)(A::StridedVecOrMat, Q::Union{QRPackedQ,QRCompactWYQ}) + TAQ = promote_type(eltype(A), eltype(Q)) + Ac = similar(A, TAQ, (size(A, 2), size(A, 1))) + ctranspose!(Ac, A) + return ($f2)(Ac, convert(AbstractMatrix{TAQ}, Q)) + end + end +end + A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, b::StridedVector{T}) = (A_ldiv_B!(UpperTriangular(A[:R]), sub(Ac_mul_B!(A[:Q], b), 1:size(A, 2))); b) A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, B::StridedMatrix{T}) = (A_ldiv_B!(UpperTriangular(A[:R]), sub(Ac_mul_B!(A[:Q], B), 1:size(A, 2), 1:size(B, 2))); B) diff --git a/test/linalg/qr.jl b/test/linalg/qr.jl index b51615af979ec..48c29caea0377 100644 --- a/test/linalg/qr.jl +++ b/test/linalg/qr.jl @@ -48,6 +48,9 @@ debug && println("QR decomposition (without pivoting)") @test_throws KeyError qra[:Z] @test_approx_eq q'*full(q, thin = false) eye(n) @test_approx_eq q*full(q, thin = false)' eye(n) + @test_approx_eq q'*eye(n)' full(q, thin = false)' + @test_approx_eq full(q, thin = false)'q eye(n) + @test_approx_eq eye(n)'q' full(q, thin = false)' @test_approx_eq q*r a @test_approx_eq_eps a*(qra\b) b 3000ε @test_approx_eq full(qra) a