Skip to content

Commit

Permalink
[CUSOLVER] Interface gesv! and gels! (#2406)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Dec 16, 2024
1 parent 19a08ef commit 4e9513b
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 65 deletions.
1 change: 1 addition & 0 deletions lib/cusolver/CUSOLVER.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("libcusolverMg.jl")
include("libcusolverRF.jl")

# low-level wrappers
include("helpers.jl")
include("error.jl")
include("base.jl")
include("sparse.jl")
Expand Down
60 changes: 60 additions & 0 deletions lib/cusolver/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,63 @@ function Base.convert(::Type{cusolverDirectMode_t}, direct::Char)
throw(ArgumentError("Unknown direction mode $direct."))
end
end

function Base.convert(::Type{cusolverIRSRefinement_t}, irs::String)
if irs == "NOT_SET"
CUSOLVER_IRS_REFINE_NOT_SET
elseif irs == "NONE"
CUSOLVER_IRS_REFINE_NONE
elseif irs == "CLASSICAL"
CUSOLVER_IRS_REFINE_CLASSICAL
elseif "CLASSICAL_GMRES"
CUSOLVER_IRS_REFINE_CLASSICAL_GMRES
elseif "GMRES"
CUSOLVER_IRS_REFINE_GMRES
elseif "GMRES_GMRES"
CUSOLVER_IRS_REFINE_GMRES_GMRES
elseif "GMRES_NOPCOND"
CUSOLVER_IRS_REFINE_GMRES_NOPCOND
else
throw(ArgumentError("Unknown iterative refinement solver $irs."))
end
end

function Base.convert(::Type{cusolverPrecType_t}, T::String)
if T == "R_16F"
return CUSOLVER_R_16F
elseif T == "R_16BF"
return CUSOLVER_R_16BF
elseif T == "R_TF32"
return CUSOLVER_R_TF32
elseif T == "R_32F"
return CUSOLVER_R_32F
elseif T == "R_64F"
return CUSOLVER_R_64F
elseif T == "C_16F"
return CUSOLVER_C_16F
elseif T == "C_16BF"
return CUSOLVER_C_16BF
elseif T == "C_TF32"
return CUSOLVER_C_TF32
elseif T == "C_32F"
return CUSOLVER_C_32F
elseif T == "C_64F"
return CUSOLVER_C_64F
else
throw(ArgumentError("cusolverPrecType_t equivalent for input type $T does not exist!"))
end
end

function Base.convert(::Type{cusolverPrecType_t}, T::DataType)
if T === Float32
return CUSOLVER_R_32F
elseif T === Float64
return CUSOLVER_R_64F
elseif T === Complex{Float32}
return CUSOLVER_C_32F
elseif T === Complex{Float64}
return CUSOLVER_C_64F
else
throw(ArgumentError("cusolverPrecType_t equivalent for input type $T does not exist!"))
end
end
108 changes: 108 additions & 0 deletions lib/cusolver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,114 @@ for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32),
end
end

# gesv
function gesv!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true,
residual_history::Bool=false, irs_precision::String="AUTO", refinement_solver::String="CLASSICAL",
maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, tol_inner=Float64=0.0) where T <: BlasFloat

params = CuSolverIRSParameters()
info = CuSolverIRSInformation()
n = checksquare(A)
nrhs = size(B, 2)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
ldx = max(1, stride(X, 2))
niters = Ref{Cint}()
dh = dense_handle()

