Skip to content

Commit

Permalink
Allow external control of linear solver kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 25, 2023
1 parent 28cfa9c commit d7c4635
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1103,22 +1103,24 @@ documentation page or the docstrings of the vjp types.
Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at
http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007)
"""
struct SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM <: AbstractSSAdjointLinsolveMethod} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT}
struct SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM <: AbstractSSAdjointLinsolveMethod, LK} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT}
autojacvec::VJP
linsolve::LS
linsolve_method::LM
uniform_blocked_diagonal_jacobian::Bool
linsolve_kwargs::LK
end

TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint

Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true,
diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing,
concrete_jac = false, uniform_blocked_diagonal_jacobian::Bool = false,
linsolve_method = SSAdjointHeuristicLinsolve(50))
linsolve_method = SSAdjointHeuristicLinsolve(50), linsolve_kwargs = (;))
return SteadyStateAdjoint{_unwrap_val(concrete_jac), chunk_size, autodiff, diff_type,
typeof(autojacvec), typeof(linsolve), typeof(linsolve_method)}(autojacvec, linsolve,
linsolve_method, uniform_blocked_diagonal_jacobian)
typeof(autojacvec), typeof(linsolve), typeof(linsolve_method),
typeof(linsolve_kwargs)}(autojacvec, linsolve, linsolve_method,
uniform_blocked_diagonal_jacobian, linsolve_kwargs)
end

function needs_concrete_jac(S::SteadyStateAdjoint{CJ}, u0) where {CJ}
Expand Down
2 changes: 1 addition & 1 deletion src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ end
end

linear_problem = LinearProblem(A_, vec(dgdu_val'); u0 = vec(λ))
sol = solve(linear_problem, linsolve; alias_A = true) # u is vec(λ)
sol = solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ)

try
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing)
Expand Down

0 comments on commit d7c4635

Please sign in to comment.