Skip to content

Commit

Permalink
Merge pull request #447 from SciML/ap/normal_cholesky_dispatches
Browse files Browse the repository at this point in the history
Handle wrapper types over GPU arrays correctly
  • Loading branch information
ChrisRackauckas authored Dec 18, 2023
2 parents ada5366 + 65d76be commit 3b4f4ed
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.21.0"
version = "2.21.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
6 changes: 3 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
end
end

function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{Bool})
function defaultalg(A::GPUArraysCore.AnyGPUArray, b, assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -110,7 +110,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump
end

# A === nothing case
function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool})
function defaultalg(A::Nothing, b::GPUArraysCore.AnyGPUArray, assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -119,7 +119,7 @@ function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::Opera
end

# Ambiguity handling
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
function defaultalg(A::GPUArraysCore.AnyGPUArray, b::GPUArraysCore.AbstractGPUArray,
assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
Expand Down
35 changes: 30 additions & 5 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ function do_factorization(alg::LUFactorization, A, b, u)
if A isa AbstractSparseMatrixCSC
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
elseif A isa GPUArraysCore.AnyGPUArray
fact = lu(A; check = false)
elseif !ArrayInterface.can_setindex(typeof(A))
fact = lu(A, alg.pivot, check = false)
else
Expand All @@ -98,6 +100,17 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, A, b
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
end

function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
if alg isa LUFactorization
return lu(A; check=false)
else
A isa GPUArraysCore.AnyGPUArray && return nothing
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check=false)
end
end

const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1))

function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
Expand Down Expand Up @@ -143,7 +156,7 @@ end
function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if ArrayInterface.can_setindex(typeof(A))
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AnyGPUArray)
fact = qr!(A, alg.pivot)
else
fact = qr(A) # CUDA.jl does not allow other args!
Expand All @@ -160,6 +173,12 @@ function init_cacheval(alg::QRFactorization, A, b, u, Pl, Pr,
ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot)
end

function init_cacheval(alg::QRFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
A isa GPUArraysCore.AnyGPUArray && return qr(A)
return qr(A, alg.pivot)
end

const PREALLOCATED_QR = ArrayInterface.qr_instance(rand(1, 1))

function init_cacheval(alg::QRFactorization{NoPivot}, A::Matrix{Float64}, b, u, Pl, Pr,
Expand Down Expand Up @@ -204,6 +223,8 @@ function do_factorization(alg::CholeskyFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if A isa SparseMatrixCSC
fact = cholesky(A; shift = alg.shift, check = false, perm = alg.perm)
elseif A isa GPUArraysCore.AnyGPUArray
fact = cholesky(A; check = false)
elseif alg.pivot === Val(false) || alg.pivot === NoPivot()
fact = cholesky!(A, alg.pivot; check = false)
else
Expand All @@ -218,9 +239,13 @@ function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl,
cholesky(A)
end

function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray, b, u, Pl,
Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
cholesky(A; check=false)
end

function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
end

Expand Down Expand Up @@ -968,7 +993,7 @@ default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true
const PREALLOCATED_NORMALCHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot())

function init_cacheval(alg::NormalCholeskyFactorization,
A::Union{AbstractSparseArray, GPUArraysCore.AbstractGPUArray,
A::Union{AbstractSparseArray, GPUArraysCore.AnyGPUArray,
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -999,7 +1024,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix
if A isa SparseMatrixCSC || A isa GPUArraysCore.AnyGPUArray || A isa SMatrix
fact = cholesky(Symmetric((A)' * A); check = false)
else
fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false)
Expand Down
3 changes: 2 additions & 1 deletion test/gpu/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
19 changes: 18 additions & 1 deletion test/gpu/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearSolve, CUDA, LinearAlgebra, SparseArrays
using LinearSolve, CUDA, LinearAlgebra, SparseArrays, StableRNGs
using Test

CUDA.allowscalar(false)
Expand Down Expand Up @@ -73,3 +73,20 @@ using BlockDiagonals

@test solve(prob1, SimpleGMRES(; blocksize = 2)).u solve(prob2, SimpleGMRES()).u
end

# Test Dispatches for Adjoint/Transpose Types
rng = StableRNG(0)

A = Matrix(Hermitian(rand(rng, 5, 5) + I)) |> cu
b = rand(rng, 5) |> cu
prob1 = LinearProblem(A', b)
prob2 = LinearProblem(transpose(A), b)

@testset "Adjoint/Transpose Type: $(alg)" for alg in (NormalCholeskyFactorization(),
CholeskyFactorization(), LUFactorization(), QRFactorization(), nothing)
sol = solve(prob1, alg; alias_A = false)
@test norm(A' * sol.u .- b) < 1e-5

sol = solve(prob2, alg; alias_A = false)
@test norm(transpose(A) * sol.u .- b) < 1e-5
end

2 comments on commit 3b4f4ed

@avik-pal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/97364

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.21.1 -m "<description of version>" 3b4f4edcf9dd3f1cc42451405a759d62670aacb8
git push origin v2.21.1

Please sign in to comment.