diff --git a/ext/LinearSolvePardisoExt.jl b/ext/LinearSolvePardisoExt.jl index 1c323bdf..98db67a0 100644 --- a/ext/LinearSolvePardisoExt.jl +++ b/ext/LinearSolvePardisoExt.jl @@ -26,22 +26,24 @@ function LinearSolve.init_cacheval(alg::PardisoJL, A = convert(AbstractMatrix, A) transposed_iparm = 1 - solver = if false && Pardiso.PARDISO_LOADED[] + solver = if Pardiso.PARDISO_LOADED[] solver = Pardiso.PardisoSolver() + Pardiso.pardisoinit(solver) solver_type !== nothing && Pardiso.set_solver!(solver, solver_type) solver else solver = Pardiso.MKLPardisoSolver() + Pardiso.pardisoinit(solver) nprocs !== nothing && Pardiso.set_nprocs!(solver, nprocs) - # for mkl 1 means conjugated an 2 means transposed. + + # for mkl 1 means conjugated and 2 means transposed. # https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-0/pardiso-iparm-parameter.html#IPARM37 transposed_iparm = 2 + solver end - Pardiso.pardisoinit(solver) # default initialization - if matrix_type !== nothing Pardiso.set_matrixtype!(solver, matrix_type) else @@ -118,9 +120,9 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs Pardiso.pardiso(cache.cacheval, A, eltype(A)[]) cache.isfresh = false end - Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE) Pardiso.pardiso(cache.cacheval, u, A, b) + return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end