Skip to content

Commit

Permalink
Merge pull request #415 from SciML/complex
Browse files Browse the repository at this point in the history
Handle complex number dispatches in AppleAccelerate and MKL
  • Loading branch information
ChrisRackauckas authored Nov 5, 2023
2 parents d863895 + 06e6f81 commit a9b5581
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 9 deletions.
97 changes: 97 additions & 0 deletions src/appleaccelerate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,46 @@ function appleaccelerate_isavailable()
return true
end

function aa_getrf!(A::AbstractMatrix{<:ComplexF64};
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
info = Ref{Cint}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
end
ccall(("zgetrf_", libacc), Cvoid,
(Ref{Cint}, Ref{Cint}, Ptr{ComplexF64},
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
m, n, A, lda, ipiv, info)
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
end

function aa_getrf!(A::AbstractMatrix{<:ComplexF32};
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
info = Ref{Cint}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
end
ccall(("cgetrf_", libacc), Cvoid,
(Ref{Cint}, Ref{Cint}, Ptr{ComplexF32},
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
m, n, A, lda, ipiv, info)
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
end

function aa_getrf!(A::AbstractMatrix{<:Float64};
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
info = Ref{Cint}(),
Expand Down Expand Up @@ -67,6 +107,55 @@ function aa_getrf!(A::AbstractMatrix{<:Float32};
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
end

function aa_getrs!(trans::AbstractChar,
A::AbstractMatrix{<:ComplexF64},
ipiv::AbstractVector{Cint},
B::AbstractVecOrMat{<:ComplexF64};
info = Ref{Cint}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
end
nrhs = size(B, 2)
ccall(("zgetrs_", libacc), Cvoid,
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF64}, Ref{Cint},
Ptr{Cint}, Ptr{ComplexF64}, Ref{Cint}, Ptr{Cint}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
end

function aa_getrs!(trans::AbstractChar,
A::AbstractMatrix{<:ComplexF32},
ipiv::AbstractVector{Cint},
B::AbstractVecOrMat{<:ComplexF32};
info = Ref{Cint}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
end
nrhs = size(B, 2)
ccall(("cgetrs_", libacc), Cvoid,
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF32}, Ref{Cint},
Ptr{Cint}, Ptr{ComplexF32}, Ref{Cint}, Ptr{Cint}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B
end

function aa_getrs!(trans::AbstractChar,
A::AbstractMatrix{<:Float64},
ipiv::AbstractVector{Cint},
Expand Down Expand Up @@ -134,6 +223,14 @@ function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u,
PREALLOCATED_APPLE_LU
end

function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
A = rand(eltype(A), 0, 0)
luinst = ArrayInterface.lu_instance(A)
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
end

function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;
kwargs...)
A = cache.A
Expand Down
10 changes: 3 additions & 7 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ function defaultalg(A, b, assump::OperatorAssumptions)
__conditioning(assump) === OperatorCondition.WellConditioned)
if length(b) <= 10
DefaultAlgorithmChoice.GenericLUFactorization
elseif VERSION >= v"1.8" && appleaccelerate_isavailable() &&
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
elseif VERSION >= v"1.8" && appleaccelerate_isavailable()
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
(usemkl && length(b) <= 200)) &&
Expand All @@ -173,8 +171,7 @@ function defaultalg(A, b, assump::OperatorAssumptions)
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
elseif usemkl
DefaultAlgorithmChoice.MKLLUFactorization
else
DefaultAlgorithmChoice.LUFactorization
Expand All @@ -183,8 +180,7 @@ function defaultalg(A, b, assump::OperatorAssumptions)
DefaultAlgorithmChoice.QRFactorization
elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
elseif usemkl
DefaultAlgorithmChoice.MKLLUFactorization
else
DefaultAlgorithmChoice.LUFactorization
Expand Down
99 changes: 98 additions & 1 deletion src/mkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,46 @@ to avoid allocations and does not require libblastrampoline.
"""
struct MKLLUFactorization <: AbstractFactorization end

function getrf!(A::AbstractMatrix{<:ComplexF64};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
end
ccall((@blasfunc(zgetrf_), MKL_jll.libmkl_rt), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type
end

function getrf!(A::AbstractMatrix{<:ComplexF32};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
end
ccall((@blasfunc(cgetrf_), MKL_jll.libmkl_rt), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type
end

function getrf!(A::AbstractMatrix{<:Float64};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
Expand Down Expand Up @@ -48,6 +88,56 @@ function getrf!(A::AbstractMatrix{<:Float32};
A, ipiv, info[], info #Error code is stored in LU factorization type
end

function getrs!(trans::AbstractChar,
A::AbstractMatrix{<:ComplexF64},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:ComplexF64};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
end
nrhs = size(B, 2)
ccall(("zgetrs_", MKL_jll.libmkl_rt), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B
end

function getrs!(trans::AbstractChar,
A::AbstractMatrix{<:ComplexF32},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:ComplexF32};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
end
nrhs = size(B, 2)
ccall(("cgetrs_", MKL_jll.libmkl_rt), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B
end

function getrs!(trans::AbstractChar,
A::AbstractMatrix{<:Float64},
ipiv::AbstractVector{BlasInt},
Expand Down Expand Up @@ -106,12 +196,19 @@ const PREALLOCATED_MKL_LU = begin
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
end

function init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
PREALLOCATED_MKL_LU
end

function LinearSolve.init_cacheval(alg::MKLLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
A = rand(eltype(A), 0, 0)
ArrayInterface.lu_instance(A), Ref{BlasInt}()
end

function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
kwargs...)
A = cache.A
Expand Down
3 changes: 2 additions & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,12 @@ end
for alg in test_algs
@testset "$alg" begin
test_interface(alg, prob1, prob2)
VERSION >= v"1.9" && (alg isa MKLLUFactorization || test_interface(alg, prob3, prob4))
VERSION >= v"1.9" && test_interface(alg, prob3, prob4)
end
end
if LinearSolve.appleaccelerate_isavailable()
test_interface(AppleAccelerateLUFactorization(), prob1, prob2)
test_interface(AppleAccelerateLUFactorization(), prob3, prob4)
end
end

Expand Down

0 comments on commit a9b5581

Please sign in to comment.