From 01469bea4022d6e1bb624946ad17407a01a00af9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Dec 2023 14:32:08 -0500 Subject: [PATCH] Handle wrapper types over GPU arrays correctly --- Project.toml | 2 +- src/default.jl | 6 +++--- src/factorization.jl | 34 +++++++++++++++++++++++++++++----- test/gpu/cuda.jl | 16 ++++++++++++++++ 4 files changed, 49 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index c40adca6c..ba9907272 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.21.0" +version = "2.21.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/default.jl b/src/default.jl index 4c4b43a1f..0d2daf19c 100644 --- a/src/default.jl +++ b/src/default.jl @@ -101,7 +101,7 @@ end end end -function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{Bool}) +function defaultalg(A::GPUArraysCore.AnyGPUArray, b, assump::OperatorAssumptions{Bool}) if assump.condition === OperatorCondition.IllConditioned || !assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) else @@ -110,7 +110,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump end # A === nothing case -function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool}) +function defaultalg(A::Nothing, b::GPUArraysCore.AnyGPUArray, assump::OperatorAssumptions{Bool}) if assump.condition === OperatorCondition.IllConditioned || !assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) else @@ -119,7 +119,7 @@ function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::Opera end # Ambiguity handling -function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, +function defaultalg(A::GPUArraysCore.AnyGPUArray, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool}) if assump.condition === OperatorCondition.IllConditioned || !assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) diff --git a/src/factorization.jl b/src/factorization.jl index 59ade8266..cb0c2b2e6 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -78,6 +78,8 @@ function do_factorization(alg::LUFactorization, A, b, u) if A isa AbstractSparseMatrixCSC return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), check = false) + elseif A isa GPUArraysCore.AnyGPUArray + fact = lu(A; check = false) elseif !ArrayInterface.can_setindex(typeof(A)) fact = lu(A, alg.pivot, check = false) else @@ -98,6 +100,16 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, A, b ArrayInterface.lu_instance(convert(AbstractMatrix, A)) end +function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, + A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, + verbose::Bool, assumptions::OperatorAssumptions) + if alg isa LUFactorization + return lu(A; check=false) + else + return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check=false) + end +end + const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1)) function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, @@ -143,7 +155,7 @@ end function do_factorization(alg::QRFactorization, A, b, u) A = convert(AbstractMatrix, A) if ArrayInterface.can_setindex(typeof(A)) - if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray) + if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AnyGPUArray) fact = qr!(A, alg.pivot) else fact = qr(A) # CUDA.jl does not allow other args! @@ -160,6 +172,12 @@ function init_cacheval(alg::QRFactorization, A, b, u, Pl, Pr, ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot) end +function init_cacheval(alg::QRFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + A isa GPUArraysCore.AnyGPUArray && return qr(A) + return qr(A, alg.pivot) +end + const PREALLOCATED_QR = ArrayInterface.qr_instance(rand(1, 1)) function init_cacheval(alg::QRFactorization{NoPivot}, A::Matrix{Float64}, b, u, Pl, Pr, @@ -204,6 +222,8 @@ function do_factorization(alg::CholeskyFactorization, A, b, u) A = convert(AbstractMatrix, A) if A isa SparseMatrixCSC fact = cholesky(A; shift = alg.shift, check = false, perm = alg.perm) + elseif A isa GPUArraysCore.AnyGPUArray + fact = cholesky(A; check = false) elseif alg.pivot === Val(false) || alg.pivot === NoPivot() fact = cholesky!(A, alg.pivot; check = false) else @@ -218,9 +238,13 @@ function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, cholesky(A) end +function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray, b, u, Pl, + Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + cholesky(A; check=false) +end + function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr, - maxiters::Int, abstol, reltol, verbose::Bool, - assumptions::OperatorAssumptions) + maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot) end @@ -968,7 +992,7 @@ default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true const PREALLOCATED_NORMALCHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot()) function init_cacheval(alg::NormalCholeskyFactorization, - A::Union{AbstractSparseArray, GPUArraysCore.AbstractGPUArray, + A::Union{AbstractSparseArray, GPUArraysCore.AnyGPUArray, Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -999,7 +1023,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization; A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh - if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix + if A isa SparseMatrixCSC || A isa GPUArraysCore.AnyGPUArray || A isa SMatrix fact = cholesky(Symmetric((A)' * A); check = false) else fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false) diff --git a/test/gpu/cuda.jl b/test/gpu/cuda.jl index 75a181026..5e0e5518e 100644 --- a/test/gpu/cuda.jl +++ b/test/gpu/cuda.jl @@ -73,3 +73,19 @@ using BlockDiagonals @test solve(prob1, SimpleGMRES(; blocksize = 2)).u ≈ solve(prob2, SimpleGMRES()).u end + +# Test Dispatches for Adjoint/Transpose Types +A = Matrix(Hermitian(rand(5, 5) + I)) |> cu +b = rand(5) |> cu +prob1 = LinearProblem(A', b) +prob2 = LinearProblem(transpose(A), b) + +@testset "Adjoint/Transpose Type: $(alg)" for alg in (NormalCholeskyFactorization(), + CholeskyFactorization(), LUFactorization(), QRFactorization(), nothing) + sol = solve(prob1, alg; alias_A = false) + @test norm(A' * sol.u .- b) < 1e-5 + @show sol.u, A' \ b + + sol = solve(prob2, alg; alias_A = false) + @test norm(transpose(A) * sol.u .- b) < 1e-5 +end