From d7c4635aa118c3a7d95c43884c82ada3e14608ee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Sep 2023 18:47:54 -0400 Subject: [PATCH] Allow external control of linear solver kwargs --- src/sensitivity_algorithms.jl | 10 ++++++---- src/steadystate_adjoint.jl | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 789be4150..c04e2d895 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1103,11 +1103,12 @@ documentation page or the docstrings of the vjp types. Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) """ -struct SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM <: AbstractSSAdjointLinsolveMethod} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} +struct SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM <: AbstractSSAdjointLinsolveMethod, LK} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS linsolve_method::LM uniform_blocked_diagonal_jacobian::Bool + linsolve_kwargs::LK end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint @@ -1115,10 +1116,11 @@ TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true, diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing, concrete_jac = false, uniform_blocked_diagonal_jacobian::Bool = false, - linsolve_method = SSAdjointHeuristicLinsolve(50)) + linsolve_method = SSAdjointHeuristicLinsolve(50), linsolve_kwargs = (;)) return SteadyStateAdjoint{_unwrap_val(concrete_jac), chunk_size, autodiff, diff_type, - typeof(autojacvec), typeof(linsolve), typeof(linsolve_method)}(autojacvec, linsolve, - linsolve_method, uniform_blocked_diagonal_jacobian) + typeof(autojacvec), typeof(linsolve), typeof(linsolve_method), + typeof(linsolve_kwargs)}(autojacvec, linsolve, linsolve_method, + uniform_blocked_diagonal_jacobian, linsolve_kwargs) end function needs_concrete_jac(S::SteadyStateAdjoint{CJ}, u0) where {CJ} diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 8fd1b2b81..4be78d8dc 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -131,7 +131,7 @@ end end linear_problem = LinearProblem(A_, vec(dgdu_val'); u0 = vec(λ)) - sol = solve(linear_problem, linsolve; alias_A = true) # u is vec(λ) + sol = solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ) try vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing)