From 92dff9bfbdc9769103e526e1b27b8b5fa9148f08 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 12 Jun 2024 09:37:43 -0400 Subject: [PATCH 1/6] dont make du, u, p, t arguments to dolinsolve but get the from integrator --- src/misc_utils.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/misc_utils.jl b/src/misc_utils.jl index 10d55b27d9..dcb74c793a 100644 --- a/src/misc_utils.jl +++ b/src/misc_utils.jl @@ -81,8 +81,7 @@ macro threaded(option, ex) end function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing, - du = nothing, u = nothing, p = nothing, t = nothing, - weight = nothing, solverdata = nothing, + p = nothing, weight = nothing, solverdata = nothing, reltol = integrator === nothing ? nothing : integrator.opts.reltol) A !== nothing && (linsolve.A = A) b !== nothing && (linsolve.b = b) @@ -95,8 +94,12 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi _alg = unwrap_alg(integrator, true) - _Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev, - solverdata) + du, u, p, t = if isnothing(integrator) + nothing, nothing, nothing, nothing + else + integrator.du, integrator.u, integrator.p, integrator.t + end + _Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev, solverdata) if (_Pl !== nothing || _Pr !== nothing) __Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl __Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr From 75f8496a59e68f9b8e58175341dd269ccabaa321 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 27 Jun 2024 10:56:54 -0400 Subject: [PATCH 2/6] change precs signature --- src/misc_utils.jl | 25 ++++-------- src/perform_step/rosenbrock_perform_step.jl | 45 ++------------------- 2 files changed, 12 insertions(+), 58 deletions(-) diff --git a/src/misc_utils.jl b/src/misc_utils.jl index dcb74c793a..385ea75772 100644 --- a/src/misc_utils.jl +++ b/src/misc_utils.jl @@ -81,30 +81,21 @@ macro threaded(option, ex) end function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing, - p = nothing, weight = nothing, solverdata = nothing, reltol = integrator === nothing ? nothing : integrator.opts.reltol) A !== nothing && (linsolve.A = A) b !== nothing && (linsolve.b = b) linu !== nothing && (linsolve.u = linu) - Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer : - linsolve.Pl - Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer : - linsolve.Pr - _alg = unwrap_alg(integrator, true) - du, u, p, t = if isnothing(integrator) - nothing, nothing, nothing, nothing - else - integrator.du, integrator.u, integrator.p, integrator.t - end - _Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev, solverdata) - if (_Pl !== nothing || _Pr !== nothing) - __Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl - __Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr - linsolve.Pl = __Pl - linsolve.Pr = __Pr + if !isnothing(A) + _Pl, _Pr = _alg.precs(linsolve.A, integrator) + if (_Pl !== nothing || _Pr !== nothing) + __Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl + __Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr + linsolve.Pl = __Pl + linsolve.Pr = __Pr + end end linres = solve!(linsolve; reltol) diff --git a/src/perform_step/rosenbrock_perform_step.jl b/src/perform_step/rosenbrock_perform_step.jl index 0521dc9779..6e88b4960c 100644 --- a/src/perform_step/rosenbrock_perform_step.jl +++ b/src/perform_step/rosenbrock_perform_step.jl @@ -163,16 +163,7 @@ end integrator.opts.internalnorm, t) linsolve = cache.linsolve - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = γ)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = γ)) - end + linres = dolinsolve(integrator, linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) @inbounds @simd ivdep for i in eachindex(u) k₁[i] = -linres.u[i] @@ -303,16 +294,7 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = γ)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = γ)) - end + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) vecu = _vec(linres.u) veck₁ = _vec(k₁) @@ -662,16 +644,7 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) vecu = _vec(linres.u) veck1 = _vec(k1) @@ -857,17 +830,7 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end - + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) vecu = _vec(linres.u) veck1 = _vec(k1) From fff49b3c107b27035a2f981176ce6de35f35fa32 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 27 Jun 2024 11:59:44 -0400 Subject: [PATCH 3/6] fix default_precs --- src/OrdinaryDiffEq.jl | 4 ++-- src/misc_utils.jl | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 181d6d952d..8561f9a5e4 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -39,7 +39,7 @@ import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_UNSTABLE_CHECK import SciMLOperators: SciMLOperators, AbstractSciMLOperator, AbstractSciMLScalarOperator, - MatrixOperator, FunctionOperator, + MatrixOperator, FunctionOperator, IdentityOperator update_coefficients, update_coefficients!, DEFAULT_UPDATE_FUNC, isconstant @@ -126,7 +126,7 @@ const CompiledFloats = Union{Float32, Float64, import FunctionWrappersWrappers import Preferences -DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing +DEFAULT_PRECS(W, integrator) = (IdentityOperator(size(W, 1)), IdentityOperator(size(W, 2))) include("doc_utils.jl") include("misc_utils.jl") diff --git a/src/misc_utils.jl b/src/misc_utils.jl index 385ea75772..ae83fd18e3 100644 --- a/src/misc_utils.jl +++ b/src/misc_utils.jl @@ -90,12 +90,8 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi if !isnothing(A) _Pl, _Pr = _alg.precs(linsolve.A, integrator) - if (_Pl !== nothing || _Pr !== nothing) - __Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl - __Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr - linsolve.Pl = __Pl - linsolve.Pr = __Pr - end + linsolve.Pl = _Pl + linsolve.Pr = _Pr end linres = solve!(linsolve; reltol) From 837ebfc23eb2b6036dbcb4f85a2d1cb32b5d3c8e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 27 Jun 2024 14:21:51 -0400 Subject: [PATCH 4/6] typo --- src/OrdinaryDiffEq.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 8561f9a5e4..91a0f98d6d 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -39,7 +39,7 @@ import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_UNSTABLE_CHECK import SciMLOperators: SciMLOperators, AbstractSciMLOperator, AbstractSciMLScalarOperator, - MatrixOperator, FunctionOperator, IdentityOperator + MatrixOperator, FunctionOperator, IdentityOperator, update_coefficients, update_coefficients!, DEFAULT_UPDATE_FUNC, isconstant From fa05cbe3241b84772e8bcdf9d3777b7fae89a89c Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 28 Jun 2024 13:54:36 -0400 Subject: [PATCH 5/6] fixes --- src/caches/rosenbrock_caches.jl | 48 +++++++++------------------------ 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/src/caches/rosenbrock_caches.jl b/src/caches/rosenbrock_caches.jl index e2eb013809..c24b51068d 100644 --- a/src/caches/rosenbrock_caches.jl +++ b/src/caches/rosenbrock_caches.jl @@ -97,9 +97,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -143,9 +141,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -291,9 +287,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -377,9 +371,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -570,9 +562,7 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -615,9 +605,7 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -740,9 +728,7 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -802,9 +788,7 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -862,9 +846,7 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -922,9 +904,7 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -1041,9 +1021,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -1105,9 +1083,7 @@ function alg_cache( uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + Pl, Pr = wrapprecs(alg.precs(W, nothing)..., weight, tmp) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) From e0a90f9619b68eaf1a45838067d56e2898f45e58 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 1 Jul 2024 16:58:53 -0400 Subject: [PATCH 6/6] fixes --- src/misc_utils.jl | 6 +++--- src/perform_step/rosenbrock_perform_step.jl | 11 +---------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/misc_utils.jl b/src/misc_utils.jl index ae83fd18e3..0c8899d7e1 100644 --- a/src/misc_utils.jl +++ b/src/misc_utils.jl @@ -89,9 +89,9 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi _alg = unwrap_alg(integrator, true) if !isnothing(A) - _Pl, _Pr = _alg.precs(linsolve.A, integrator) - linsolve.Pl = _Pl - linsolve.Pr = _Pr + (;du, u, p, t) = integrator + p = isnothing(integrator) ? nothing : (du, u, p, t) + reinit!(linsolve; A, p) end linres = solve!(linsolve; reltol) diff --git a/src/perform_step/rosenbrock_perform_step.jl b/src/perform_step/rosenbrock_perform_step.jl index 6e88b4960c..99ce2dfb9a 100644 --- a/src/perform_step/rosenbrock_perform_step.jl +++ b/src/perform_step/rosenbrock_perform_step.jl @@ -50,16 +50,7 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = γ)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = γ)) - end + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) vecu = _vec(linres.u) veck₁ = _vec(k₁)