diff --git a/Project.toml b/Project.toml index 40a513794..1a4c17f20 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ authors = ["SciML"] version = "1.16.3" [deps] -ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ArrayInterfaceCore = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" @@ -21,7 +21,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] -ArrayInterface = "3, 4, 5" +ArrayInterfaceCore = "0.1.1" DocStringExtensions = "0.8" GPUArrays = "8" IterativeSolvers = "0.9.2" diff --git a/docs/src/advanced/developing.md b/docs/src/advanced/developing.md index f4835f496..dbc1e3a4d 100644 --- a/docs/src/advanced/developing.md +++ b/docs/src/advanced/developing.md @@ -43,7 +43,7 @@ is what is called at `init` time to create the first `cacheval`. Note that this should match the type of the cache later used in `solve` as many algorithms, like those in OrdinaryDiffEq.jl, expect type-groundedness in the linear solver definitions. While there are cheaper ways to obtain this type for LU factorizations (specifically, -`ArrayInterface.lu_instance(A)`), for a demonstration this just performs an +`ArrayInterfaceCore.lu_instance(A)`), for a demonstration this just performs an LU-factorization to get an `LU{T, Matrix{T}}` which it puts into the `cacheval` so its typed for future use. diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 91aa11681..16ace7803 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -1,6 +1,6 @@ module LinearSolve -using ArrayInterface +using ArrayInterfaceCore using RecursiveFactorization using Base: cache_dependencies, Bool import Base: eltype, adjoint, inv diff --git a/src/default.jl b/src/default.jl index e90294920..68e06e816 100644 --- a/src/default.jl +++ b/src/default.jl @@ -11,7 +11,7 @@ function defaultalg(A,b) # whether MKL or OpenBLAS is being used if (A === nothing && !(b isa GPUArrays.AbstractGPUArray)) || A isa Matrix if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) && - ArrayInterface.can_setindex(b) + ArrayInterfaceCore.can_setindex(b) if length(b) <= 10 alg = GenericLUFactorization() elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) @@ -34,7 +34,7 @@ function defaultalg(A,b) # This catches the cases where a factorization overload could exist # For example, BlockBandedMatrix - elseif A !== nothing && ArrayInterface.isstructured(A) + elseif A !== nothing && ArrayInterfaceCore.isstructured(A) alg = GenericFactorization() # This catches the case where A is a CuMatrix @@ -64,7 +64,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing, if A isa Matrix b = cache.b if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) && - ArrayInterface.can_setindex(b) + ArrayInterfaceCore.can_setindex(b) if length(b) <= 10 alg = GenericLUFactorization() SciMLBase.solve(cache, alg, args...; kwargs...) @@ -94,7 +94,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing, # This catches the cases where a factorization overload could exist # For example, BlockBandedMatrix - elseif ArrayInterface.isstructured(A) + elseif ArrayInterfaceCore.isstructured(A) alg = GenericFactorization() SciMLBase.solve(cache, alg, args...; kwargs...) @@ -122,7 +122,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, # whether MKL or OpenBLAS is being used if A isa Matrix if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) && - ArrayInterface.can_setindex(b) + ArrayInterfaceCore.can_setindex(b) if length(b) <= 10 alg = GenericLUFactorization() init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) @@ -152,7 +152,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, # This catches the cases where a factorization overload could exist # For example, BlockBandedMatrix - elseif ArrayInterface.isstructured(A) + elseif ArrayInterfaceCore.isstructured(A) alg = GenericFactorization() init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) diff --git a/src/factorization.jl b/src/factorization.jl index 6bee0daaa..cc6320811 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -63,7 +63,7 @@ function do_factorization(alg::GenericLUFactorization, A, b, u) return fact end -init_cacheval(alg::Union{LUFactorization,GenericLUFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) +init_cacheval(alg::Union{LUFactorization,GenericLUFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A)) # This could be a GenericFactorization perhaps? Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization @@ -205,18 +205,18 @@ function do_factorization(alg::GenericFactorization, A, b, u) return fact end -init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) -init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) +init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A)) +init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A)) -init_cacheval(alg::GenericFactorization{typeof(lu)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) -init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) +init_cacheval(alg::GenericFactorization{typeof(lu)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A) +init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A) init_cacheval(alg::GenericFactorization{typeof(lu)}, A::Diagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = Diagonal(inv.(A.diag)) -init_cacheval(alg::GenericFactorization{typeof(lu)}, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) +init_cacheval(alg::GenericFactorization{typeof(lu)}, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A) init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Diagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = Diagonal(inv.(A.diag)) -init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) +init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A) init_cacheval(alg::GenericFactorization, A::Diagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = Diagonal(inv.(A.diag)) -init_cacheval(alg::GenericFactorization, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) +init_cacheval(alg::GenericFactorization, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A) init_cacheval(alg::GenericFactorization, A::SymTridiagonal{T,V}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) where {T,V} = LinearAlgebra.LDLt{T,SymTridiagonal{T,V}}(A) function init_cacheval(alg::Union{GenericFactorization,GenericFactorization{typeof(bunchkaufman!)},GenericFactorization{typeof(bunchkaufman)}}, @@ -277,5 +277,5 @@ end RFLUFactorization(;pivot = Val(true), thread = Val(true)) = GenericFactorization(;fact_alg=RFWrapper(pivot,thread)) -init_cacheval(alg::GenericFactorization{<:RFWrapper}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) -init_cacheval(alg::GenericFactorization{<:RFWrapper}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) +init_cacheval(alg::GenericFactorization{<:RFWrapper}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A)) +init_cacheval(alg::GenericFactorization{<:RFWrapper}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A))