Skip to content

Commit

Permalink
Handle wrapper types over GPU arrays correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2023
1 parent ada5366 commit 01469be
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.21.0"
version = "2.21.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
6 changes: 3 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
end
end

function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{Bool})
function defaultalg(A::GPUArraysCore.AnyGPUArray, b, assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -110,7 +110,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump
end

# A === nothing case
function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool})
function defaultalg(A::Nothing, b::GPUArraysCore.AnyGPUArray, assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -119,7 +119,7 @@ function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::Opera
end

# Ambiguity handling
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
function defaultalg(A::GPUArraysCore.AnyGPUArray, b::GPUArraysCore.AbstractGPUArray,
assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
Expand Down
34 changes: 29 additions & 5 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ function do_factorization(alg::LUFactorization, A, b, u)
if A isa AbstractSparseMatrixCSC
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
elseif A isa GPUArraysCore.AnyGPUArray
fact = lu(A; check = false)
elseif !ArrayInterface.can_setindex(typeof(A))
fact = lu(A, alg.pivot, check = false)
else
Expand All @@ -98,6 +100,16 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, A, b
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
end

function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
if alg isa LUFactorization
return lu(A; check=false)
else
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check=false)
end
end

const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1))

function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
Expand Down Expand Up @@ -143,7 +155,7 @@ end
function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if ArrayInterface.can_setindex(typeof(A))
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AnyGPUArray)
fact = qr!(A, alg.pivot)
else
fact = qr(A) # CUDA.jl does not allow other args!
Expand All @@ -160,6 +172,12 @@ function init_cacheval(alg::QRFactorization, A, b, u, Pl, Pr,
ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot)
end

function init_cacheval(alg::QRFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
A isa GPUArraysCore.AnyGPUArray && return qr(A)
return qr(A, alg.pivot)
end

const PREALLOCATED_QR = ArrayInterface.qr_instance(rand(1, 1))

function init_cacheval(alg::QRFactorization{NoPivot}, A::Matrix{Float64}, b, u, Pl, Pr,
Expand Down Expand Up @@ -204,6 +222,8 @@ function do_factorization(alg::CholeskyFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if A isa SparseMatrixCSC
fact = cholesky(A; shift = alg.shift, check = false, perm = alg.perm)
elseif A isa GPUArraysCore.AnyGPUArray
fact = cholesky(A; check = false)
elseif alg.pivot === Val(false) || alg.pivot === NoPivot()
fact = cholesky!(A, alg.pivot; check = false)
else
Expand All @@ -218,9 +238,13 @@ function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl,
cholesky(A)
end

function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray, b, u, Pl,
Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
cholesky(A; check=false)
end

function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
end

Expand Down Expand Up @@ -968,7 +992,7 @@ default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true
const PREALLOCATED_NORMALCHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot())

function init_cacheval(alg::NormalCholeskyFactorization,
A::Union{AbstractSparseArray, GPUArraysCore.AbstractGPUArray,
A::Union{AbstractSparseArray, GPUArraysCore.AnyGPUArray,
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -999,7 +1023,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix
if A isa SparseMatrixCSC || A isa GPUArraysCore.AnyGPUArray || A isa SMatrix
fact = cholesky(Symmetric((A)' * A); check = false)
else
fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false)
Expand Down
16 changes: 16 additions & 0 deletions test/gpu/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,19 @@ using BlockDiagonals

@test solve(prob1, SimpleGMRES(; blocksize = 2)).u solve(prob2, SimpleGMRES()).u
end

# Test Dispatches for Adjoint/Transpose Types
A = Matrix(Hermitian(rand(5, 5) + I)) |> cu
b = rand(5) |> cu
prob1 = LinearProblem(A', b)
prob2 = LinearProblem(transpose(A), b)

@testset "Adjoint/Transpose Type: $(alg)" for alg in (NormalCholeskyFactorization(),
CholeskyFactorization(), LUFactorization(), QRFactorization(), nothing)
sol = solve(prob1, alg; alias_A = false)
@test norm(A' * sol.u .- b) < 1e-5
@show sol.u, A' \ b

sol = solve(prob2, alg; alias_A = false)
@test norm(transpose(A) * sol.u .- b) < 1e-5
end

0 comments on commit 01469be

Please sign in to comment.