Skip to content

Commit

Permalink
Merge pull request #33 from vpuri3/vp-precond
Browse files Browse the repository at this point in the history
default preconditioners
  • Loading branch information
ChrisRackauckas authored Dec 8, 2021
2 parents e6e42b0 + 730b501 commit bda20fe
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
[compat]
ArrayInterface = "3"
IterativeSolvers = "0.9.2"
Krylov = "0.7"
Krylov = "0.7.9"
KrylovKit = "0.5"
RecursiveFactorization = "0.2"
Reexport = "1"
Expand Down
2 changes: 2 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ module LinearSolve
using ArrayInterface
using RecursiveFactorization
using Base: cache_dependencies, Bool
import Base: eltype, adjoint, inv
using LinearAlgebra
using IterativeSolvers:Identity
using SparseArrays
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
using Setfield
Expand Down
15 changes: 11 additions & 4 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ function set_cacheval(cache::LinearCache, alg_cache)
return cache
end

function set_prec(cache, Pl, Pr)
@set! cache.Pl = Pl
@set! cache.Pr = Pr
return cache
end

init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing

SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)
Expand All @@ -54,19 +60,20 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
reltol=eps(eltype(prob.A)),
maxiters=length(prob.b),
verbose=false,
Pl = nothing,
Pr = nothing,
kwargs...,
)
@unpack A, b, u0, p = prob

u0 = (u0 === nothing) ? zero(b) : u0
u0 = (u0 !== nothing) ? u0 : zero(b)
Pl = (Pl !== nothing) ? Pl : Identity()
Pr = (Pr !== nothing) ? Pr : Identity()

cacheval = init_cacheval(alg, A, b, u0)
isfresh = cacheval === nothing
Tc = isfresh ? Any : typeof(cacheval)

Pl = LinearAlgebra.I
Pr = LinearAlgebra.I

A = alias_A ? A : deepcopy(A)
b = alias_b ? b : deepcopy(b)

Expand Down
100 changes: 71 additions & 29 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,61 @@

#TODO: composed preconditioners, preconditioner setter for cache,
# detailed tests for wrappers

## Preconditioners

struct ScaleVector{T}
s::T
isleft::Bool
scaling_preconditioner(s) = I * s , I * (1/s)

struct ComposePreconditioner{Ti,To}
inner::Ti
outer::To
end

function LinearAlgebra.ldiv!(v::ScaleVector, x)
Base.eltype(A::ComposePreconditioner) = promote_type(eltype(A.inner), eltype(A.outer))
Base.adjoint(A::ComposePreconditioner) = ComposePreconditioner(A.outer', A.inner')
Base.inv(A::ComposePreconditioner) = InvComposePreconditioner(A)

function LinearAlgebra.ldiv!(A::ComposePreconditioner, x)
@unpack inner, outer = A

ldiv!(inner, x)
ldiv!(outer, x)
end

function LinearAlgebra.ldiv!(y, v::ScaleVector, x)
function LinearAlgebra.ldiv!(y, A::ComposePreconditioner, x)
@unpack inner, outer = A

ldiv!(y, inner, x)
ldiv!(outer, y)
end

struct ComposePreconditioner{Ti,To}
inner::Ti
outer::To
isleft::Bool
struct InvComposePreconditioner{Tp <: ComposePreconditioner}
P::Tp
end

function LinearAlgebra.ldiv!(v::ComposePreconditioner, x)
@unpack inner, outer, isleft = v
InvComposePreconditioner(inner, outer) = InvComposePreconditioner(ComposePreconditioner(inner, outer))

