Skip to content

Commit

Permalink
Merge pull request #321 from SciML/cholesky16
Browse files Browse the repository at this point in the history
workaround Cholesky factorization limitations on v1.6
  • Loading branch information
ChrisRackauckas authored Jun 7, 2023
2 parents 9cece80 + a1fedb8 commit c6a2743
Show file tree
Hide file tree
Showing 25 changed files with 493 additions and 493 deletions.
32 changes: 16 additions & 16 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ DocMeta.setdocmeta!(LinearSolve, :DocTestSetup, :(using LinearSolve); recursive
include("pages.jl")

makedocs(sitename = "LinearSolve.jl",
authors = "Chris Rackauckas",
modules = [LinearSolve, LinearSolve.SciMLBase],
clean = true, doctest = false, linkcheck = true,
strict = [
:doctest,
:linkcheck,
:parse_error,
:example_block,
# Other available options are
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
],
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/LinearSolve/stable/"),
pages = pages)
authors = "Chris Rackauckas",
modules = [LinearSolve, LinearSolve.SciMLBase],
clean = true, doctest = false, linkcheck = true,
strict = [
:doctest,
:linkcheck,
:parse_error,
:example_block,
# Other available options are
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
],
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/LinearSolve/stable/"),
pages = pages)

deploydocs(;
repo = "github.com/SciML/LinearSolve.jl",
push_preview = true)
repo = "github.com/SciML/LinearSolve.jl",
push_preview = true)
12 changes: 6 additions & 6 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

pages = ["index.md",
"Tutorials" => Any["tutorials/linear.md"
"tutorials/caching_interface.md"],
"tutorials/caching_interface.md"],
"Basics" => Any["basics/LinearProblem.md",
"basics/common_solver_opts.md",
"basics/OperatorAssumptions.md",
"basics/Preconditioners.md",
"basics/FAQ.md"],
"basics/common_solver_opts.md",
"basics/OperatorAssumptions.md",
"basics/Preconditioners.md",
"basics/FAQ.md"],
"Solvers" => Any["solvers/solvers.md"],
"Advanced" => Any["advanced/developing.md"
"advanced/custom.md"],
"advanced/custom.md"],
"Release Notes" => "release_notes.md",
]
2 changes: 1 addition & 1 deletion docs/src/advanced/developing.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ basic machinery. A simplified version is:
struct MyLUFactorization{P} <: SciMLBase.AbstractLinearAlgorithm end

function init_cacheval(alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol,
verbose)
verbose)
lu!(convert(AbstractMatrix, A))
end

Expand Down
2 changes: 1 addition & 1 deletion docs/src/basics/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ b = rand(n)
weights = rand(n)
realprec = lu(rand(n, n)) # some random preconditioner
Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(weights)),
realprec)
realprec)
Pr = Diagonal(weights)
prob = LinearProblem(A, b)
Expand Down
2 changes: 1 addition & 1 deletion ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using CUDA, LinearAlgebra, LinearSolve, SciMLBase
using SciMLBase: AbstractSciMLOperator

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
kwargs...)
kwargs...)
if cache.isfresh
fact = LinearSolve.do_factorization(alg, CUDA.CuArray(cache.A), cache.b, cache.u)
cache = LinearSolve.set_cacheval(cache, fact)
Expand Down
70 changes: 35 additions & 35 deletions ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using LinearAlgebra
using HYPRE.LibHYPRE: HYPRE_Complex
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning
using SciMLBase: LinearProblem, SciMLBase
using UnPack: @unpack
using Setfield: @set!
Expand All @@ -21,8 +21,8 @@ mutable struct HYPRECache
end

function LinearSolve.init_cacheval(alg::HYPREAlgorithm, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
return HYPRECache(nothing, nothing, nothing, nothing, true, true, true)
end

Expand Down Expand Up @@ -54,21 +54,21 @@ end
# fill!(similar(b, size(A, 2)), false) since HYPREArrays are not AbstractArrays.

function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
args...;
alias_A = false, alias_b = false,
# TODO: Implement eltype for HYPREMatrix in HYPRE.jl? Looks useful
# even if it is not AbstractArray.
abstol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
eltype(prob.A)),
reltol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
eltype(prob.A)),
# TODO: Implement length() for HYPREVector in HYPRE.jl?
maxiters::Int = prob.b isa HYPREVector ? 1000 : length(prob.b),
verbose::Bool = false,
Pl = LinearAlgebra.I,
Pr = LinearAlgebra.I,
assumptions = OperatorAssumptions(),
kwargs...)
args...;
alias_A = false, alias_b = false,
# TODO: Implement eltype for HYPREMatrix in HYPRE.jl? Looks useful
# even if it is not AbstractArray.
abstol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
eltype(prob.A)),
reltol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
eltype(prob.A)),
# TODO: Implement length() for HYPREVector in HYPRE.jl?
maxiters::Int = prob.b isa HYPREVector ? 1000 : length(prob.b),
verbose::Bool = false,
Pl = LinearAlgebra.I,
Pr = LinearAlgebra.I,
assumptions = OperatorAssumptions(),
kwargs...)
@unpack A, b, u0, p = prob

