Skip to content

Commit

Permalink
Try to fix #497
Browse files Browse the repository at this point in the history
Pardiso defaults for highly indefinite matrices.

This commit essentially reverts #89 and introduces a new
kwarg "cache_analysis" (default `false`) to PardisoJL() which, if true would
lead to the behaviour of #89.

Also, allow the user to overwrite all iparms modified by
the extension besides of 12.
  • Loading branch information
j-fu committed May 31, 2024
1 parent 270b56d commit fe0d790
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 29 deletions.
57 changes: 31 additions & 26 deletions ext/LinearSolvePardisoExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
reltol,
verbose::Bool,
assumptions::LinearSolve.OperatorAssumptions)
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
@unpack nprocs, solver_type, matrix_type, cache_analysis, iparm, dparm = alg
A = convert(AbstractMatrix, A)

solver = if Pardiso.PARDISO_LOADED[]
Expand Down Expand Up @@ -52,22 +52,6 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
end
verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)

# pass in vector of tuples like [(iparm::Int, key::Int) ...]
if iparm !== nothing
for i in iparm
Pardiso.set_iparm!(solver, i...)
end
end

if dparm !== nothing
for d in dparm
Pardiso.set_dparm!(solver, d...)
end
end

# Make sure to say it's transposed because its CSC not CSR
Pardiso.set_iparm!(solver, 12, 1)

#=
Note: It is recommended to use IPARM(11)=1 (scaling) and IPARM(13)=1 (matchings) for
highly indefinite symmetric matrices e.g. from interior point optimizations or saddle point problems.
Expand All @@ -79,10 +63,10 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
be changed to Pardiso.ANALYSIS_NUM_FACT in the solver loop otherwise instabilities
occur in the example https://github.com/SciML/OrdinaryDiffEq.jl/issues/1569
=#
Pardiso.set_iparm!(solver, 11, 0)
Pardiso.set_iparm!(solver, 13, 0)

Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
if cache_analysis
Pardiso.set_iparm!(solver, 11, 0)
Pardiso.set_iparm!(solver, 13, 0)
end

if alg.solver_type == 1
# PARDISO uses a numerical factorization A = LU for the first system and
Expand All @@ -92,10 +76,30 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
Pardiso.set_iparm!(solver, 3, round(Int, abs(log10(reltol)), RoundDown) * 10 + 1)
end

Pardiso.pardiso(solver,
u,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
b)
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
if iparm !== nothing
for i in iparm
Pardiso.set_iparm!(solver, i...)
end
end

if dparm !== nothing
for d in dparm
Pardiso.set_dparm!(solver, d...)
end
end

# Make sure to say it's transposed because its CSC not CSR
# This is also the only value which should not be overwritten by users
Pardiso.set_iparm!(solver, 12, 1)

if cache_analysis
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
Pardiso.pardiso(solver,
u,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
b)
end

return solver
end
Expand All @@ -105,7 +109,8 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs
A = convert(AbstractMatrix, A)

if cache.isfresh
Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT
Pardiso.set_phase!(cache.cacheval, phase)
Pardiso.pardiso(cache.cacheval, A, eltype(A)[])
cache.isfresh = false
end
Expand Down
5 changes: 4 additions & 1 deletion src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,14 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
nprocs::Union{Int, Nothing}
solver_type::T1
matrix_type::T2
cache_analysis::Bool
iparm::Union{Vector{Tuple{Int, Int}}, Nothing}
dparm::Union{Vector{Tuple{Int, Int}}, Nothing}

