diff --git a/ext/LinearSolvePardisoExt.jl b/ext/LinearSolvePardisoExt.jl index 0b4cfbb1..a6d2188e 100644 --- a/ext/LinearSolvePardisoExt.jl +++ b/ext/LinearSolvePardisoExt.jl @@ -132,6 +132,11 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs @unpack A, b, u = cache A = convert(AbstractMatrix, A) if cache.isfresh + if hasproperty(alg, :precs) && !isnothing(alg.precs) + Pl, Pr = cache.alg.precs(x, cache.p) + cache.Pl = Pl + cache.Pr = Pr + end phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT Pardiso.set_phase!(cache.cacheval, phase) Pardiso.pardiso(cache.cacheval, A, eltype(A)[]) diff --git a/src/common.jl b/src/common.jl index f212fe25..fbd581c9 100644 --- a/src/common.jl +++ b/src/common.jl @@ -84,19 +84,8 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S} end function Base.setproperty!(cache::LinearCache, name::Symbol, x) - if name === :A - if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs) - Pl, Pr = cache.alg.precs(x, cache.p) - setfield!(cache, :Pl, Pl) - setfield!(cache, :Pr, Pr) - end + if name === :A || name === :p setfield!(cache, :isfresh, true) - elseif name === :p - if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs) - Pl, Pr = cache.alg.precs(cache.A, x) - setfield!(cache, :Pl, Pl) - setfield!(cache, :Pr, Pr) - end elseif name === :b # In case there is something that needs to be done when b is updated update_cacheval!(cache, :b, x) @@ -224,20 +213,7 @@ function SciMLBase.reinit!(cache::LinearCache; u = cache.u, p = nothing, reinit_cache = false,) - (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache - - precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS - Pl, Pr = if isnothing(A) || isnothing(p) - if isnothing(A) - A = cache.A - end - if isnothing(p) - p = cache.p - end - precs(A, p) - else - (cache.Pl, cache.Pr) - end + (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg, Pl, Pr) = cache isfresh = true if reinit_cache @@ -250,8 +226,6 @@ function SciMLBase.reinit!(cache::LinearCache; cache.b = b cache.u = u cache.p = p - cache.Pl = Pl - cache.Pr = Pr cache.isfresh = true end end diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 16a50a27..48905605 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -226,6 +226,11 @@ end function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) if cache.isfresh + if hasproperty(alg, :precs) && !isnothing(alg.precs) + Pl, Pr = cache.alg.precs(x, cache.p) + cache.Pl = Pl + cache.Pr = Pr + end solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, cache.maxiters, cache.abstol, cache.reltol, cache.verbose, cache.assumptions, zeroinit = false)