A = A isa HYPREMatrix ? A : HYPREMatrix(A)
Expand All @@ -82,23 +82,23 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,

# Initialize internal alg cache
cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
assumptions)
Tc = typeof(cacheval)
isfresh = true

cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions))
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters,
verbose, assumptions)
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions)),
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters,
verbose, assumptions)
return cache
end

# Solvers whose constructor requires passing the MPI communicator
const COMM_SOLVERS = Union{HYPRE.BiCGSTAB, HYPRE.FlexGMRES, HYPRE.GMRES, HYPRE.ParaSails,
HYPRE.PCG}
HYPRE.PCG}
create_solver(::Type{S}, comm) where {S <: COMM_SOLVERS} = S(comm)

# Solvers whose constructor should not be passed the MPI communicator
Expand All @@ -120,10 +120,10 @@ function create_solver(alg::HYPREAlgorithm, cache::LinearCache)

# Construct solver options
solver_options = (;
AbsoluteTol = cache.abstol,
MaxIter = cache.maxiters,
PrintLevel = Int(cache.verbose),
Tol = cache.reltol)
AbsoluteTol = cache.abstol,
MaxIter = cache.maxiters,
PrintLevel = Int(cache.verbose),
Tol = cache.reltol)

# Preconditioner (uses Pl even though it might not be a *left* preconditioner just *a*
# preconditioner)
Expand Down Expand Up @@ -211,16 +211,16 @@ function SciMLBase.solve!(cache::LinearCache, alg::HYPREAlgorithm, args...; kwar
stats = nothing

ret = SciMLBase.LinearSolution{T, N, typeof(cache.u), typeof(resid), typeof(alg),
typeof(cache), typeof(stats)}(cache.u, resid, alg, retc,
iters, cache, stats)
typeof(cache), typeof(stats)}(cache.u, resid, alg, retc,
iters, cache, stats)

return ret
end

# HYPREArrays are not AbstractArrays so perform some type-piracy
function SciMLBase.LinearProblem(A::HYPREMatrix, b::HYPREVector,
p = SciMLBase.NullParameters();
u0::Union{HYPREVector, Nothing} = nothing, kwargs...)
p = SciMLBase.NullParameters();
u0::Union{HYPREVector, Nothing} = nothing, kwargs...)
return LinearProblem{true}(A, b, p; u0 = u0, kwargs)
end

Expand Down
54 changes: 27 additions & 27 deletions ext/LinearSolveIterativeSolversExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,76 +11,76 @@ else
end

function LinearSolve.IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.gmres_iterable!,
gmres_restart = 0, kwargs...)
generate_iterator = IterativeSolvers.gmres_iterable!,
gmres_restart = 0, kwargs...)
return IterativeSolversJL(generate_iterator, gmres_restart,
args, kwargs)
args, kwargs)
end

function LinearSolve.IterativeSolversJL_CG(args...; kwargs...)
IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.cg_iterator!,
kwargs...)
generate_iterator = IterativeSolvers.cg_iterator!,
kwargs...)
end
function LinearSolve.IterativeSolversJL_GMRES(args...; kwargs...)
IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.gmres_iterable!,
kwargs...)
generate_iterator = IterativeSolvers.gmres_iterable!,
kwargs...)
end
function LinearSolve.IterativeSolversJL_BICGSTAB(args...; kwargs...)
IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.bicgstabl_iterator!,
kwargs...)
generate_iterator = IterativeSolvers.bicgstabl_iterator!,
kwargs...)
end
function LinearSolve.IterativeSolversJL_MINRES(args...; kwargs...)
IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.minres_iterable!,
kwargs...)
generate_iterator = IterativeSolvers.minres_iterable!,
kwargs...)
end