Base.eltype(A::InvComposePreconditioner) = Base.eltype(A.P)
Base.adjoint(A::InvComposePreconditioner) = InvComposePreconditioner(A.P')
Base.inv(A::InvComposePreconditioner) = deepcopy(A.P)

function LinearAlgebra.mul!(y, A::InvComposePreconditioner, x)
@unpack P = A
ldiv!(y, P, x)
end

function LinearAlgebra.ldiv!(y, v::ComposePreconditioner, x)
@unpack inner, outer, isleft = v
function get_preconditioner(Pi, Po)

ifPi = Pi !== Identity()
ifPo = Po !== Identity()

P =
if ifPi & ifPo
ComposePreconditioner(Pi, Po)
elseif ifPi | ifPo
ifPi ? Pi : Po
else
Identity()
end

return P
end

## Krylov.jl
Expand All @@ -41,10 +70,14 @@ struct KrylovJL{F,Tl,Tr,I,A,K} <: AbstractKrylovSubspaceMethod
kwargs::K
end

function KrylovJL(args...; KrylovAlg = Krylov.gmres!, Pl=I, Pr=I,
function KrylovJL(args...; KrylovAlg = Krylov.gmres!,
Pl=nothing, Pr=nothing,
gmres_restart=0, window=0,
kwargs...)

Pl = (Pl === nothing) ? Identity() : Pl
Pr = (Pr === nothing) ? Identity() : Pr

return KrylovJL(KrylovAlg, Pl, Pr, gmres_restart, window,
args, kwargs)
end
Expand Down Expand Up @@ -132,6 +165,12 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
cache = set_cacheval(cache, solver)
end

M = get_preconditioner(alg.Pl, cache.Pl)
N = get_preconditioner(alg.Pr, cache.Pr)

M = (M === Identity()) ? I : inv(M)
N = (N === Identity()) ? I : inv(N)

atol = cache.abstol
rtol = cache.reltol
itmax = cache.maxiters
Expand All @@ -142,20 +181,20 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
alg.kwargs...)

if cache.cacheval isa Krylov.CgSolver
alg.Pr != LinearAlgebra.I &&
N !== I &&
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
Krylov.solve!(args...; M=alg.Pl,
Krylov.solve!(args...; M=M,
kwargs...)
elseif cache.cacheval isa Krylov.GmresSolver
Krylov.solve!(args...; M=alg.Pl, N=alg.Pr,
Krylov.solve!(args...; M=M, N=N,
kwargs...)
elseif cache.cacheval isa Krylov.BicgstabSolver
Krylov.solve!(args...; M=alg.Pl, N=alg.Pr,
Krylov.solve!(args...; M=M, N=N,
kwargs...)
elseif cache.cacheval isa Krylov.MinresSolver
alg.Pr != LinearAlgebra.I &&
N !== I &&
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
Krylov.solve!(args...; M=alg.Pl,
Krylov.solve!(args...; M=M,
kwargs...)
else
Krylov.solve!(args...; kwargs...)
Expand All @@ -177,9 +216,12 @@ end

function IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.gmres_iterable!,
Pl=IterativeSolvers.Identity(),
Pr=IterativeSolvers.Identity(),
Pl=nothing, Pr=nothing,
gmres_restart=0, kwargs...)

Pl = (Pl === nothing) ? Identity() : Pl
Pr = (Pr === nothing) ? Identity() : Pr

return IterativeSolversJL(generate_iterator, Pl, Pr, gmres_restart,
args, kwargs)
end
Expand All @@ -204,8 +246,8 @@ IterativeSolversJL_MINRES(args...;kwargs...) =
function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
@unpack A, b, u = cache

Pl = (alg.Pl == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pl
Pr = (alg.Pr == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pr
Pl = get_preconditioner(alg.Pl, cache.Pl)
Pr = get_preconditioner(alg.Pr, cache.Pr)

abstol = cache.abstol
reltol = cache.reltol
Expand All @@ -218,15 +260,15 @@ function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
alg.kwargs...)

iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
Pr != IterativeSolvers.Identity() &&
Pr !== Identity() &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, Pl;
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
alg.generate_iterator(u, A, b; Pl=Pl, Pr=Pr, restart=restart,
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
Pr != IterativeSolvers.Identity() &&
Pr !== Identity() &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, alg.args...; Pl=Pl,
abstol=abstol, reltol=reltol,
Expand Down
52 changes: 52 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ function test_interface(alg, prob1, prob2)
return
end

@testset "LinearSolve" begin

@testset "Default Linear Solver" begin
test_interface(nothing, prob1, prob2)

Expand Down Expand Up @@ -123,3 +125,53 @@ end
end
end
end

@testset "Preconditioners" begin
@testset "scaling_preconditioner" begin
s = rand()

x = rand(n,n)
y = rand(n,n)

Pl, Pr = LinearSolve.scaling_preconditioner(s)

mul!(y, Pl, x); @test y s * x
mul!(y, Pr, x); @test y s \ x

y .= x; ldiv!(Pl, x); @test x s \ y
y .= x; ldiv!(Pr, x); @test x s * y

ldiv!(y, Pl, x); @test y s \ x
ldiv!(y, Pr, x); @test y s * x

end

@testset "ComposePreconditioenr" begin
s1 = rand()
s2 = rand()

x = rand(n,n)
y = rand(n,n)

P1, _ = LinearSolve.scaling_preconditioner(s1)
P2, _ = LinearSolve.scaling_preconditioner(s2)

P = LinearSolve.ComposePreconditioner(P1,P2)
Pi = LinearSolve.InvComposePreconditioner(P)

@test Pi == LinearSolve.InvComposePreconditioner(P1, P2)
@test Pi == inv(P)
@test P == inv(Pi)
@test Pi' == inv(P')

# ComposePreconditioner
ldiv!(y, P, x); @test y ldiv!(P2, ldiv!(P1, x))
y .= x; ldiv!(P, x); @test x ldiv!(P2, ldiv!(P1, y))

# InvComposePreconditioner
mul!(y, Pi, x); @test y ldiv!(P2, ldiv!(P1, x))

end
end

end # testset

0 comments on commit bda20fe

Please sign in to comment.