-
-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test case: ```julia using LinearSolve, blis_jll A = rand(4, 4) b = rand(4) prob = LinearProblem(A, b) sol = solve(prob,LinearSolve.BLISLUFactorization()) sol.u ``` throws: ```julia julia> sol = solve(prob,LinearSolve.BLISLUFactorization()) ERROR: TypeError: in ccall: first argument not a pointer or valid constant expression, expected Ptr, got a value of type Tuple{Symbol, Ptr{Nothing}} Stacktrace: [1] getrf!(A::Matrix{Float64}; ipiv::Vector{Int64}, info::Base.RefValue{Int64}, check::Bool) @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:67 [2] getrf! @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:55 [inlined] [3] #solve!#9 @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:222 [inlined] [4] solve! @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:216 [inlined] [5] #solve!#6 @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:209 [inlined] [6] solve! @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:208 [inlined] [7] #solve#5 @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:205 [inlined] [8] solve(::LinearProblem{…}, ::LinearSolve.BLISLUFactorization) @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:202 [9] top-level scope @ REPL[8]:1 Some type information was truncated. Use `show(err)` to see complete types. ```
- Loading branch information
1 parent
a455e27
commit 2fea1c2
Showing
3 changed files
with
253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
module LinearSolveBLISExt | ||
|
||
using Libdl | ||
using blis_jll | ||
using LinearAlgebra | ||
using LinearSolve | ||
|
||
using LinearAlgebra: BlasInt, LU | ||
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, | ||
@blasfunc, chkargsok | ||
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase | ||
|
||
const global libblis = dlopen(blis_jll.blis_path) | ||
|
||
function getrf!(A::AbstractMatrix{<:ComplexF64}; | ||
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), | ||
info = Ref{BlasInt}(), | ||
check = false) | ||
require_one_based_indexing(A) | ||
check && chkfinite(A) | ||
chkstride1(A) | ||
m, n = size(A) | ||
lda = max(1, stride(A, 2)) | ||
if isempty(ipiv) | ||
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) | ||
end | ||
ccall((@blasfunc(zgetrf_), libblis), Cvoid, | ||
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, | ||
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), | ||
m, n, A, lda, ipiv, info) | ||
chkargsok(info[]) | ||
A, ipiv, info[], info #Error code is stored in LU factorization type | ||
end | ||
|
||
function getrf!(A::AbstractMatrix{<:ComplexF32}; | ||
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), | ||
info = Ref{BlasInt}(), | ||
check = false) | ||
require_one_based_indexing(A) | ||
check && chkfinite(A) | ||
chkstride1(A) | ||
m, n = size(A) | ||
lda = max(1, stride(A, 2)) | ||
if isempty(ipiv) | ||
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) | ||
end | ||
ccall((@blasfunc(cgetrf_), libblis), Cvoid, | ||
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, | ||
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), | ||
m, n, A, lda, ipiv, info) | ||
chkargsok(info[]) | ||
A, ipiv, info[], info #Error code is stored in LU factorization type | ||
end | ||
|
||
function getrf!(A::AbstractMatrix{<:Float64}; | ||
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), | ||
info = Ref{BlasInt}(), | ||
check = false) | ||
require_one_based_indexing(A) | ||
check && chkfinite(A) | ||
chkstride1(A) | ||
m, n = size(A) | ||
lda = max(1, stride(A, 2)) | ||
if isempty(ipiv) | ||
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) | ||
end | ||
ccall((@blasfunc(dgetrf_), libblis), Cvoid, | ||
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, | ||
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), | ||
m, n, A, lda, ipiv, info) | ||
chkargsok(info[]) | ||
A, ipiv, info[], info #Error code is stored in LU factorization type | ||
end | ||
|
||
function getrf!(A::AbstractMatrix{<:Float32}; | ||
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), | ||
info = Ref{BlasInt}(), | ||
check = false) | ||
require_one_based_indexing(A) | ||
check && chkfinite(A) | ||
chkstride1(A) | ||
m, n = size(A) | ||
lda = max(1, stride(A, 2)) | ||
if isempty(ipiv) | ||
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) | ||
end | ||
ccall((@blasfunc(sgetrf_), libblis), Cvoid, | ||
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, | ||
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), | ||
m, n, A, lda, ipiv, info) | ||
chkargsok(info[]) | ||
A, ipiv, info[], info #Error code is stored in LU factorization type | ||
end | ||
|
||
function getrs!(trans::AbstractChar, | ||
A::AbstractMatrix{<:ComplexF64}, | ||
ipiv::AbstractVector{BlasInt}, | ||
B::AbstractVecOrMat{<:ComplexF64}; | ||
info = Ref{BlasInt}()) | ||
require_one_based_indexing(A, ipiv, B) | ||
LinearAlgebra.LAPACK.chktrans(trans) | ||
chkstride1(A, B, ipiv) | ||
n = LinearAlgebra.checksquare(A) | ||
if n != size(B, 1) | ||
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) | ||
end | ||
if n != length(ipiv) | ||
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) | ||
end | ||
nrhs = size(B, 2) | ||
ccall(("zgetrs_", libblis), Cvoid, | ||
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, | ||
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), | ||
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, | ||
1) | ||
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) | ||
B | ||
end | ||
|
||
function getrs!(trans::AbstractChar, | ||
A::AbstractMatrix{<:ComplexF32}, | ||
ipiv::AbstractVector{BlasInt}, | ||
B::AbstractVecOrMat{<:ComplexF32}; | ||
info = Ref{BlasInt}()) | ||
require_one_based_indexing(A, ipiv, B) | ||
LinearAlgebra.LAPACK.chktrans(trans) | ||
chkstride1(A, B, ipiv) | ||
n = LinearAlgebra.checksquare(A) | ||
if n != size(B, 1) | ||
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) | ||
end | ||
if n != length(ipiv) | ||
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) | ||
end | ||
nrhs = size(B, 2) | ||
ccall(("cgetrs_", libblis), Cvoid, | ||
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, | ||
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), | ||
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, | ||
1) | ||
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) | ||
B | ||
end | ||
|
||
function getrs!(trans::AbstractChar, | ||
A::AbstractMatrix{<:Float64}, | ||
ipiv::AbstractVector{BlasInt}, | ||
B::AbstractVecOrMat{<:Float64}; | ||
info = Ref{BlasInt}()) | ||
require_one_based_indexing(A, ipiv, B) | ||
LinearAlgebra.LAPACK.chktrans(trans) | ||
chkstride1(A, B, ipiv) | ||
n = LinearAlgebra.checksquare(A) | ||
if n != size(B, 1) | ||
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) | ||
end | ||
if n != length(ipiv) | ||
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) | ||
end | ||
nrhs = size(B, 2) | ||
ccall(("dgetrs_", libblis), Cvoid, | ||
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, | ||
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), | ||
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, | ||
1) | ||
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) | ||
B | ||
end | ||
|
||
function getrs!(trans::AbstractChar, | ||
A::AbstractMatrix{<:Float32}, | ||
ipiv::AbstractVector{BlasInt}, | ||
B::AbstractVecOrMat{<:Float32}; | ||
info = Ref{BlasInt}()) | ||
require_one_based_indexing(A, ipiv, B) | ||
LinearAlgebra.LAPACK.chktrans(trans) | ||
chkstride1(A, B, ipiv) | ||
n = LinearAlgebra.checksquare(A) | ||
if n != size(B, 1) | ||
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) | ||
end | ||
if n != length(ipiv) | ||
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) | ||
end | ||
nrhs = size(B, 2) | ||
ccall(("sgetrs_", libblis), Cvoid, | ||
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt}, | ||
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), | ||
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, | ||
1) | ||
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) | ||
B | ||
end | ||
|
||
default_alias_A(::BLISLUFactorization, ::Any, ::Any) = false | ||
default_alias_b(::BLISLUFactorization, ::Any, ::Any) = false | ||
|
||
const PREALLOCATED_BLIS_LU = begin | ||
A = rand(0, 0) | ||
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}() | ||
end | ||
|
||
function LinearSolve.init_cacheval(alg::BLISLUFactorization, A, b, u, Pl, Pr, | ||
maxiters::Int, abstol, reltol, verbose::Bool, | ||
assumptions::OperatorAssumptions) | ||
PREALLOCATED_BLIS_LU | ||
end | ||
|
||
function LinearSolve.init_cacheval(alg::BLISLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr, | ||
maxiters::Int, abstol, reltol, verbose::Bool, | ||
assumptions::OperatorAssumptions) | ||
A = rand(eltype(A), 0, 0) | ||
ArrayInterface.lu_instance(A), Ref{BlasInt}() | ||
end | ||
|
||
function SciMLBase.solve!(cache::LinearCache, alg::BLISLUFactorization; | ||
kwargs...) | ||
A = cache.A | ||
A = convert(AbstractMatrix, A) | ||
if cache.isfresh | ||
cacheval = @get_cacheval(cache, :BLISLUFactorization) | ||
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2]) | ||
fact = LU(res[1:3]...), res[4] | ||
cache.cacheval = fact | ||
cache.isfresh = false | ||
end | ||
|
||
y = ldiv!(cache.u, @get_cacheval(cache, :BLISLUFactorization)[1], cache.b) | ||
SciMLBase.build_linear_solution(alg, y, nothing, cache) | ||
|
||
#= | ||
A, info = @get_cacheval(cache, :BLISLUFactorization) | ||
LinearAlgebra.require_one_based_indexing(cache.u, cache.b) | ||
m, n = size(A, 1), size(A, 2) | ||
if m > n | ||
Bc = copy(cache.b) | ||
getrs!('N', A.factors, A.ipiv, Bc; info) | ||
return copyto!(cache.u, 1, Bc, 1, n) | ||
else | ||
copyto!(cache.u, cache.b) | ||
getrs!('N', A.factors, A.ipiv, cache.u; info) | ||
end | ||
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) | ||
=# | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters