Skip to content

Commit

Permalink
32-bit BLIS form
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Nov 12, 2023
1 parent afcc28e commit 703b3ed
Showing 1 changed file with 84 additions and 53 deletions.
137 changes: 84 additions & 53 deletions ext/LinearSolveBLISExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,90 +13,91 @@ using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCac
const global libblis = blis_jll.blis

function getrf!(A::AbstractMatrix{<:ComplexF64};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
info = Ref{Cint}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
end
ccall((@blasfunc(zgetrf_), libblis), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
ccall(("zgetrf_", libblis), Cvoid,
(Ref{Cint}, Ref{Cint}, Ptr{ComplexF64},
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
end

function getrf!(A::AbstractMatrix{<:ComplexF32};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
info = Ref{Cint}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
end
ccall((@blasfunc(cgetrf_), libblis), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
ccall(("cgetrf_", libblis), Cvoid,
(Ref{Cint}, Ref{Cint}, Ptr{ComplexF32},
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
end

function getrf!(A::AbstractMatrix{<:Float64};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
info = Ref{Cint}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
end
ccall((@blasfunc(dgetrf_), libblis), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
ccall(("dgetrf_", libblis), Cvoid,
(Ref{Cint}, Ref{Cint}, Ptr{Float64},
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
end

function getrf!(A::AbstractMatrix{<:Float32};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2))),
info = Ref{Cint}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
end
ccall((@blasfunc(sgetrf_), libblis), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),

ccall(("sgetrf_", libblis), Cvoid,
(Ref{Cint}, Ref{Cint}, Ptr{Float32},
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
end

function getrs!(trans::AbstractChar,
A::AbstractMatrix{<:ComplexF64},
ipiv::AbstractVector{BlasInt},
ipiv::AbstractVector{Cint},
B::AbstractVecOrMat{<:ComplexF64};
info = Ref{BlasInt}())
info = Ref{Cint}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
Expand All @@ -109,19 +110,18 @@ function getrs!(trans::AbstractChar,
end
nrhs = size(B, 2)
ccall(("zgetrs_", libblis), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF64}, Ref{Cint},
Ptr{Cint}, Ptr{ComplexF64}, Ref{Cint}, Ptr{Cint}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B
end

function getrs!(trans::AbstractChar,
A::AbstractMatrix{<:ComplexF32},
ipiv::AbstractVector{BlasInt},
ipiv::AbstractVector{Cint},
B::AbstractVecOrMat{<:ComplexF32};
info = Ref{BlasInt}())
info = Ref{Cint}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
Expand All @@ -134,8 +134,8 @@ function getrs!(trans::AbstractChar,
end
nrhs = size(B, 2)
ccall(("cgetrs_", libblis), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF32}, Ref{Cint},
Ptr{Cint}, Ptr{ComplexF32}, Ref{Cint}, Ptr{Cint}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
Expand All @@ -144,9 +144,9 @@ end

function getrs!(trans::AbstractChar,
A::AbstractMatrix{<:Float64},
ipiv::AbstractVector{BlasInt},
ipiv::AbstractVector{Cint},
B::AbstractVecOrMat{<:Float64};
info = Ref{BlasInt}())
info = Ref{Cint}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
Expand All @@ -159,8 +159,8 @@ function getrs!(trans::AbstractChar,
end
nrhs = size(B, 2)
ccall(("dgetrs_", libblis), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
Expand All @@ -169,9 +169,9 @@ end

function getrs!(trans::AbstractChar,
A::AbstractMatrix{<:Float32},
ipiv::AbstractVector{BlasInt},
ipiv::AbstractVector{Cint},
B::AbstractVecOrMat{<:Float32};
info = Ref{BlasInt}())
info = Ref{Cint}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
Expand All @@ -184,8 +184,8 @@ function getrs!(trans::AbstractChar,
end
nrhs = size(B, 2)
ccall(("sgetrs_", libblis), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float32}, Ref{Cint},
Ptr{Cint}, Ptr{Float32}, Ref{Cint}, Ptr{Cint}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
Expand All @@ -195,24 +195,54 @@ end
default_alias_A(::BLISLUFactorization, ::Any, ::Any) = false
default_alias_b(::BLISLUFactorization, ::Any, ::Any) = false

const PREALLOCATED_BLIS_LU = begin
const PREALLOCATED_APPLE_LU = begin
A = rand(0, 0)
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
luinst = ArrayInterface.lu_instance(A)
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
end

function LinearSolve.init_cacheval(alg::BLISLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
PREALLOCATED_BLIS_LU
PREALLOCATED_APPLE_LU
end

function LinearSolve.init_cacheval(alg::BLISLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
A = rand(eltype(A), 0, 0)
ArrayInterface.lu_instance(A), Ref{BlasInt}()
luinst = ArrayInterface.lu_instance(A)
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
end

function SciMLBase.solve!(cache::LinearCache, alg::BLISLUFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :BLISLUFactorization)
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
fact = LU(res[1:3]...), res[4]
cache.cacheval = fact
cache.isfresh = false
end

A, info = @get_cacheval(cache, :BLISLUFactorization)
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
m, n = size(A, 1), size(A, 2)
if m > n
Bc = copy(cache.b)
getrs!('N', A.factors, A.ipiv, Bc; info)
return copyto!(cache.u, 1, Bc, 1, n)
else
copyto!(cache.u, cache.b)
getrs!('N', A.factors, A.ipiv, cache.u; info)
end

SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

#=
function SciMLBase.solve!(cache::LinearCache, alg::BLISLUFactorization;
kwargs...)
A = cache.A
Expand Down Expand Up @@ -244,5 +274,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::BLISLUFactorization;
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
=#
end
=#

end

0 comments on commit 703b3ed

Please sign in to comment.