Skip to content

Commit

Permalink
Change from GPUArrays to GPUArraysCore
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jun 22, 2022
1 parent 6ce0709 commit 0e19abe
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version = "1.18.1"
[deps]
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Expand All @@ -23,7 +23,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
[compat]
ArrayInterfaceCore = "0.1.1"
DocStringExtensions = "0.8"
GPUArrays = "8"
GPUArraysCore = "0.1"
IterativeSolvers = "0.9.2"
KLU = "0.3.0"
Krylov = "0.7.11, 0.8"
Expand Down
2 changes: 1 addition & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using SuiteSparse
using KLU
using DocStringExtensions

import GPUArrays
import GPUArraysCore

# wrap
import Krylov
Expand Down
8 changes: 4 additions & 4 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function defaultalg(A,b)
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
# it makes sense according to the benchmarks, which is dependent on
# whether MKL or OpenBLAS is being used
if (A === nothing && !(b isa GPUArrays.AbstractGPUArray)) || A isa Matrix
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterfaceCore.can_setindex(b)
if length(b) <= 10
Expand Down Expand Up @@ -39,7 +39,7 @@ function defaultalg(A,b)

# This catches the case where A is a CuMatrix
# Which does not have LU fully defined
elseif A isa GPUArrays.AbstractGPUArray || b isa GPUArrays.AbstractGPUArray
elseif A isa GPUArraysCore.AbstractGPUArray || b isa GPUArraysCore.AbstractGPUArray
alg = QRFactorization(false)

# Not factorizable operator, default to only using A*x
Expand Down Expand Up @@ -100,7 +100,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,

# This catches the case where A is a CuMatrix
# Which does not have LU fully defined
elseif A isa GPUArrays.AbstractGPUArray
elseif A isa GPUArraysCore.AbstractGPUArray
alg = QRFactorization(false)
SciMLBase.solve(cache, alg, args...; kwargs...)

Expand Down Expand Up @@ -158,7 +158,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,

# This catches the case where A is a CuMatrix
# Which does not have LU fully defined
elseif A isa GPUArrays.AbstractGPUArray
elseif A isa GPUArraysCore.AbstractGPUArray
alg = QRFactorization(false)
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)

Expand Down

0 comments on commit 0e19abe

Please sign in to comment.