From c9d0ab4b3ab5ab5ada87eb6bb2963698d7b88d67 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 5 Nov 2023 10:31:07 -0500 Subject: [PATCH] Handle complex number dispatches in AppleAccelerate and MKL --- src/appleaccelerate.jl | 99 +++++++++++++++++++++++++++++++++++++++++- src/default.jl | 10 ++--- src/mkl.jl | 99 +++++++++++++++++++++++++++++++++++++++++- test/basictests.jl | 3 +- 4 files changed, 201 insertions(+), 10 deletions(-) diff --git a/src/appleaccelerate.jl b/src/appleaccelerate.jl index db684e278..7c3eb8e12 100644 --- a/src/appleaccelerate.jl +++ b/src/appleaccelerate.jl @@ -26,6 +26,46 @@ function appleaccelerate_isavailable() return true end +function aa_getrf!(A::AbstractMatrix{<:ComplexF64}; + ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))), + info = Ref{Cint}(), + check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1, stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))) + end + ccall(("zgetrf_", libacc), Cvoid, + (Ref{Cint}, Ref{Cint}, Ptr{ComplexF64}, + Ref{Cint}, Ptr{Cint}, Ptr{Cint}), + m, n, A, lda, ipiv, info) + info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_")) + A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type +end + +function aa_getrf!(A::AbstractMatrix{<:ComplexF32}; + ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))), + info = Ref{Cint}(), + check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1, stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))) + end + ccall(("cgetrf_", libacc), Cvoid, + (Ref{Cint}, Ref{Cint}, Ptr{ComplexF32}, + Ref{Cint}, Ptr{Cint}, Ptr{Cint}), + m, n, A, lda, ipiv, info) + info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_")) + A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type +end + function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))), info = Ref{Cint}(), @@ -67,6 +107,55 @@ function aa_getrf!(A::AbstractMatrix{<:Float32}; A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type end +function aa_getrs!(trans::AbstractChar, + A::AbstractMatrix{<:ComplexF64}, + ipiv::AbstractVector{Cint}, + B::AbstractVecOrMat{<:ComplexF64}; + info = Ref{Cint}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall(("zgetrs_", libacc), Cvoid, + (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF64}, Ref{Cint}, + Ptr{Cint}, Ptr{ComplexF64}, Ref{Cint}, Ptr{Cint}, Clong), + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, + 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) +end + +function aa_getrs!(trans::AbstractChar, + A::AbstractMatrix{<:ComplexF32}, + ipiv::AbstractVector{Cint}, + B::AbstractVecOrMat{<:ComplexF32}; + info = Ref{Cint}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall(("cgetrs_", libacc), Cvoid, + (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF32}, Ref{Cint}, + Ptr{Cint}, Ptr{ComplexF32}, Ref{Cint}, Ptr{Cint}, Clong), + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, + 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, @@ -128,12 +217,20 @@ else nothing end -function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr, +function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A::AbstractMatrix{<:Float64}, b::AbstractArray{<:Float64}, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) PREALLOCATED_APPLE_LU end +function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + A = rand(eltype(A), 0, 0) + luinst = ArrayInterface.lu_instance(A) + LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}() +end + function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization; kwargs...) A = cache.A diff --git a/src/default.jl b/src/default.jl index 446824dbc..d3ee52b83 100644 --- a/src/default.jl +++ b/src/default.jl @@ -162,9 +162,7 @@ function defaultalg(A, b, assump::OperatorAssumptions) __conditioning(assump) === OperatorCondition.WellConditioned) if length(b) <= 10 DefaultAlgorithmChoice.GenericLUFactorization - elseif VERSION >= v"1.8" && appleaccelerate_isavailable() && - (A === nothing ? eltype(b) <: Union{Float32, Float64} : - eltype(A) <: Union{Float32, Float64}) + elseif VERSION >= v"1.8" && appleaccelerate_isavailable() DefaultAlgorithmChoice.AppleAccelerateLUFactorization elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) || (usemkl && length(b) <= 200)) && @@ -173,8 +171,7 @@ function defaultalg(A, b, assump::OperatorAssumptions) DefaultAlgorithmChoice.RFLUFactorization #elseif A === nothing || A isa Matrix # alg = FastLUFactorization() - elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} : - eltype(A) <: Union{Float32, Float64}) + elseif usemkl DefaultAlgorithmChoice.MKLLUFactorization else DefaultAlgorithmChoice.LUFactorization @@ -183,8 +180,7 @@ function defaultalg(A, b, assump::OperatorAssumptions) DefaultAlgorithmChoice.QRFactorization elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned DefaultAlgorithmChoice.SVDFactorization - elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} : - eltype(A) <: Union{Float32, Float64}) + elseif usemkl DefaultAlgorithmChoice.MKLLUFactorization else DefaultAlgorithmChoice.LUFactorization diff --git a/src/mkl.jl b/src/mkl.jl index fa0c3f5b1..3aeb6b663 100644 --- a/src/mkl.jl +++ b/src/mkl.jl @@ -8,6 +8,46 @@ to avoid allocations and does not require libblastrampoline. """ struct MKLLUFactorization <: AbstractFactorization end +function getrf!(A::AbstractMatrix{<:ComplexF64}; + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), + info = Ref{BlasInt}(), + check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1, stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) + end + ccall((@blasfunc(zgetrf_), MKL_jll.libmkl_rt), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ipiv, info) + chkargsok(info[]) + A, ipiv, info[], info #Error code is stored in LU factorization type +end + +function getrf!(A::AbstractMatrix{<:ComplexF32}; + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), + info = Ref{BlasInt}(), + check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1, stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) + end + ccall((@blasfunc(cgetrf_), MKL_jll.libmkl_rt), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ipiv, info) + chkargsok(info[]) + A, ipiv, info[], info #Error code is stored in LU factorization type +end + function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), info = Ref{BlasInt}(), @@ -48,6 +88,56 @@ function getrf!(A::AbstractMatrix{<:Float32}; A, ipiv, info[], info #Error code is stored in LU factorization type end +function getrs!(trans::AbstractChar, + A::AbstractMatrix{<:ComplexF64}, + ipiv::AbstractVector{BlasInt}, + B::AbstractVecOrMat{<:ComplexF64}; + info = Ref{BlasInt}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall(("zgetrs_", MKL_jll.libmkl_rt), Cvoid, + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, + 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + +function getrs!(trans::AbstractChar, + A::AbstractMatrix{<:ComplexF32}, + ipiv::AbstractVector{BlasInt}, + B::AbstractVecOrMat{<:ComplexF32}; + info = Ref{BlasInt}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall(("cgetrs_", MKL_jll.libmkl_rt), Cvoid, + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, + 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{BlasInt}, @@ -106,12 +196,19 @@ const PREALLOCATED_MKL_LU = begin luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}() end -function init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr, +function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A::AbstractMatrix{<:Float64}, b::AbstractArray{<:Float64}, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) PREALLOCATED_MKL_LU end +function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + A = rand(eltype(A), 0, 0) + ArrayInterface.lu_instance(A), Ref{BlasInt}() +end + function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; kwargs...) A = cache.A diff --git a/test/basictests.jl b/test/basictests.jl index d31ecc074..93b22fbdd 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -235,11 +235,12 @@ end for alg in test_algs @testset "$alg" begin test_interface(alg, prob1, prob2) - VERSION >= v"1.9" && (alg isa MKLLUFactorization || test_interface(alg, prob3, prob4)) + VERSION >= v"1.9" && test_interface(alg, prob3, prob4) end end if LinearSolve.appleaccelerate_isavailable() test_interface(AppleAccelerateLUFactorization(), prob1, prob2) + test_interface(AppleAccelerateLUFactorization(), prob3, prob4) end end