Skip to content

Commit

Permalink
Merge pull request #387 from SciML/mklfactorization_default
Browse files Browse the repository at this point in the history
Make MKL the default when it's available
  • Loading branch information
ChrisRackauckas authored Oct 4, 2023
2 parents e487d3b + 0a4f96a commit a53f644
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 38 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Expand All @@ -37,7 +38,6 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"

Expand All @@ -49,7 +49,6 @@ LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMKLExt = "MKL_jll"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"

Expand Down Expand Up @@ -91,7 +90,6 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Expand All @@ -101,4 +99,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"]
6 changes: 6 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ PrecompileTools.@recompile_invalidations begin
import Krylov

using SciMLBase

using MKL_jll
end

using Reexport
@reexport using SciMLBase
using SciMLBase: _unwrap_val

const usemkl = MKL_jll.is_available()

abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
Expand Down Expand Up @@ -91,6 +95,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
CholeskyFactorization
NormalCholeskyFactorization
AppleAccelerateLUFactorization
MKLLUFactorization
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
Expand All @@ -100,6 +105,7 @@ end
include("common.jl")
include("factorization.jl")
include("appleaccelerate.jl")
include("mkl.jl")
include("simplelu.jl")
include("simplegmres.jl")
include("iterative_wrappers.jl")
Expand Down
14 changes: 11 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17}
T13, T14, T15, T16, T17, T18}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -18,6 +18,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
CholeskyFactorization::T15
NormalCholeskyFactorization::T16
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
end

# Legacy fallback
Expand Down Expand Up @@ -162,19 +163,24 @@ function defaultalg(A, b, assump::OperatorAssumptions)
DefaultAlgorithmChoice.GenericLUFactorization
elseif VERSION >= v"1.8" && appleaccelerate_isavailable()
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
(usemkl && length(b) <= 200)) &&
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl
DefaultAlgorithmChoice.MKLLUFactorization
else
DefaultAlgorithmChoice.GenericLUFactorization
DefaultAlgorithmChoice.LUFactorization
end
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
elseif usemkl
DefaultAlgorithmChoice.MKLLUFactorization
else
DefaultAlgorithmChoice.LUFactorization
end
Expand Down Expand Up @@ -209,6 +215,8 @@ function algchoice_to_alg(alg::Symbol)
LDLtFactorization()
elseif alg === :LUFactorization
LUFactorization()
elseif alg === :MKLLUFactorization
MKLLUFactorization()
elseif alg === :QRFactorization
QRFactorization()
elseif alg === :DiagonalFactorization
Expand Down
10 changes: 0 additions & 10 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,6 @@ A wrapper over the IterativeSolvers.jl MINRES.
"""
function IterativeSolversJL_MINRES end

"""
```julia
MKLLUFactorization()
```
A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
to avoid allocations and does not require libblastrampoline.
"""
struct MKLLUFactorization <: AbstractFactorization end

"""
```julia
MetalLUFactorization()
Expand Down
3 changes: 0 additions & 3 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ function __init__()
@require KrylovKit="0b1a1467-8014-51b9-945f-bf0ae24f4b77" begin
include("../ext/LinearSolveKrylovKitExt.jl")
end
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
include("../ext/LinearSolveMKLExt.jl")
end
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/LinearSolveEnzymeExt.jl")
end
Expand Down
32 changes: 16 additions & 16 deletions ext/LinearSolveMKLExt.jl → src/mkl.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
module LinearSolveMKLExt
"""
```julia
MKLLUFactorization()
```
using MKL_jll
using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing,
chkfinite, chkstride1,
@blasfunc, chkargsok
using LinearAlgebra
const usemkl = MKL_jll.is_available()

using LinearSolve
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase
A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
to avoid allocations and does not require libblastrampoline.
"""
struct MKLLUFactorization <: AbstractFactorization end

function getrf!(A::AbstractMatrix{<:Float64};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
Expand Down Expand Up @@ -104,10 +101,15 @@ end
default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false

function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
const PREALLOCATED_MKL_LU = begin
A = rand(0, 0)
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
end

function init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), Ref{BlasInt}()
PREALLOCATED_MKL_LU
end

function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
Expand Down Expand Up @@ -140,6 +142,4 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
=#
end

end
end
10 changes: 8 additions & 2 deletions test/default_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ solve(prob)
prob = LinearProblem(rand(50, 50), rand(50))
solve(prob)

@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
if LinearSolve.usemkl
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
LinearSolve.DefaultAlgorithmChoice.MKLLUFactorization
else
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
LinearSolve.DefaultAlgorithmChoice.LUFactorization
end

prob = LinearProblem(rand(600, 600), rand(600))
solve(prob)

Expand Down

0 comments on commit a53f644

Please sign in to comment.