Skip to content

Commit

Permalink
Merge pull request #495 from SciML/ap/fix_le
Browse files Browse the repository at this point in the history
Fix the formatting
  • Loading branch information
avik-pal authored Apr 25, 2024
2 parents 89ea6ee + 69cabe9 commit 4afec5a
Show file tree
Hide file tree
Showing 3 changed files with 1,444 additions and 1,441 deletions.
125 changes: 64 additions & 61 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,64 @@
module LinearSolveCUDAExt

using CUDA
using LinearSolve
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
using SciMLBase: AbstractSciMLOperator

function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
if LinearSolve.cudss_loaded(A)
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
else
if !LinearSolve.ALREADY_WARNED_CUDSS[]
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
LinearSolve.ALREADY_WARNED_CUDSS[] = true
end
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
end
end

function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
if !LinearSolve.CUDSS_LOADED[]
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")
end
nothing
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
kwargs...)
if cache.isfresh
fact = qr(CUDA.CuArray(cache.A))
cache.cacheval = fact
cache.isfresh = false
end
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
cache.u .= y
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
qr(CUDA.CuArray(A))
end

function LinearSolve.init_cacheval(::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function LinearSolve.init_cacheval(::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function LinearSolve.init_cacheval(::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

end
module LinearSolveCUDAExt

using CUDA
using LinearSolve
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
using SciMLBase: AbstractSciMLOperator

function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
if LinearSolve.cudss_loaded(A)
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
else
if !LinearSolve.ALREADY_WARNED_CUDSS[]
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
LinearSolve.ALREADY_WARNED_CUDSS[] = true
end
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
end
end

function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
if !LinearSolve.CUDSS_LOADED[]
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")
end
nothing
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
kwargs...)
if cache.isfresh
fact = qr(CUDA.CuArray(cache.A))
cache.cacheval = fact
cache.isfresh = false
end
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
cache.u .= y
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
qr(CUDA.CuArray(A))
end

function LinearSolve.init_cacheval(
::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function LinearSolve.init_cacheval(
::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function LinearSolve.init_cacheval(
::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

end
Loading

0 comments on commit 4afec5a

Please sign in to comment.