Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle wrapper types over GPU arrays correctly #447

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.21.0"
version = "2.21.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
6 changes: 3 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
end
end

function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{Bool})
function defaultalg(A::GPUArraysCore.AnyGPUArray, b, assump::OperatorAssumptions{Bool})

Check warning on line 104 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L104

Added line #L104 was not covered by tests
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -110,7 +110,7 @@
end

# A === nothing case
function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool})
function defaultalg(A::Nothing, b::GPUArraysCore.AnyGPUArray, assump::OperatorAssumptions{Bool})

Check warning on line 113 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L113

Added line #L113 was not covered by tests
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -119,7 +119,7 @@
end

# Ambiguity handling
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
function defaultalg(A::GPUArraysCore.AnyGPUArray, b::GPUArraysCore.AbstractGPUArray,

Check warning on line 122 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L122

Added line #L122 was not covered by tests
assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
Expand Down
35 changes: 30 additions & 5 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
if A isa AbstractSparseMatrixCSC
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
elseif A isa GPUArraysCore.AnyGPUArray
fact = lu(A; check = false)

Check warning on line 82 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L82

Added line #L82 was not covered by tests
elseif !ArrayInterface.can_setindex(typeof(A))
fact = lu(A, alg.pivot, check = false)
else
Expand All @@ -98,6 +100,17 @@
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
end

function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},

Check warning on line 103 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L103

Added line #L103 was not covered by tests
A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
if alg isa LUFactorization
return lu(A; check=false)

Check warning on line 107 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
else
A isa GPUArraysCore.AnyGPUArray && return nothing
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check=false)

Check warning on line 110 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L109-L110

Added lines #L109 - L110 were not covered by tests
end
end

const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1))

function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
Expand Down Expand Up @@ -143,7 +156,7 @@
function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if ArrayInterface.can_setindex(typeof(A))
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AnyGPUArray)

Check warning on line 159 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L159

Added line #L159 was not covered by tests
fact = qr!(A, alg.pivot)
else
fact = qr(A) # CUDA.jl does not allow other args!
Expand All @@ -160,6 +173,12 @@
ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot)
end

function init_cacheval(alg::QRFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr,

Check warning on line 176 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L176

Added line #L176 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
A isa GPUArraysCore.AnyGPUArray && return qr(A)
return qr(A, alg.pivot)

Check warning on line 179 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L178-L179

Added lines #L178 - L179 were not covered by tests
end

const PREALLOCATED_QR = ArrayInterface.qr_instance(rand(1, 1))

function init_cacheval(alg::QRFactorization{NoPivot}, A::Matrix{Float64}, b, u, Pl, Pr,
Expand Down Expand Up @@ -204,6 +223,8 @@
A = convert(AbstractMatrix, A)
if A isa SparseMatrixCSC
fact = cholesky(A; shift = alg.shift, check = false, perm = alg.perm)
elseif A isa GPUArraysCore.AnyGPUArray
fact = cholesky(A; check = false)

Check warning on line 227 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L226-L227

Added lines #L226 - L227 were not covered by tests
elseif alg.pivot === Val(false) || alg.pivot === NoPivot()
fact = cholesky!(A, alg.pivot; check = false)
else
Expand All @@ -218,9 +239,13 @@
cholesky(A)
end

function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray, b, u, Pl,

Check warning on line 242 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L242

Added line #L242 was not covered by tests
Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
cholesky(A; check=false)

Check warning on line 244 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L244

Added line #L244 was not covered by tests
end

function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
end

Expand Down Expand Up @@ -968,7 +993,7 @@
const PREALLOCATED_NORMALCHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot())

function init_cacheval(alg::NormalCholeskyFactorization,
A::Union{AbstractSparseArray, GPUArraysCore.AbstractGPUArray,
A::Union{AbstractSparseArray, GPUArraysCore.AnyGPUArray,
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -999,7 +1024,7 @@
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix
if A isa SparseMatrixCSC || A isa GPUArraysCore.AnyGPUArray || A isa SMatrix

Check warning on line 1027 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L1027

Added line #L1027 was not covered by tests
fact = cholesky(Symmetric((A)' * A); check = false)
else
fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false)
Expand Down
3 changes: 2 additions & 1 deletion test/gpu/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
19 changes: 18 additions & 1 deletion test/gpu/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearSolve, CUDA, LinearAlgebra, SparseArrays
using LinearSolve, CUDA, LinearAlgebra, SparseArrays, StableRNGs
using Test

CUDA.allowscalar(false)
Expand Down Expand Up @@ -73,3 +73,20 @@ using BlockDiagonals

@test solve(prob1, SimpleGMRES(; blocksize = 2)).u ≈ solve(prob2, SimpleGMRES()).u
end

# Test Dispatches for Adjoint/Transpose Types
rng = StableRNG(0)

A = Matrix(Hermitian(rand(rng, 5, 5) + I)) |> cu
b = rand(rng, 5) |> cu
prob1 = LinearProblem(A', b)
prob2 = LinearProblem(transpose(A), b)

@testset "Adjoint/Transpose Type: $(alg)" for alg in (NormalCholeskyFactorization(),
CholeskyFactorization(), LUFactorization(), QRFactorization(), nothing)
sol = solve(prob1, alg; alias_A = false)
@test norm(A' * sol.u .- b) < 1e-5

sol = solve(prob2, alg; alias_A = false)
@test norm(transpose(A) * sol.u .- b) < 1e-5
end
Loading