diff --git a/ext/LinearSolveHYPREExt.jl b/ext/LinearSolveHYPREExt.jl index 279aba75d..4dd06af14 100644 --- a/ext/LinearSolveHYPREExt.jl +++ b/ext/LinearSolveHYPREExt.jl @@ -86,12 +86,13 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm, assumptions) Tc = typeof(cacheval) isfresh = true + precsisfresh = false cache = LinearCache{ typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), typeof(__issquare(assumptions)), typeof(sensealg) - }(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + }(A, b, u0, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) return cache end diff --git a/ext/LinearSolveIterativeSolversExt.jl b/ext/LinearSolveIterativeSolversExt.jl index cb4589642..198cc0a5e 100644 --- a/ext/LinearSolveIterativeSolversExt.jl +++ b/ext/LinearSolveIterativeSolversExt.jl @@ -90,6 +90,12 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max end function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...) + if cache.precsisfresh && !isnothing(alg.precs) + Pl, Pr = alg.precs(cache.Pl, cache.Pr) + cache.Pl = Pl + cache.Pr = Pr + cache.precsisfresh = false + end if cache.isfresh || !(alg isa IterativeSolvers.GMRESIterable) solver = LinearSolve.init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, diff --git a/src/common.jl b/src/common.jl index f212fe250..3f145fc69 100644 --- a/src/common.jl +++ b/src/common.jl @@ -73,6 +73,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S} alg::Talg cacheval::Tc # store alg cache here isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A + precsisfresh::Bool # false => PR,PL is set wrt A, true => update PR,PL wrt A Pl::Tl # preconditioners Pr::Tr abstol::Ttol @@ -85,18 +86,10 @@ end function Base.setproperty!(cache::LinearCache, name::Symbol, x) if name === :A - if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs) - Pl, Pr = cache.alg.precs(x, cache.p) - setfield!(cache, :Pl, Pl) - setfield!(cache, :Pr, Pr) - end setfield!(cache, :isfresh, true) + setfield!(cache, :precsisfresh, true) elseif name === :p - if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs) - Pl, Pr = cache.alg.precs(cache.A, x) - setfield!(cache, :Pl, Pl) - setfield!(cache, :Pr, Pr) - end + setfield!(cache, :precsisfresh, true) elseif name === :b # In case there is something that needs to be done when b is updated update_cacheval!(cache, :b, x) @@ -208,11 +201,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions) isfresh = true + precsisfresh = false Tc = typeof(cacheval) cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), - typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) return cache end @@ -223,27 +217,26 @@ function SciMLBase.reinit!(cache::LinearCache; b = cache.b, u = cache.u, p = nothing, - reinit_cache = false,) + reinit_cache = false, + reuse_precs = false) (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache - precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS - Pl, Pr = if isnothing(A) || isnothing(p) - if isnothing(A) - A = cache.A - end - if isnothing(p) - p = cache.p - end - precs(A, p) - else - (cache.Pl, cache.Pr) - end - isfresh = true + isfresh = !isnothing(A) + precsisfresh = !reuse_precs && (isfresh || !isnothing(p)) + isfresh |= cache.isfresh + precsisfresh |= cache.precsisfresh + + A = isnothing(A) ? cache.A : A + b = isnothing(b) ? cache.b : b + u = isnothing(u) ? cache.u : u + p = isnothing(p) ? cache.p : p + Pl = cache.Pl + Pr = cache.Pr if reinit_cache return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval), typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), - typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + typeof(sensealg)}(A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) else cache.A = A @@ -253,6 +246,7 @@ function SciMLBase.reinit!(cache::LinearCache; cache.Pl = Pl cache.Pr = Pr cache.isfresh = true + cache.precsisfresh = precsisfresh end end diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 16a50a27c..f487463c3 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -225,6 +225,12 @@ function init_cacheval(alg::KrylovJL, A, b, u, Pl, Pr, maxiters::Int, abstol, re end function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) + if cache.precsisfresh && !isnothing(alg.precs) + Pl, Pr = alg.precs(cache.A, cache.p) + cache.Pl = Pl + cache.Pr = Pr + cache.precsisfresh = false + end if cache.isfresh solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, cache.maxiters, cache.abstol, cache.reltol, cache.verbose, diff --git a/test/basictests.jl b/test/basictests.jl index 6d274a07f..e9492e4d9 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -284,6 +284,30 @@ end end end + @testset "Reuse precs" begin + num_precs_calls = 0 + + function countingprecs(A, p = nothing) + num_precs_calls += 1 + (BlockJacobiPreconditioner(A, 2), I) + end + + n = 10 + A = spdiagm(-1 => -ones(n - 1), 0 => fill(10.0, n), 1 => -ones(n - 1)) + b = rand(n) + p = LinearProblem(A, b) + x0 = solve(p, KrylovJL_CG(precs = countingprecs, ldiv = false)) + cache = x0.cache + x0 = copy(x0) + for i in 4:(n - 3) + A[i, i + 3] -= 1.0e-4 + A[i - 3, i] -= 1.0e-4 + end + LinearSolve.reinit!(cache; A, reuse_precs = true) + x1 = copy(solve!(cache)) + @test all(x0 .< x1) && num_precs_calls == 1 + end + if VERSION >= v"1.9-" @testset "IterativeSolversJL" begin kwargs = (; gmres_restart = 5)