LinearSolve._isidentity_struct(::IterativeSolvers.Identity) = true
LinearSolve.default_alias_A(::IterativeSolversJL, ::Any, ::Any) = true
LinearSolve.default_alias_b(::IterativeSolversJL, ::Any, ::Any) = true

function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int,
abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
restart = (alg.gmres_restart == 0) ? min(20, size(A, 1)) : alg.gmres_restart

kwargs = (abstol = abstol, reltol = reltol, maxiter = maxiters,
alg.kwargs...)
alg.kwargs...)

iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
!LinearSolve._isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, Pl;
kwargs...)
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart,
kwargs...)
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
!!LinearSolve._isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, alg.args...; Pl = Pl,
abstol = abstol, reltol = reltol,
max_mv_products = maxiters * 2,
alg.kwargs...)
abstol = abstol, reltol = reltol,
max_mv_products = maxiters * 2,
alg.kwargs...)
else # minres, qmr
alg.generate_iterator(u, A, b, alg.args...;
abstol = abstol, reltol = reltol, maxiter = maxiters,
alg.kwargs...)
abstol = abstol, reltol = reltol, maxiter = maxiters,
alg.kwargs...)
end
return iterable
end

function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
if cache.isfresh || !(typeof(alg) <: IterativeSolvers.GMRESIterable)
solver = LinearSolve.init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl,
cache.Pr,
cache.maxiters, cache.abstol, cache.reltol,
cache.verbose,
cache.assumptions)
cache.Pr,
cache.maxiters, cache.abstol, cache.reltol,
cache.verbose,
cache.assumptions)
cache.cacheval = solver
cache.isfresh = false
end
Expand Down Expand Up @@ -111,7 +111,7 @@ function purge_history!(iter::IterativeSolvers.GMRESIterable, x, b)
iter.b = b

iter.residual.current = IterativeSolvers.init!(iter.arnoldi, iter.x, iter.b, iter.Pl,
iter.Ax, initially_zero = true)
iter.Ax, initially_zero = true)
IterativeSolvers.init_residual!(iter.residual, iter.residual.current)
iter.β = iter.residual.current
nothing
Expand Down
8 changes: 4 additions & 4 deletions ext/LinearSolveKrylovKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using LinearSolve, KrylovKit, LinearAlgebra
using LinearSolve: LinearCache

function LinearSolve.KrylovKitJL(args...;
KrylovAlg = KrylovKit.GMRES, gmres_restart = 0,
kwargs...)
KrylovAlg = KrylovKit.GMRES, gmres_restart = 0,
kwargs...)
return KrylovKitJL(KrylovAlg, gmres_restart, args, kwargs)
end

Expand All @@ -28,7 +28,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovKitJL; kwargs...)
krylovdim = (alg.gmres_restart == 0) ? min(20, size(cache.A, 1)) : alg.gmres_restart

kwargs = (atol = atol, rtol = rtol, maxiter = maxiter, verbosity = verbosity,
krylovdim = krylovdim, alg.kwargs...)
krylovdim = krylovdim, alg.kwargs...)

x, info = KrylovKit.linsolve(cache.A, cache.b, cache.u; kwargs...)

Expand All @@ -37,7 +37,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovKitJL; kwargs...)
retcode = info.converged == 1 ? ReturnCode.Default : ReturnCode.ConvergenceFailure
iters = info.numiter
return SciMLBase.build_linear_solution(alg, cache.u, resid, cache; retcode = retcode,
iters = iters)
iters = iters)
end

end
26 changes: 13 additions & 13 deletions ext/LinearSolvePardisoExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ LinearSolve.needs_concrete_A(alg::PardisoJL) = true
# TODO schur complement functionality

function LinearSolve.init_cacheval(alg::PardisoJL,
A,
b,
u,
Pl,
Pr,
maxiters::Int,
abstol,
reltol,
verbose::Bool,
assumptions::LinearSolve.OperatorAssumptions)
A,
b,
u,
Pl,
Pr,
maxiters::Int,
abstol,
reltol,
verbose::Bool,
assumptions::LinearSolve.OperatorAssumptions)
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
A = convert(AbstractMatrix, A)

Expand Down Expand Up @@ -93,9 +93,9 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
end

Pardiso.pardiso(solver,
u,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
b)
u,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
b)

return solver
end
Expand Down
Loading

0 comments on commit c6a2743

Please sign in to comment.