Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make preconditioners part of the solver rather than a random extra #514

Merged
merged 14 commits into from
Aug 8, 2024
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ JET = "0.8.28"
KLU = "0.6"
KernelAbstractions = "0.9.16"
Krylov = "0.9"
KrylovPreconditioners = "0.2"
KrylovKit = "0.8"
LazyArrays = "1.8, 2"
Libdl = "1.10"
Expand Down Expand Up @@ -135,6 +136,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
KrylovPreconditioners = "45d422c2-293f-44ce-8315-2cb988662dec"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Expand All @@ -148,4 +150,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
11 changes: 5 additions & 6 deletions docs/src/basics/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ a few ways:

## How do I use IterativeSolvers solvers with a weighted tolerance vector?

IterativeSolvers.jl computes the norm after the application of the left preconditioner
`Pl`. Thus, in order to use a vector tolerance `weights`, one can mathematically
IterativeSolvers.jl computes the norm after the application of the left preconditioner.
Thus, in order to use a vector tolerance `weights`, one can mathematically
hack the system via the following formulation:

```@example FAQPrec
Expand All @@ -57,11 +57,10 @@ A = rand(n, n)
b = rand(n)

weights = [1e-1, 1]
Pl = LinearSolve.InvPreconditioner(Diagonal(weights))
Pr = Diagonal(weights)
precs = Returns((LinearSolve.InvPreconditioner(Diagonal(weights)), Diagonal(weights)))

prob = LinearProblem(A, b)
sol = solve(prob, KrylovJL_GMRES(), Pl = Pl, Pr = Pr)
sol = solve(prob, KrylovJL_GMRES(precs))

sol.u
```
Expand All @@ -84,5 +83,5 @@ Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(we
Pr = Diagonal(weights)

prob = LinearProblem(A, b)
sol = solve(prob, KrylovJL_GMRES(), Pl = Pl, Pr = Pr)
sol = solve(prob, KrylovJL_GMRES(precs=Returns((Pl,Pr))))
```
6 changes: 3 additions & 3 deletions ext/LinearSolveIterativeSolversExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearSolveIterativeSolversExt

using LinearSolve, LinearAlgebra
using LinearSolve: LinearCache
using LinearSolve: LinearCache, DEFAULT_PRECS
import LinearSolve: IterativeSolversJL

if isdefined(Base, :get_extension)
Expand All @@ -12,9 +12,9 @@ end

function LinearSolve.IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.gmres_iterable!,
gmres_restart = 0, kwargs...)
gmres_restart = 0, precs = DEFAULT_PRECS, kwargs...)
return IterativeSolversJL(generate_iterator, gmres_restart,
args, kwargs)
precs, args, kwargs)
end

function LinearSolve.IterativeSolversJL_CG(args...; kwargs...)
Expand Down
5 changes: 3 additions & 2 deletions ext/LinearSolveKrylovKitExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module LinearSolveKrylovKitExt

using LinearSolve, KrylovKit, LinearAlgebra
using LinearSolve: LinearCache
using LinearSolve: LinearCache, DEFAULT_PRECS

function LinearSolve.KrylovKitJL(args...;
KrylovAlg = KrylovKit.GMRES, gmres_restart = 0,
precs = DEFAULT_PRECS,
kwargs...)
return KrylovKitJL(KrylovAlg, gmres_restart, args, kwargs)
return KrylovKitJL(KrylovAlg, gmres_restart, precs, args, kwargs)
end

function LinearSolve.KrylovKitJL_CG(args...; kwargs...)
Expand Down
1 change: 1 addition & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import PrecompileTools
import Krylov
using SciMLBase
import Preferences

const CRC = ChainRulesCore

@static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686
Expand Down
76 changes: 73 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,18 @@ end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
Pl, Pr = cache.alg.precs(x, cache.p)
setfield!(cache, :Pl, Pl)
setfield!(cache, :Pr, Pr)
end
setfield!(cache, :isfresh, true)
elseif name === :p
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
Pl, Pr = cache.alg.precs(cache.A, x)
setfield!(cache, :Pl, Pl)
setfield!(cache, :Pr, Pr)
end
elseif name === :b
# In case there is something that needs to be done when b is updated
update_cacheval!(cache, :b, x)
Expand Down Expand Up @@ -121,6 +132,8 @@ default_alias_b(::Any, ::Any, ::Any) = false
default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true

DEFAULT_PRECS(A, p) = IdentityOperator(size(A)[1]), IdentityOperator(size(A)[2])

function __init_u0_from_Ab(A, b)
u0 = similar(b, size(A, 2))
fill!(u0, false)
Expand All @@ -136,12 +149,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
reltol = default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = IdentityOperator(size(prob.A)[1]),
Pr = IdentityOperator(size(prob.A)[2]),
Pl = nothing,
Pr = nothing,
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)
@unpack A, b, u0, p = prob
(;A, b, u0, p) = prob

A = if alias_A || A isa SMatrix
A
Expand All @@ -167,6 +180,24 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
reltol = real(eltype(prob.b))(reltol)
abstol = real(eltype(prob.b))(abstol)

precs = if hasproperty(alg, :precs)
isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs
else
DEFAULT_PRECS
end
_Pl, _Pr = precs(A, p)
if isnothing(Pl)
Pl = _Pl
else
# TODO: deprecate once all docs are updated to the new form
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
end
if isnothing(Pr)
Pr = _Pr
else
# TODO: deprecate once all docs are updated to the new form
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
end
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
isfresh = true
Expand All @@ -179,6 +210,45 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
return cache
end


function SciMLBase.reinit!(cache::LinearCache;
oscardssmith marked this conversation as resolved.
Show resolved Hide resolved
A = nothing,
b = cache.b,
u = cache.u,
p = nothing,
reinit_cache = false,)
(; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache

precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS
Pl, Pr = if isnothing(A) || isnothing(p)
if isnothing(A)
A = cache.A
end
if isnothing(p)
p = cache.p
end
precs(A, p)
else
(cache.Pl, cache.Pr)
end
isfresh = true

if reinit_cache
return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval),
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
else
cache.A = A
cache.b = b
cache.u = u
cache.p = p
cache.Pl = Pl
cache.Pr = Pr
cache.isfresh = true
end
end

function SciMLBase.solve(prob::LinearProblem, args...; kwargs...)
return solve(prob, nothing, args...; kwargs...)
end
Expand Down
6 changes: 4 additions & 2 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,10 @@ solvers.

Using this solver requires adding the package KrylovKit.jl, i.e. `using KrylovKit`
"""
struct KrylovKitJL{F, A, I, K} <: LinearSolve.AbstractKrylovSubspaceMethod
struct KrylovKitJL{F, I, P, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod
KrylovAlg::F
gmres_restart::I
precs::P
args::A
kwargs::K
end
Expand Down Expand Up @@ -306,9 +307,10 @@ A generic wrapper over the IterativeSolvers.jl solvers.

Using this solver requires adding the package IterativeSolvers.jl, i.e. `using IterativeSolvers`
"""
struct IterativeSolversJL{F, I, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod
struct IterativeSolversJL{F, I, P, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod
generate_iterator::F
gmres_restart::I
precs::P
args::A
kwargs::K
end
Expand Down
23 changes: 10 additions & 13 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ KrylovJL(args...; KrylovAlg = Krylov.gmres!,

A generic wrapper over the Krylov.jl krylov-subspace iterative solvers.
"""
struct KrylovJL{F, I, A, K} <: AbstractKrylovSubspaceMethod
struct KrylovJL{F, I, P, A, K} <: AbstractKrylovSubspaceMethod
KrylovAlg::F
gmres_restart::I
window::I
precs::P
args::A
kwargs::K
end

function KrylovJL(args...; KrylovAlg = Krylov.gmres!,
gmres_restart = 0, window = 0,
precs = nothing,
kwargs...)
return KrylovJL(KrylovAlg, gmres_restart, window,
args, kwargs)
precs, args, kwargs)
end

default_alias_A(::KrylovJL, ::Any, ::Any) = true
Expand Down Expand Up @@ -231,8 +233,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
cache.isfresh = false
end

M = cache.Pl
N = cache.Pr
M, N = cache.Pl, cache.Pr

# use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity
M = _isidentity_struct(M) ? I : M
Expand All @@ -258,25 +259,21 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
end

args = (cacheval, cache.A, cache.b)
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
kwargs = (atol = atol, rtol, itmax, verbose,
ldiv = true, history = true, alg.kwargs...)

if cache.cacheval isa Krylov.CgSolver
N !== I &&
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
Krylov.solve!(args...; M = M,
kwargs...)
Krylov.solve!(args...; M, kwargs...)
elseif cache.cacheval isa Krylov.GmresSolver
Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0,
kwargs...)
Krylov.solve!(args...; M, N, restart = alg.gmres_restart > 0, kwargs...)
elseif cache.cacheval isa Krylov.BicgstabSolver
Krylov.solve!(args...; M = M, N = N,
kwargs...)
Krylov.solve!(args...; M, N, kwargs...)
elseif cache.cacheval isa Krylov.MinresSolver
N !== I &&
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
Krylov.solve!(args...; M = M,
kwargs...)
Krylov.solve!(args...; M, kwargs...)
else
Krylov.solve!(args...; kwargs...)
end
Expand Down
4 changes: 3 additions & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
using SciMLOperators
using IterativeSolvers, KrylovKit, MKL_jll
using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners
using Test
import Random

Expand Down Expand Up @@ -267,10 +267,12 @@ end

@testset "KrylovJL" begin
kwargs = (; gmres_restart = 5)
precs = (A,p=nothing) -> (BlockJacobiPreconditioner(A, 2), I)
algorithms = (
("Default", KrylovJL(kwargs...)),
("CG", KrylovJL_CG(kwargs...)),
("GMRES", KrylovJL_GMRES(kwargs...)),
("GMRES_prec", KrylovJL_GMRES(;precs, ldiv=false, kwargs...)),
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
("MINRES", KrylovJL_MINRES(kwargs...))
)
Expand Down
Loading