Skip to content

Commit

Permalink
Merge pull request #16602 from JuliaLang/anj/qr2
Browse files Browse the repository at this point in the history
Ensure that multiplication with QRQs doesn't fall back to generic_matmatmul!
  • Loading branch information
tkelman committed May 29, 2016
2 parents 7becf07 + f922d5c commit 15626c3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
54 changes: 40 additions & 14 deletions base/linalg/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ end
### QB
A_mul_B!{T<:BlasFloat}(A::QRCompactWYQ{T}, B::StridedVecOrMat{T}) = LAPACK.gemqrt!('L','N',A.factors,A.T,B)
A_mul_B!{T<:BlasFloat}(A::QRPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormqr!('L','N',A.factors,A.τ,B)
function A_mul_B!{T}(A::QRPackedQ{T}, B::AbstractVecOrMat{T})
function A_mul_B!(A::QRPackedQ, B::AbstractVecOrMat)
mA, nA = size(A.factors)
mB, nB = size(B,1), size(B,2)
if mA != mB
Expand All @@ -290,8 +290,8 @@ function A_mul_B!{T}(A::QRPackedQ{T}, B::AbstractVecOrMat{T})
B
end

function (*){TA,Tb}(A::Union{QRPackedQ{TA},QRCompactWYQ{TA}}, b::StridedVector{Tb})
TAb = promote_type(TA, Tb)
function (*)(A::Union{QRPackedQ,QRCompactWYQ}, b::StridedVector)
TAb = promote_type(eltype(A), eltype(b))
Anew = convert(AbstractMatrix{TAb}, A)
if size(A.factors, 1) == length(b)
bnew = copy_oftype(b, TAb)
Expand All @@ -302,8 +302,8 @@ function (*){TA,Tb}(A::Union{QRPackedQ{TA},QRCompactWYQ{TA}}, b::StridedVector{T
end
A_mul_B!(Anew, bnew)
end
function (*){TA,TB}(A::Union{QRPackedQ{TA},QRCompactWYQ{TA}}, B::StridedMatrix{TB})
TAB = promote_type(TA, TB)
function (*)(A::Union{QRPackedQ,QRCompactWYQ}, B::StridedMatrix)
TAB = promote_type(eltype(A), eltype(B))
Anew = convert(AbstractMatrix{TAB}, A)
if size(A.factors, 1) == size(B, 1)
Bnew = copy_oftype(B, TAB)
Expand All @@ -320,7 +320,7 @@ Ac_mul_B!{T<:BlasReal}(A::QRCompactWYQ{T}, B::StridedVecOrMat{T}) = LAPACK.gemqr
Ac_mul_B!{T<:BlasComplex}(A::QRCompactWYQ{T}, B::StridedVecOrMat{T}) = LAPACK.gemqrt!('L','C',A.factors,A.T,B)
Ac_mul_B!{T<:BlasReal}(A::QRPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormqr!('L','T',A.factors,A.τ,B)
Ac_mul_B!{T<:BlasComplex}(A::QRPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormqr!('L','C',A.factors,A.τ,B)
function Ac_mul_B!{T}(A::QRPackedQ{T}, B::AbstractVecOrMat{T})
function Ac_mul_B!(A::QRPackedQ, B::AbstractVecOrMat)
mA, nA = size(A.factors)
mB, nB = size(B,1), size(B,2)
if mA != mB
Expand All @@ -344,15 +344,28 @@ function Ac_mul_B!{T}(A::QRPackedQ{T}, B::AbstractVecOrMat{T})
end
B
end
function Ac_mul_B{TQ<:Number,TB<:Number,N}(Q::Union{QRPackedQ{TQ},QRCompactWYQ{TQ}}, B::StridedArray{TB,N})
TQB = promote_type(TQ,TB)
function Ac_mul_B(Q::Union{QRPackedQ,QRCompactWYQ}, B::StridedVecOrMat)
TQB = promote_type(eltype(Q), eltype(B))
return Ac_mul_B!(convert(AbstractMatrix{TQB}, Q), copy_oftype(B, TQB))
end

### QBc/QcBc
for (f1, f2) in ((:A_mul_Bc, :A_mul_B!),
(:Ac_mul_Bc, :Ac_mul_B!))
@eval begin
function ($f1)(Q::Union{QRPackedQ,QRCompactWYQ}, B::StridedVecOrMat)
TQB = promote_type(eltype(Q), eltype(B))
Bc = similar(B, TQB, (size(B, 2), size(B, 1)))
ctranspose!(Bc, B)
return ($f2)(convert(AbstractMatrix{TQB}, Q), Bc)
end
end
end

### AQ
A_mul_B!{T<:BlasFloat}(A::StridedVecOrMat{T}, B::QRCompactWYQ{T}) = LAPACK.gemqrt!('R','N', B.factors, B.T, A)
A_mul_B!{T<:BlasFloat}(A::StridedVecOrMat{T}, B::QRPackedQ{T}) = LAPACK.ormqr!('R', 'N', B.factors, B.τ, A)
function A_mul_B!{T}(A::StridedMatrix{T},Q::QRPackedQ{T})
function A_mul_B!(A::StridedMatrix,Q::QRPackedQ)
mQ, nQ = size(Q.factors)
mA, nA = size(A,1), size(A,2)
if nA != mQ
Expand All @@ -377,8 +390,8 @@ function A_mul_B!{T}(A::StridedMatrix{T},Q::QRPackedQ{T})
A
end

function (*){TA,TQ,N}(A::StridedArray{TA,N}, Q::Union{QRPackedQ{TQ},QRCompactWYQ{TQ}})
TAQ = promote_type(TA, TQ)
function (*)(A::StridedMatrix, Q::Union{QRPackedQ,QRCompactWYQ})
TAQ = promote_type(eltype(A), eltype(Q))
return A_mul_B!(copy_oftype(A, TAQ), convert(AbstractMatrix{TAQ}, Q))
end

Expand All @@ -387,7 +400,7 @@ A_mul_Bc!{T<:BlasReal}(A::StridedVecOrMat{T}, B::QRCompactWYQ{T}) = LAPACK.gemqr
A_mul_Bc!{T<:BlasComplex}(A::StridedVecOrMat{T}, B::QRCompactWYQ{T}) = LAPACK.gemqrt!('R','C',B.factors,B.T,A)
A_mul_Bc!{T<:BlasReal}(A::StridedVecOrMat{T}, B::QRPackedQ{T}) = LAPACK.ormqr!('R','T',B.factors,B.τ,A)
A_mul_Bc!{T<:BlasComplex}(A::StridedVecOrMat{T}, B::QRPackedQ{T}) = LAPACK.ormqr!('R','C',B.factors,B.τ,A)
function A_mul_Bc!{T}(A::AbstractMatrix{T},Q::QRPackedQ{T})
function A_mul_Bc!(A::AbstractMatrix,Q::QRPackedQ)
mQ, nQ = size(Q.factors)
mA, nA = size(A,1), size(A,2)
if nA != mQ
Expand All @@ -411,8 +424,8 @@ function A_mul_Bc!{T}(A::AbstractMatrix{T},Q::QRPackedQ{T})
end
A
end
function A_mul_Bc{TA,TB}(A::AbstractMatrix{TA}, B::Union{QRCompactWYQ{TB},QRPackedQ{TB}})
TAB = promote_type(TA,TB)
function A_mul_Bc(A::AbstractMatrix, B::Union{QRCompactWYQ,QRPackedQ})
TAB = promote_type(eltype(A),eltype(B))
BB = convert(AbstractMatrix{TAB}, B)
if size(A,2) == size(B.factors, 1)
AA = similar(A, TAB, size(A))
Expand All @@ -425,6 +438,19 @@ function A_mul_Bc{TA,TB}(A::AbstractMatrix{TA}, B::Union{QRCompactWYQ{TB},QRPack
end
end

### AcQ/AcQc
for (f1, f2) in ((:Ac_mul_B, :A_mul_B!),
(:Ac_mul_Bc, :A_mul_Bc!))
@eval begin
function ($f1)(A::StridedVecOrMat, Q::Union{QRPackedQ,QRCompactWYQ})
TAQ = promote_type(eltype(A), eltype(Q))
Ac = similar(A, TAQ, (size(A, 2), size(A, 1)))
ctranspose!(Ac, A)
return ($f2)(Ac, convert(AbstractMatrix{TAQ}, Q))
end
end
end

A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, b::StridedVector{T}) = (A_ldiv_B!(UpperTriangular(A[:R]), sub(Ac_mul_B!(A[:Q], b), 1:size(A, 2))); b)
A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, B::StridedMatrix{T}) = (A_ldiv_B!(UpperTriangular(A[:R]), sub(Ac_mul_B!(A[:Q], B), 1:size(A, 2), 1:size(B, 2))); B)

Expand Down
3 changes: 3 additions & 0 deletions test/linalg/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ debug && println("QR decomposition (without pivoting)")
@test_throws KeyError qra[:Z]
@test_approx_eq q'*full(q, thin = false) eye(n)
@test_approx_eq q*full(q, thin = false)' eye(n)
@test_approx_eq q'*eye(n)' full(q, thin = false)'
@test_approx_eq full(q, thin = false)'q eye(n)
@test_approx_eq eye(n)'q' full(q, thin = false)'
@test_approx_eq q*r a
@test_approx_eq_eps a*(qra\b) b 3000ε
@test_approx_eq full(qra) a
Expand Down

0 comments on commit 15626c3

Please sign in to comment.