Skip to content

Commit

Permalink
Fix NormalCholesky on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 8, 2023
1 parent e60a10a commit 300c4a9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.20.1"
version = "2.20.2"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
4 changes: 2 additions & 2 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true
const PREALLOCATED_NORMALCHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot())

function init_cacheval(alg::NormalCholeskyFactorization,
A::Union{AbstractSparseArray,
A::Union{AbstractSparseArray, GPUArraysCore.AbstractGPUArray,
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Expand All @@ -921,7 +921,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray

Check warning on line 924 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L924

Added line #L924 was not covered by tests
fact = cholesky(Symmetric((A)' * A, :L); check = false)
else
fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false)
Expand Down
4 changes: 2 additions & 2 deletions test/gpu/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ function test_interface(alg, prob1, prob2)
return
end

@testset "CudaOffloadFactorization" begin
test_interface(CudaOffloadFactorization(), prob1, prob2)
@testset "$alg" for alg in (CudaOffloadFactorization(), NormalCholeskyFactorization())
test_interface(alg, prob1, prob2)
end

@testset "Simple GMRES: restart = $restart" for restart in (true, false)
Expand Down

0 comments on commit 300c4a9

Please sign in to comment.