Skip to content

Commit

Permalink
Make MKL the default when it's available
Browse files Browse the repository at this point in the history
The benchmarks have pretty conclusively shown that MKL's LU factorization is just so much better than OpenBLAS that we should effectively always use it. What this does is make MKL_jll into a dependency of LinearSolve.jl and then uses the direct calls to the binary as part of the default algorithm when it's available (it won't be available on systems where MKL does not exist, like M2 macbooks). This uses the direct calls instead of LibBLASTrampoline and thus does not effect the user's global state, thus only being a local change that simply accelerates packages using LinearSolve (i.e. all of SciML).
  • Loading branch information
ChrisRackauckas committed Oct 3, 2023
1 parent 730f59c commit 5156b96
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 23 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Expand All @@ -37,7 +38,6 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"

Expand All @@ -49,7 +49,6 @@ LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMKLExt = "MKL_jll"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"

Expand Down Expand Up @@ -91,7 +90,6 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Expand All @@ -101,4 +99,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"]
4 changes: 4 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ PrecompileTools.@recompile_invalidations begin
import Krylov

using SciMLBase

using MKL_jll
end

using Reexport
@reexport using SciMLBase
using SciMLBase: _unwrap_val

const usemkl = MKL_jll.is_available()

abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
Expand Down
14 changes: 11 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17}
T13, T14, T15, T16, T17, T18}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -18,6 +18,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
CholeskyFactorization::T15
NormalCholeskyFactorization::T16
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
end

# Legacy fallback
Expand Down Expand Up @@ -162,19 +163,24 @@ function defaultalg(A, b, assump::OperatorAssumptions)
DefaultAlgorithmChoice.GenericLUFactorization
elseif appleaccelerate_isavailable()
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
(usemkl && length(b) <= 200)) &&
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl
DefaultAlgorithmChoice.MKLLUFactorization
else
DefaultAlgorithmChoice.GenericLUFactorization
DefaultAlgorithmChoice.LUFactorization
end
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
elseif usemkl
DefaultAlgorithmChoice.MKLLUFactorization
else
DefaultAlgorithmChoice.LUFactorization
end
Expand Down Expand Up @@ -209,6 +215,8 @@ function algchoice_to_alg(alg::Symbol)
LDLtFactorization()
elseif alg === :LUFactorization
LUFactorization()
elseif alg === :MKLLUFactorization
MKLLUFactorization()
elseif alg === :QRFactorization
QRFactorization()
elseif alg === :DiagonalFactorization
Expand Down
17 changes: 1 addition & 16 deletions ext/LinearSolveMKLExt.jl → src/mkl.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
module LinearSolveMKLExt

using MKL_jll
using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing,
chkfinite, chkstride1,
@blasfunc, chkargsok
using LinearAlgebra
const usemkl = MKL_jll.is_available()

using LinearSolve
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase

function getrf!(A::AbstractMatrix{<:Float64};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
Expand Down Expand Up @@ -140,6 +127,4 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
=#
end

end
end

0 comments on commit 5156b96

Please sign in to comment.