Skip to content

Commit

Permalink
Change to ArrayInterfaceCore
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed May 22, 2022
1 parent 9cf1981 commit 752edd3
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = ["SciML"]
version = "1.16.3"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ArrayInterfaceCore = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Expand All @@ -21,7 +21,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
ArrayInterface = "3, 4, 5"
ArrayInterfaceCore = "0.1.1"
DocStringExtensions = "0.8"
GPUArrays = "8"
IterativeSolvers = "0.9.2"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/advanced/developing.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ is what is called at `init` time to create the first `cacheval`. Note that this
should match the type of the cache later used in `solve` as many algorithms, like
those in OrdinaryDiffEq.jl, expect type-groundedness in the linear solver definitions.
While there are cheaper ways to obtain this type for LU factorizations (specifically,
`ArrayInterface.lu_instance(A)`), for a demonstration this just performs an
`ArrayInterfaceCore.lu_instance(A)`), for a demonstration this just performs an
LU-factorization to get an `LU{T, Matrix{T}}` which it puts into the `cacheval`
so its typed for future use.

Expand Down
2 changes: 1 addition & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module LinearSolve

using ArrayInterface
using ArrayInterfaceCore
using RecursiveFactorization
using Base: cache_dependencies, Bool
import Base: eltype, adjoint, inv
Expand Down
12 changes: 6 additions & 6 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function defaultalg(A,b)
# whether MKL or OpenBLAS is being used
if (A === nothing && !(b isa GPUArrays.AbstractGPUArray)) || A isa Matrix
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b)
ArrayInterfaceCore.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
Expand All @@ -34,7 +34,7 @@ function defaultalg(A,b)

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif A !== nothing && ArrayInterface.isstructured(A)
elseif A !== nothing && ArrayInterfaceCore.isstructured(A)
alg = GenericFactorization()

# This catches the case where A is a CuMatrix
Expand Down Expand Up @@ -64,7 +64,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
if A isa Matrix
b = cache.b
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b)
ArrayInterfaceCore.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
Expand Down Expand Up @@ -94,7 +94,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif ArrayInterface.isstructured(A)
elseif ArrayInterfaceCore.isstructured(A)
alg = GenericFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)

Expand Down Expand Up @@ -122,7 +122,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
# whether MKL or OpenBLAS is being used
if A isa Matrix
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b)
ArrayInterfaceCore.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
Expand Down Expand Up @@ -152,7 +152,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif ArrayInterface.isstructured(A)
elseif ArrayInterfaceCore.isstructured(A)
alg = GenericFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)

Expand Down
20 changes: 10 additions & 10 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function do_factorization(alg::GenericLUFactorization, A, b, u)
return fact
end

init_cacheval(alg::Union{LUFactorization,GenericLUFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::Union{LUFactorization,GenericLUFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A))

# This could be a GenericFactorization perhaps?
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
Expand Down Expand Up @@ -205,18 +205,18 @@ function do_factorization(alg::GenericFactorization, A, b, u)
return fact
end

init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A))

init_cacheval(alg::GenericFactorization{typeof(lu)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu)}, A::Diagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = Diagonal(inv.(A.diag))
init_cacheval(alg::GenericFactorization{typeof(lu)}, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu)}, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Diagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = Diagonal(inv.(A.diag))
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A)

init_cacheval(alg::GenericFactorization, A::Diagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = Diagonal(inv.(A.diag))
init_cacheval(alg::GenericFactorization, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization, A::Tridiagonal, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(A)
init_cacheval(alg::GenericFactorization, A::SymTridiagonal{T,V}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) where {T,V} = LinearAlgebra.LDLt{T,SymTridiagonal{T,V}}(A)

function init_cacheval(alg::Union{GenericFactorization,GenericFactorization{typeof(bunchkaufman!)},GenericFactorization{typeof(bunchkaufman)}},
Expand Down Expand Up @@ -277,5 +277,5 @@ end

RFLUFactorization(;pivot = Val(true), thread = Val(true)) = GenericFactorization(;fact_alg=RFWrapper(pivot,thread))

init_cacheval(alg::GenericFactorization{<:RFWrapper}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::GenericFactorization{<:RFWrapper}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::GenericFactorization{<:RFWrapper}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::GenericFactorization{<:RFWrapper}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterfaceCore.lu_instance(convert(AbstractMatrix,A))

0 comments on commit 752edd3

Please sign in to comment.