diff --git a/Project.toml b/Project.toml index ca4f6d76f..e9b1ac491 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,7 @@ version = "1.18.1" [deps] ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KLU = "ef3ab10e-7fda-4108-b977-705223b18434" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" @@ -23,7 +23,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] ArrayInterfaceCore = "0.1.1" DocStringExtensions = "0.8" -GPUArrays = "8" +GPUArraysCore = "0.1" IterativeSolvers = "0.9.2" KLU = "0.3.0" Krylov = "0.7.11, 0.8" diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 16ace7803..1cb3ae379 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -14,7 +14,7 @@ using SuiteSparse using KLU using DocStringExtensions -import GPUArrays +import GPUArraysCore # wrap import Krylov diff --git a/src/default.jl b/src/default.jl index 68e06e816..301f201ae 100644 --- a/src/default.jl +++ b/src/default.jl @@ -9,7 +9,7 @@ function defaultalg(A,b) # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when # it makes sense according to the benchmarks, which is dependent on # whether MKL or OpenBLAS is being used - if (A === nothing && !(b isa GPUArrays.AbstractGPUArray)) || A isa Matrix + if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) && ArrayInterfaceCore.can_setindex(b) if length(b) <= 10 @@ -39,7 +39,7 @@ function defaultalg(A,b) # This catches the case where A is a CuMatrix # Which does not have LU fully defined - elseif A isa GPUArrays.AbstractGPUArray || b isa GPUArrays.AbstractGPUArray + elseif A isa GPUArraysCore.AbstractGPUArray || b isa GPUArraysCore.AbstractGPUArray alg = QRFactorization(false) # Not factorizable operator, default to only using A*x @@ -100,7 +100,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing, # This catches the case where A is a CuMatrix # Which does not have LU fully defined - elseif A isa GPUArrays.AbstractGPUArray + elseif A isa GPUArraysCore.AbstractGPUArray alg = QRFactorization(false) SciMLBase.solve(cache, alg, args...; kwargs...) @@ -158,7 +158,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, # This catches the case where A is a CuMatrix # Which does not have LU fully defined - elseif A isa GPUArrays.AbstractGPUArray + elseif A isa GPUArraysCore.AbstractGPUArray alg = QRFactorization(false) init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)