function PardisoJL(; nprocs::Union{Int, Nothing} = nothing,
solver_type = nothing,
matrix_type = nothing,
cache_analysis = false,
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
ext = Base.get_extension(@__MODULE__, :LinearSolvePardisoExt)
Expand All @@ -170,7 +172,8 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
T2 = typeof(matrix_type)
@assert T1 <: Union{Int, Nothing, ext.Pardiso.Solver}
@assert T2 <: Union{Int, Nothing, ext.Pardiso.MatrixType}
return new{T1, T2}(nprocs, solver_type, matrix_type, iparm, dparm)
return new{T1, T2}(
nprocs, solver_type, matrix_type, cache_analysis, iparm, dparm)
end
end
end
Expand Down
35 changes: 33 additions & 2 deletions test/pardiso/pardiso.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearSolve, SparseArrays, Random
using LinearSolve, SparseArrays, Random, LinearAlgebra
import Pardiso

A1 = sparse([1.0 0 -2 3
Expand All @@ -13,12 +13,22 @@ n = 4
e = ones(n)
e2 = ones(n - 1)
A2 = spdiagm(-1 => im * e2, 0 => lambda * e, 1 => -im * e2)

b2 = rand(n) + im * zeros(n)
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)

prob2 = LinearProblem(A2, b2)

for alg in (PardisoJL(), MKLPardisoFactorize(), MKLPardisoIterate())
for alg in (PardisoJL(), MKLPardisoFactorize())
u = solve(prob1, alg; cache_kwargs...).u
@test A1 * u b1

u = solve(prob2, alg; cache_kwargs...).u
@test eltype(u) <: Complex
@test A2 * u b2
end

for alg in (MKLPardisoIterate(),)
u = solve(prob1, alg; cache_kwargs...).u
@test A1 * u b1

Expand All @@ -27,6 +37,8 @@ for alg in (PardisoJL(), MKLPardisoFactorize(), MKLPardisoIterate())
@test_broken A2 * u b2
end



Random.seed!(10)
A = sprand(n, n, 0.8);
A2 = 2.0 .* A;
Expand All @@ -53,6 +65,25 @@ sol33 = solve(linsolve)
@test sol12.u sol32.u
@test sol13.u sol33.u


# Test for problem from #497
function makeA()
n = 60
colptr = [1, 4, 7, 11, 15, 17, 22, 26, 30, 34, 38, 40, 46, 50, 54, 58, 62, 64, 70, 74, 78, 82, 86, 88, 94, 98, 102, 106, 110, 112, 118, 122, 126, 130, 134, 136, 142, 146, 150, 154, 158, 160, 166, 170, 174, 178, 182, 184, 190, 194, 198, 202, 206, 208, 214, 218, 222, 224, 226, 228, 232]
rowval = [1, 3, 4, 1, 2, 4, 2, 4, 9, 10, 3, 5, 11, 12, 1, 3, 2, 4, 6, 11, 12, 2, 7, 9, 10, 2, 7, 8, 10, 8, 10, 15, 16, 9, 11, 17, 18, 7, 9, 2, 8, 10, 12, 17, 18, 8, 13, 15, 16, 8, 13, 14, 16, 14, 16, 21, 22, 15, 17, 23, 24, 13, 15, 8, 14, 16, 18, 23, 24, 14, 19, 21, 22, 14, 19, 20, 22, 20, 22, 27, 28, 21, 23, 29, 30, 19, 21, 14, 20, 22, 24, 29, 30, 20, 25, 27, 28, 20, 25, 26, 28, 26, 28, 33, 34, 27, 29, 35, 36, 25, 27, 20, 26, 28, 30, 35, 36, 26, 31, 33, 34, 26, 31, 32, 34, 32, 34, 39, 40, 33, 35, 41, 42, 31, 33, 26, 32, 34, 36, 41, 42, 32, 37, 39, 40, 32, 37, 38, 40, 38, 40, 45, 46, 39, 41, 47, 48, 37, 39, 32, 38, 40, 42, 47, 48, 38, 43, 45, 46, 38, 43, 44, 46, 44, 46, 51, 52, 45, 47, 53, 54, 43, 45, 38, 44, 46, 48, 53, 54, 44, 49, 51, 52, 44, 49, 50, 52, 50, 52, 57, 58, 51, 53, 59, 60, 49, 51, 44, 50, 52, 54, 59, 60, 50, 55, 57, 58, 50, 55, 56, 58, 56, 58, 57, 59, 55, 57, 50, 56, 58, 60]
nzval = [-0.64, 1.0, -1.0, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -1.0806825309567203, 1.0, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0]
A = SparseMatrixCSC(n, n, colptr, rowval, nzval)
return(A)
end

A=makeA()
u0=fill(0.1,size(A,2))
linprob = LinearProblem(A, A*u0)
u = LinearSolve.solve(linprob, PardisoJL(),verbose=true)
@test norm(u-u0) < 1.0e-14



# Testing and demonstrating Pardiso.set_iparm! for MKLPardisoSolver
solver = Pardiso.MKLPardisoSolver()
iparm = [
Expand Down

0 comments on commit fe0d790

Please sign in to comment.