if irs_precision == "AUTO"
(T == Float32) && (irs_precision = "R_32F")
(T == Float64) && (irs_precision = "R_64F")
(T == ComplexF32) && (irs_precision = "C_32F")
(T == ComplexF64) && (irs_precision = "C_64F")
else
(T == Float32) && (irs_precision ("R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
(T == Float64) && (irs_precision ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
(T == ComplexF32) && (irs_precision ("C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
(T == ComplexF64) && (irs_precision ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
end
cusolverDnIRSParamsSetSolverMainPrecision(params, T)
cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision)
cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver)
(tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol)
(tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner)
(maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters)
(maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner)
fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params)
residual_history && cusolverDnIRSInfosRequestResidual(info)

function bufferSize()
buffer_size = Ref{Csize_t}(0)
cusolverDnIRSXgesv_bufferSize(dh, params, n, nrhs, buffer_size)
return buffer_size[]
end

with_workspace(dh.workspace_gpu, bufferSize) do buffer
cusolverDnIRSXgesv(dh, params, info, n, nrhs, A, lda, B, ldb,
X, ldx, buffer, sizeof(buffer), niters, dh.info)
end

# Copy the solver flag and delete the device memory
flag = @allowscalar dh.info[1]
chklapackerror(flag |> BlasInt)

return X, info
end

# gels
function gels!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true,
residual_history::Bool=false, irs_precision::String="AUTO", refinement_solver::String="CLASSICAL",
maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, tol_inner=Float64=0.0) where T <: BlasFloat

params = CuSolverIRSParameters()
info = CuSolverIRSInformation()
m,n = size(A)
nrhs = size(B, 2)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
ldx = max(1, stride(X, 2))
niters = Ref{Cint}()
dh = dense_handle()

if irs_precision == "AUTO"
(T == Float32) && (irs_precision = "R_32F")
(T == Float64) && (irs_precision = "R_64F")
(T == ComplexF32) && (irs_precision = "C_32F")
(T == ComplexF64) && (irs_precision = "C_64F")
else
(T == Float32) && (irs_precision ("R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
(T == Float64) && (irs_precision ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
(T == ComplexF32) && (irs_precision ("C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
(T == ComplexF64) && (irs_precision ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
end
cusolverDnIRSParamsSetSolverMainPrecision(params, T)
cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision)
cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver)
(tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol)
(tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner)
(maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters)
(maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner)
fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params)
residual_history && cusolverDnIRSInfosRequestResidual(info)

function bufferSize()
buffer_size = Ref{Csize_t}(0)
cusolverDnIRSXgels_bufferSize(dh, params, m, n, nrhs, buffer_size)
return buffer_size[]
end

with_workspace(dh.workspace_gpu, bufferSize) do buffer
cusolverDnIRSXgels(dh, params, info, m, n, nrhs, A, lda, B, ldb,
X, ldx, buffer, sizeof(buffer), niters, dh.info)
end

# Copy the solver flag and delete the device memory
flag = @allowscalar dh.info[1]
chklapackerror(flag |> BlasInt)

return X, info
end

# LAPACK
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
@eval begin
Expand Down
14 changes: 0 additions & 14 deletions lib/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
mutable struct CuSolverParameters
parameters::cusolverDnParams_t

function CuSolverParameters()
parameters_ref = Ref{cusolverDnParams_t}()
cusolverDnCreateParams(parameters_ref)
obj = new(parameters_ref[])
finalizer(cusolverDnDestroyParams, obj)
obj
end
end

Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters

# Xpotrf
function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
chkuplo(uplo)
Expand Down
2 changes: 2 additions & 0 deletions lib/cusolver/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ function description(err)
"an internal operation failed"
elseif err.code == CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED
"the matrix type is not supported."
elseif err.code == CUSOLVER_STATUS_NOT_SUPPORTED
"the parameter combination is not supported."
else
"no description for this error"
end
Expand Down
148 changes: 148 additions & 0 deletions lib/cusolver/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# cuSOLVER helper functions

## SparseQRInfo

mutable struct SparseQRInfo
info::csrqrInfo_t

function SparseQRInfo()
info_ref = Ref{csrqrInfo_t}()
cusolverSpCreateCsrqrInfo(info_ref)
obj = new(info_ref[])
finalizer(cusolverSpDestroyCsrqrInfo, obj)
obj
end
end

Base.unsafe_convert(::Type{csrqrInfo_t}, info::SparseQRInfo) = info.info


## SparseCholeskyInfo

mutable struct SparseCholeskyInfo
info::csrcholInfo_t

function SparseCholeskyInfo()
info_ref = Ref{csrcholInfo_t}()
cusolverSpCreateCsrcholInfo(info_ref)
obj = new(info_ref[])
finalizer(cusolverSpDestroyCsrcholInfo, obj)
obj
end
end

Base.unsafe_convert(::Type{csrcholInfo_t}, info::SparseCholeskyInfo) = info.info


## CuSolverParameters

mutable struct CuSolverParameters
parameters::cusolverDnParams_t

function CuSolverParameters()
parameters_ref = Ref{cusolverDnParams_t}()
cusolverDnCreateParams(parameters_ref)
obj = new(parameters_ref[])
finalizer(cusolverDnDestroyParams, obj)
obj
end
end

Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters


## CuSolverIRSParameters

mutable struct CuSolverIRSParameters
parameters::cusolverDnIRSParams_t

function CuSolverIRSParameters()
parameters_ref = Ref{cusolverDnIRSParams_t}()
cusolverDnIRSParamsCreate(parameters_ref)
obj = new(parameters_ref[])
finalizer(cusolverDnIRSParamsDestroy, obj)
obj
end
end

Base.unsafe_convert(::Type{cusolverDnIRSParams_t}, params::CuSolverIRSParameters) = params.parameters

function get_info(params::CuSolverIRSParameters, field::Symbol)
if field == :maxiters
maxiters = Ref{Cint}()
cusolverDnIRSParamsGetMaxIters(params, maxiters)
return maxiters[]
else
error("The information $field is incorrect.")
end
end


## CuSolverIRSInformation

mutable struct CuSolverIRSInformation
information::cusolverDnIRSInfos_t

function CuSolverIRSInformation()
info_ref = Ref{cusolverDnIRSInfos_t}()
cusolverDnIRSInfosCreate(info_ref)
obj = new(info_ref[])
finalizer(cusolverDnIRSInfosDestroy, obj)
obj
end
end

Base.unsafe_convert(::Type{cusolverDnIRSInfos_t}, info::CuSolverIRSInformation) = info.information

function get_info(info::CuSolverIRSInformation, field::Symbol)
if field == :niters
niters = Ref{Cint}()
cusolverDnIRSInfosGetNiters(info, niters)
return niters[]
elseif field == :outer_niters
outer_niters = Ref{Cint}()
cusolverDnIRSInfosGetOuterNiters(info, outer_niters)
return outer_niters[]
# elseif field == :residual_history
# residual_history = Ref{Ptr{Cvoid}
# cusolverDnIRSInfosGetResidualHistory(info, residual_history)
# return residual_history[]
elseif field == :maxiters
maxiters = Ref{Cint}()
cusolverDnIRSInfosGetMaxIters(info, maxiters)
return maxiters[]
else
error("The information $field is incorrect.")
end
end


## MatrixDescriptor

mutable struct MatrixDescriptor
desc::cudaLibMgMatrixDesc_t

function MatrixDescriptor(a, grid; rowblocks = size(a, 1), colblocks = size(a, 2), elta=eltype(a) )
desc = Ref{cudaLibMgMatrixDesc_t}()
cusolverMgCreateMatrixDesc(desc, size(a, 1), size(a, 2), rowblocks, colblocks, elta, grid)
return new(desc[])
end
end

Base.unsafe_convert(::Type{cudaLibMgMatrixDesc_t}, obj::MatrixDescriptor) = obj.desc


## DeviceGrid

mutable struct DeviceGrid
desc::cudaLibMgGrid_t

function DeviceGrid(num_row_devs, num_col_devs, deviceIds, mapping)
@assert num_row_devs == 1 "Only 1-D column block cyclic is supported, so numRowDevices must be equal to 1."
desc = Ref{cudaLibMgGrid_t}()
cusolverMgCreateDeviceGrid(desc, num_row_devs, num_col_devs, deviceIds, mapping)
return new(desc[])
end
end

Base.unsafe_convert(::Type{cudaLibMgGrid_t}, obj::DeviceGrid) = obj.desc
25 changes: 0 additions & 25 deletions lib/cusolver/multigpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,6 @@
# NOTE: in the cublasMg preview, which also relies on this functionality, a separate library
# called 'cudalibmg' is introduced. factor this out when we actually ship that.

mutable struct MatrixDescriptor
desc::cudaLibMgMatrixDesc_t

function MatrixDescriptor(a, grid; rowblocks = size(a, 1), colblocks = size(a, 2), elta=eltype(a) )
desc = Ref{cudaLibMgMatrixDesc_t}()
cusolverMgCreateMatrixDesc(desc, size(a, 1), size(a, 2), rowblocks, colblocks, elta, grid)
return new(desc[])
end
end

Base.unsafe_convert(::Type{cudaLibMgMatrixDesc_t}, obj::MatrixDescriptor) = obj.desc

mutable struct DeviceGrid
desc::cudaLibMgGrid_t

function DeviceGrid(num_row_devs, num_col_devs, deviceIds, mapping)
@assert num_row_devs == 1 "Only 1-D column block cyclic is supported, so numRowDevices must be equal to 1."
desc = Ref{cudaLibMgGrid_t}()
cusolverMgCreateDeviceGrid(desc, num_row_devs, num_col_devs, deviceIds, mapping)
return new(desc[])
end
end

Base.unsafe_convert(::Type{cudaLibMgGrid_t}, obj::DeviceGrid) = obj.desc

function allocateBuffers(n_row_devs, n_col_devs, mat::Matrix)
mat_row_block_size = div(size(mat, 1), n_row_devs)
mat_col_block_size = div(size(mat, 2), n_col_devs)
Expand Down
Loading

0 comments on commit 4e9513b

Please sign in to comment.