From 53c4bdb13530f0aa5c4bec753fe78e05da692a25 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 21 Aug 2023 17:51:50 -0400 Subject: [PATCH 1/8] Force Jacobian Construction --- .gitignore | 3 ++- src/SciMLSensitivity.jl | 2 +- src/adjoint_common.jl | 11 ++++++----- src/sensitivity_algorithms.jl | 27 +++++++++++++-------------- src/steadystate_adjoint.jl | 30 +++++++++++++++--------------- 5 files changed, 37 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index f3fc3a6b4..8b8e929b3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ *.jl.*.cov *.jl.mem Manifest.toml -/docs/build/ \ No newline at end of file +/docs/build/ +.vscode diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 3c4ffbc22..dcca83a9d 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -34,7 +34,7 @@ import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Pr abstract type SensitivityFunction end abstract type TransformedFunction end -import SciMLBase: unwrapped_f +import SciMLBase: unwrapped_f, _unwrap_val import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm, AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm, diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 82ca3b3e4..d22e699ee 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -29,8 +29,8 @@ TruncatedStacktraces.@truncate_stacktrace AdjointDiffCache return (AdjointDiffCache, y) """ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f, alg; - quad = false, - noiseterm = false, needs_jac = false) where {G, DG1, DG2} + quad = false, noiseterm = false, needs_jac = false, + jac_prototype = nothing) where {G, DG1, DG2} prob = sol.prob if prob isa Union{SteadyStateProblem, NonlinearProblem} @unpack u0, p = prob @@ -104,12 +104,13 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f if !needs_jac && !issemiexplicitdae && !(autojacvec isa Bool) J = nothing else + # jac_prototype can be provided if we want to exploit sparsity + J_ = jac_prototype !== nothing ? jac_prototype : similar(u0, numindvar, numindvar) if SciMLBase.forwarddiffs_model_time(alg) # 1 chunk is fine because it's only t - J = dualcache(similar(u0, numindvar, numindvar), - ForwardDiff.pickchunksize(length(u0))) + J = dualcache(J_, ForwardDiff.pickchunksize(length(u0))) else - J = similar(u0, numindvar, numindvar) + J = J_ end end diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 9ad688310..0f3544e14 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1076,7 +1076,7 @@ documentation page or the docstrings of the vjp types. Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) """ -struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS} <: +struct SteadyStateAdjoint{CJ, BD, CS, AD, FDT, VJP, LS} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS @@ -1085,20 +1085,19 @@ end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true, - diff_type = Val{:central}, - autojacvec = nothing, linsolve = nothing) - SteadyStateAdjoint{ - chunk_size, - autodiff, - diff_type, - typeof(autojacvec), - typeof(linsolve), - }(autojacvec, - linsolve) + diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing, + concrete_jac = false, assume_uniform_blocked_diagonal = false) + return SteadyStateAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec), + typeof(linsolve), _unwrap_val(concrete_jac), + _unwrap_val(assume_uniform_blocked_diagonal)}(autojacvec, linsolve) end -function setvjp(sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS}, - vjp) where {CS, AD, FDT, VJP, LS} - SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS}(vjp, sensealg.linsolve) + +needs_concrete_jac(::SteadyStateAdjoint{CJ}) where {CJ} = Val(CJ) +assume_uniform_blocked_diagonal(::SteadyStateAdjoint{CJ, BD}) where {CJ, BD} = Val(BD) + +function setvjp(sensealg::SteadyStateAdjoint{CJ, BD, CS, AD, FDT, VJP, LS}, + vjp) where {CJ, BD, CS, AD, FDT, VJP, LS} + SteadyStateAdjoint{CJ, BD, CS, AD, FDT, typeof(vjp), LS}(vjp, sensealg.linsolve) end abstract type VJPChoice end diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index e24ba04c1..fa389c35c 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -23,28 +23,28 @@ end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjointSensitivityFunction function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, f, - colorvec, needs_jac) + colorvec, needs_jac, jac_prototype) @unpack p, u0 = sol.prob - diffcache, y = adjointdiffcache(g, sensealg, false, sol, dgdu, dgdp, f, alg; - quad = false, needs_jac) + diffcache, y = adjointdiffcache(g, sensealg, false, sol, dgdu, dgdp, f, alg; needs_jac, + jac_prototype) λ = zero(y) + # Override the choice of the user if we feel that it is not a fast enough choice. linsolve = needs_jac ? nothing : sensealg.linsolve vjp = similar(λ, length(p)) - SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec, λ, vjp, - linsolve) + return SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec, + λ, vjp, linsolve) end @noinline function SteadyStateAdjointProblem(sol, sensealg::SteadyStateAdjoint, alg, dgdu::DG1 = nothing, dgdp::DG2 = nothing, g::G = nothing; kwargs...) where {DG1, DG2, G} + # TODO: Sparsity Exploiting @unpack f, p, u0 = sol.prob - if sol.prob isa NonlinearProblem - f = ODEFunction(f) - end + sol.prob isa NonlinearProblem && (f = ODEFunction(f)) dgdu === nothing && dgdp === nothing && g === nothing && error("Either `dgdu`, `dgdp`, or `g` must be specified.") @@ -53,17 +53,17 @@ end false # TODO: What is the correct heuristic? Can we afford to compute Jacobian for # cases where the length(u0) > 50 and if yes till what threshold - elseif sensealg.linsolve === nothing - length(u0) <= 50 - else - LinearSolve.needs_concrete_A(sensealg.linsolve) - end + needs_jac = needs_concrete_jac(sensealg) || + (sensealg.linsolve === nothing && length(u0) ≤ 50) || + LinearSolve.needs_concrete_A(sensealg.linsolve) p === DiffEqBase.NullParameters() && error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") + # TODO: Specify jac_prototype for sparse problems + jac_prototype = nothing sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, - f, f.colorvec, needs_jac) + f, f.colorvec, needs_jac, jac_prototype) @unpack diffcache, y, sol, λ, vjp, linsolve = sense if needs_jac @@ -115,7 +115,7 @@ end vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing) catch e if sense.sensealg.autojacvec === nothing - @warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp" + @warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to nonlinear solve adjoint + numerical vjp" vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp, dy = nothing) else From 8ed708814cea44abdb55ff648dccf8f6f30ad587 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 28 Aug 2023 16:15:17 -0400 Subject: [PATCH 2/8] Allow users to control the linear solve in SteadyStateAdjoint --- Project.toml | 2 + src/SciMLSensitivity.jl | 5 ++ src/sensitivity_algorithms.jl | 94 ++++++++++++++++++++++++++++------- src/steadystate_adjoint.jl | 74 +++++++++++++-------------- 4 files changed, 119 insertions(+), 56 deletions(-) diff --git a/Project.toml b/Project.toml index fae2e7687..f2a508910 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "7.40.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +BandedMatrices = "aae01518-5342-5314-be14-df237901396f" Cassette = "7057c7e9-c182-5462-911a-8362d720325c" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" @@ -35,6 +36,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index dcca83a9d..4671e7852 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -17,6 +17,8 @@ using StaticArraysCore using ADTypes using SparseDiffTools using SciMLOperators +using BandedMatrices +using SparseArrays import TruncatedStacktraces import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache @@ -84,4 +86,7 @@ export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP export StochasticTransformedFunction +export SSAdjointFullJacobianLinsolve, SSAdjointIterativeVJPLinsolve, + SSAdjointHeuristicLinsolve + end # module diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 0f3544e14..937ea15d9 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -304,7 +304,7 @@ InterpolatingAdjoint(; chunk_size = 0, autodiff = true, differentiation with special seeding. The total set of choices are: + `nothing`: uses an automatic algorithm to automatically choose the vjp. - This is the default and recommended for most users. + This is the default and recommended for most users. + `false`: the Jacobian is constructed via FiniteDiff.jl + `true`: the Jacobian is constructed via ForwardDiff.jl + `TrackerVJP`: Uses Tracker.jl for the vjp. @@ -1024,9 +1024,17 @@ Base.@pure function NILSAS(nseg, nstep, M = nothing; nseg, nstep, g) end +abstract type AbstractSSAdjointLinsolveMethod end + +struct SSAdjointFullJacobianLinsolve <: AbstractSSAdjointLinsolveMethod end +struct SSAdjointIterativeVJPLinsolve <: AbstractSSAdjointLinsolveMethod end +Base.@kwdef struct SSAdjointHeuristicLinsolve <: AbstractSSAdjointLinsolveMethod + auto_switch_threshold::Int = 50 +end + """ ```julia -SteadyStateAdjoint{CS, AD, FDT, VJP, LS} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} +SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM <: AbstractSSAdjointLinsolveMethod} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} ``` An implementation of the adjoint differentiation of a nonlinear solve. Uses the @@ -1036,9 +1044,10 @@ implicit function theorem to directly compute the derivative of the solution to ## Constructor ```julia -SteadyStateAdjoint(; chunk_size = 0, autodiff = true, - diff_type = Val{:central}, - autojacvec = autodiff, linsolve = nothing) +SteadyStateAdjoint(; chunk_size = 0, autodiff = true, diff_type = Val{:central}, + autojacvec = nothing, linsolve = nothing, linsolve = nothing, + concrete_jac = false, uniform_blocked_diagonal_jacobian::Bool = false, + linsolve_method = SSAdjointHeuristicLinsolve(50)) ``` ## Keyword Arguments @@ -1065,8 +1074,26 @@ SteadyStateAdjoint(; chunk_size = 0, autodiff = true, is a boolean for whether to precompile the tape, which should only be done if there are no branches (`if` or `while` statements) in the `f` function. - `linsolve`: the linear solver used in the adjoint solve. Defaults to `nothing`, - which uses a polyalgorithm to choose an efficient - algorithm automatically. + which uses a polyalgorithm to choose an efficient algorithm automatically. + - `concrete_jac`: If `true`, ignore every other directive and mandatorily construct + the Jacobian. (default: `false`) + - `uniform_blocked_diagonal_jacobian`: If `true`, the jacobian is assumed to be uniformly + block diagonal with a block size = `div(length(u0), size(u0, ndims(u0)))`. This allows + using sparse differentiation to construct the Jacobian or specialized iterative linear + solvers. + - `linsolve_method`: The method used to solve the linear system. There are the following + choices: (it is recommended to use the default `SSAdjointHeuristicLinsolve`, unless + benchmarking or an analytic jacobian is known) + + + `SSAdjointFullJacobianLinsolve`: Construct the full Jacobian and solve the linear + system. This is typically efficient for small systems. + + `SSAdjointIterativeVJPLinsolve`: Use an iterative method to solve the linear system + with the vjp. This is typically efficient for large systems. + + `SSAdjointHeuristicLinsolve(; auto_switch_threshold = 50)`: Use a heuristic to + automatically determine if we should construct the jacobian or use an iterative + solve. If the Jacobian is constructed for any other reason + (like `concrete_jac=true`), we will not use an iterative solver (unless forced via + `linsolve`). For more details on the vjp choices, please consult the sensitivity algorithms documentation page or the docstrings of the vjp types. @@ -1076,28 +1103,54 @@ documentation page or the docstrings of the vjp types. Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) """ -struct SteadyStateAdjoint{CJ, BD, CS, AD, FDT, VJP, LS} <: - AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} +struct SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM <: AbstractSSAdjointLinsolveMethod} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS + linsolve_method::LM + uniform_blocked_diagonal_jacobian::Bool end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true, diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing, - concrete_jac = false, assume_uniform_blocked_diagonal = false) - return SteadyStateAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec), - typeof(linsolve), _unwrap_val(concrete_jac), - _unwrap_val(assume_uniform_blocked_diagonal)}(autojacvec, linsolve) + concrete_jac = false, uniform_blocked_diagonal_jacobian::Bool = false, + linsolve_method = SSAdjointHeuristicLinsolve(50)) + return SteadyStateAdjoint{_unwrap_val(concrete_jac), chunk_size, autodiff, diff_type, + typeof(autojacvec), typeof(linsolve), typeof(linsolve_method)}(autojacvec, linsolve, + linsolve_method, uniform_blocked_diagonal_jacobian) end -needs_concrete_jac(::SteadyStateAdjoint{CJ}) where {CJ} = Val(CJ) -assume_uniform_blocked_diagonal(::SteadyStateAdjoint{CJ, BD}) where {CJ, BD} = Val(BD) +function needs_concrete_jac(S::SteadyStateAdjoint{CJ}, u0) where {CJ} + # Force Jacobian Construction + CJ && return true + # Check if the users wants us to use a specific method + lm_needs = needs_concrete_jac(S, S.linsolve_method, u0) + lm_needs && return true + # If nothing is true then only construct if the linear solver needs the jacobian + return S.linsolve === nothing ? false : LinearSolve.needs_concrete_A(S.linsolve) +end -function setvjp(sensealg::SteadyStateAdjoint{CJ, BD, CS, AD, FDT, VJP, LS}, - vjp) where {CJ, BD, CS, AD, FDT, VJP, LS} - SteadyStateAdjoint{CJ, BD, CS, AD, FDT, typeof(vjp), LS}(vjp, sensealg.linsolve) +function jacobian_adtype(S::SteadyStateAdjoint{CJ, CS, AD}) where {CJ, CS, AD} + # FIXME: Don't ignore the chunk size. Will need to verify upstream. + if S.uniform_blocked_diagonal_jacobian + return AD ? AutoSparseForwardDiff() : AutoSparseFiniteDiff() + else + return AD ? AutoForwardDiff() : AutoFiniteDiff() + end +end + +function setvjp(sensealg::SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM}, + vjp) where {CJ, CS, AD, FDT, VJP, LS, LM} + return SteadyStateAdjoint{CJ, CS, AD, FDT, typeof(vjp), LS, LM}(vjp, + sensealg.linsolve, sensealg.linsolve_method, + sensealg.uniform_blocked_diagonal_jacobian) +end + +needs_concrete_jac(::SteadyStateAdjoint, ::SSAdjointFullJacobianLinsolve, _) = true +needs_concrete_jac(::SteadyStateAdjoint, ::SSAdjointIterativeVJPLinsolve, _) = false +function needs_concrete_jac(S::SteadyStateAdjoint, L::SSAdjointHeuristicLinsolve, u0) + return length(u0) ≤ L.auto_switch_threshold end abstract type VJPChoice end @@ -1313,7 +1366,10 @@ struct ForwardDiffOverAdjoint{A} <: adjalg::A end -get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where{compile} = AutoReverseDiff(; compile = compile) +function get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where{compile} + return AutoReverseDiff(; compile) +end get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote() get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() get_autodiff_from_vjp(::TrackerVJP) = AutoTracker() +get_autodiff_from_vjp(::Nothing) = AutoZygote() diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index fa389c35c..716ea9076 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -1,14 +1,5 @@ -struct SteadyStateAdjointSensitivityFunction{ - C <: AdjointDiffCache, - Alg <: SteadyStateAdjoint, - uType, - SType, - fType <: ODEFunction, - CV, - λType, - VJPType, - LS, -} <: SensitivityFunction +struct SteadyStateAdjointSensitivityFunction{C<:AdjointDiffCache, Alg<:SteadyStateAdjoint, + uType, SType, fType<:ODEFunction, CV, λType, VJPType, LS,} <: SensitivityFunction diffcache::C sensealg::Alg y::uType @@ -39,29 +30,31 @@ function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp end @noinline function SteadyStateAdjointProblem(sol, sensealg::SteadyStateAdjoint, alg, - dgdu::DG1 = nothing, dgdp::DG2 = nothing, - g::G = nothing; kwargs...) where {DG1, DG2, G} - # TODO: Sparsity Exploiting + dgdu::DG1 = nothing, dgdp::DG2 = nothing, g::G = nothing; kwargs...) where {DG1, DG2, G} @unpack f, p, u0 = sol.prob - sol.prob isa NonlinearProblem && (f = ODEFunction(f)) + u0_len = length(u0) dgdu === nothing && dgdp === nothing && g === nothing && error("Either `dgdu`, `dgdp`, or `g` must be specified.") - needs_jac = if has_adjoint(f) - false - # TODO: What is the correct heuristic? Can we afford to compute Jacobian for - # cases where the length(u0) > 50 and if yes till what threshold - needs_jac = needs_concrete_jac(sensealg) || - (sensealg.linsolve === nothing && length(u0) ≤ 50) || - LinearSolve.needs_concrete_A(sensealg.linsolve) + needs_jac = needs_concrete_jac(sensealg, u0) + sensealg.uniform_blocked_diagonal_jacobian && (blocksize = u0_len ÷ size(u0, ndims(u0))) p === DiffEqBase.NullParameters() && error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") - # TODO: Specify jac_prototype for sparse problems - jac_prototype = nothing + if needs_jac + if sensealg.uniform_blocked_diagonal_jacobian + jac_prototype = sparse(BandedMatrix(similar(u0, u0_len, u0_len), + (blocksize, blocksize))) + sd = JacPrototypeSparsityDetection(; jac_prototype) + else + sd = NoSparsityDetection() + end + else + jac_prototype = nothing + end sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, f, f.colorvec, needs_jac, jac_prototype) @unpack diffcache, y, sol, λ, vjp, linsolve = sense @@ -71,10 +64,12 @@ end f.jac(diffcache.J, y, p, nothing) else if DiffEqBase.isinplace(sol.prob) - jacobian!(diffcache.J, diffcache.uf, y, diffcache.f_cache, - sensealg, diffcache.jac_config) + # TODO: reuse diffcache.jac_config?? + sparse_jacobian!(diffcache.J, jacobian_adtype(sensealg), sd, diffcache.uf, + diffcache.f_cache, y) else - diffcache.J .= jacobian(diffcache.uf, y, sensealg) + sparse_jacobian!(diffcache.J, jacobian_adtype(sensealg), sd, diffcache.uf, + y) end end end @@ -100,24 +95,29 @@ end end if !needs_jac - # operator = VecJac(f, y, p; Val(DiffEqBase.isinplace(sol.prob))) - __f = y -> f(y, p, nothing) - operator = VecJac(__f, y; autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) - linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) + __f(x) = vec(f(reshape(x, size(y)), p, nothing)) + operator = VecJac(__f, vec(y); autodiff=get_autodiff_from_vjp(sensealg.autojacvec)) + + if linsolve === nothing && sensealg.uniform_blocked_diagonal_jacobian + @warn "linsolve not specified, and Jacobian is specified to be uniform blocked diagonal. Using SimpleGMRES with blocksize $blocksize" maxlog=1 + linsolve = SimpleGMRES(; blocksize, restart=false) + end + + A_ = operator else - linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ)) + A_ = diffcache.J' end - # Zygote pullback function won't work with deepcopy - solve(linear_problem, linsolve; alias_A = true) # u is vec(λ) + linear_problem = LinearProblem(A_, vec(dgdu_val'); u0=vec(λ)) + sol = solve(linear_problem, linsolve; alias_A=true) # u is vec(λ) try - vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing) + vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad=vjp, dy=nothing) catch e if sense.sensealg.autojacvec === nothing @warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to nonlinear solve adjoint + numerical vjp" - vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp, - dy = nothing) + vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad=vjp, + dy=nothing) else @warn "AD choice of autojacvec failed in nonlinear solve adjoint" throw(e) From b3038f90f5e3666e8c1ce88aceae9cdc527d1332 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 28 Aug 2023 17:29:45 -0400 Subject: [PATCH 3/8] Testing updated SteadyStateAdjoint --- src/concrete_solve.jl | 98 +++++++++++++++++++++++++++++------ src/sensitivity_algorithms.jl | 4 ++ src/steadystate_adjoint.jl | 15 ++++-- test/steady_state.jl | 67 ++++++++++++++++++++++++ 4 files changed, 165 insertions(+), 19 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 4a9f7354b..1c848a053 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -6,19 +6,45 @@ const have_not_warned_vjp = Ref(true) const STACKTRACE_WITH_VJPWARN = Ref(false) +__unwrapped_f(f) = unwrapped_f(f) +__unwrapped_f(f::NonlinearFunction) = f.f + function inplace_vjp(prob, u0, p, verbose) du = copy(u0) + tspan_nothing = hasmethod(__unwrapped_f(prob.f), + Tuple{typeof(du), typeof(u0), typeof(p), Nothing}) + no_tspan = (!hasfield(typeof(prob), :tspan) || !(hasmethod(__unwrapped_f(prob.f), + Tuple{typeof(du), typeof(u0), typeof(p), typeof(first(prob.tspan))}))) && + !tspan_nothing + if !no_tspan + if tspan_nothing + __t = nothing + else + __t = first(prob.tspan) + end + end + ez = try f = unwrapped_f(prob.f) - function adfunc(out, u, _p, t) - f(out, u, _p, t) - nothing + if no_tspan + function adfunc_nlprob(out, u, _p) + f(out, u, _p) + nothing + end + Enzyme.autodiff(Enzyme.Reverse, adfunc_nlprob, Enzyme.Duplicated(du, du), + copy(u0), copy(p)) + true + else + function adfunc(out, u, _p, t) + f(out, u, _p, t) + nothing + end + Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(du, du), + copy(u0), copy(p), __t) + true end - Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(du, du), - copy(u0), copy(p), prob.tspan[1]) - true catch e if verbose || have_not_warned_vjp[] @warn "Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.\n" @@ -36,9 +62,17 @@ function inplace_vjp(prob, u0, p, verbose) compile = try f = unwrapped_f(prob.f) if DiffEqBase.isinplace(prob) - !hasbranching(f, copy(u0), u0, p, prob.tspan[1]) + if no_tspan + !hasbranching(f, copy(u0), u0, p) + else + !hasbranching(f, copy(u0), u0, p, __t) + end else - !hasbranching(f, u0, p, prob.tspan[1]) + if no_tspan + !hasbranching(f, u0, p) + else + !hasbranching(f, u0, p, __t) + end end catch false @@ -47,16 +81,48 @@ function inplace_vjp(prob, u0, p, verbose) vjp = try f = unwrapped_f(prob.f) if p === nothing || p isa SciMLBase.NullParameters - ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t - du1 = similar(u, size(u)) - f(du1, u, p, first(t)) - return vec(du1) + if no_tspan + ReverseDiff.GradientTape((copy(u0),)) do u + du1 = similar(u, size(u)) + f(du1, u, p) + return vec(du1) + end + else + if tspan_nothing + ReverseDiff.GradientTape((copy(u0),)) do u + du1 = similar(u, size(u)) + f(du1, u, p, nothing) + return vec(du1) + end + else + ReverseDiff.GradientTape((copy(u0), [__t])) do u, t + du1 = similar(u, size(u)) + f(du1, u, p, first(t)) + return vec(du1) + end + end end else - ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t - du1 = similar(u, size(u)) - f(du1, u, p, first(t)) - return vec(du1) + if no_tspan + ReverseDiff.GradientTape((copy(u0), p)) do u, p + du1 = similar(u, size(u)) + f(du1, u, p) + return vec(du1) + end + else + if tspan_nothing + ReverseDiff.GradientTape((copy(u0), p)) do u, p + du1 = similar(u, size(u)) + f(du1, u, p, nothing) + return vec(du1) + end + else + ReverseDiff.GradientTape((copy(u0), p, [__t])) do u, p, t + du1 = similar(u, size(u)) + f(du1, u, p, first(t)) + return vec(du1) + end + end end end ReverseDiffVJP(compile) diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 937ea15d9..504542fcf 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1373,3 +1373,7 @@ get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote() get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() get_autodiff_from_vjp(::TrackerVJP) = AutoTracker() get_autodiff_from_vjp(::Nothing) = AutoZygote() +function get_autodiff_from_vjp(b::Bool) + b && return AutoForwardDiff() + return AutoFiniteDiff() +end diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 716ea9076..015222373 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -50,6 +50,7 @@ end (blocksize, blocksize))) sd = JacPrototypeSparsityDetection(; jac_prototype) else + jac_prototype = nothing sd = NoSparsityDetection() end else @@ -95,11 +96,19 @@ end end if !needs_jac - __f(x) = vec(f(reshape(x, size(y)), p, nothing)) - operator = VecJac(__f, vec(y); autodiff=get_autodiff_from_vjp(sensealg.autojacvec)) + if DiffEqBase.isinplace(sol.prob) + __f_iip(fx, x) = f(reshape(fx, size(diffcache.f_cache)), reshape(x, size(y)), p, + nothing) + operator = VecJac(__f_iip, vec(diffcache.f_cache), vec(y); + autodiff=get_autodiff_from_vjp(sensealg.autojacvec)) + else + __f_oop(x) = vec(f(reshape(x, size(y)), p, nothing)) + operator = VecJac(__f_oop, vec(y); + autodiff=get_autodiff_from_vjp(sensealg.autojacvec)) + end if linsolve === nothing && sensealg.uniform_blocked_diagonal_jacobian - @warn "linsolve not specified, and Jacobian is specified to be uniform blocked diagonal. Using SimpleGMRES with blocksize $blocksize" maxlog=1 + @warn "linsolve not specified, and Jacobian is specified to be uniform block diagonal. Using SimpleGMRES with blocksize $blocksize" maxlog=1 linsolve = SimpleGMRES(; blocksize, restart=false) end diff --git a/test/steady_state.jl b/test/steady_state.jl index b3b4c3380..03970df22 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -587,3 +587,70 @@ end @test dp≈Zdp atol=1e-4 end end + +@testset "High Level Interface to Control Steady State Adjoint Internals" begin + u0 = zeros(32) + p = [2.0, 1.0] + + # Diagonal Jacobian Problem + prob = NonlinearProblem((u, p) -> u .- p[1] .+ p[2], u0, p) + solve1 = solve(remake(prob, p = p), NewtonRaphson()) + + function test_loss(p, prob; alg = NewtonRaphson(), + sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP())) + _prob = remake(prob, p = p) + sol = sum(solve(_prob, alg; sensealg)) + return sol + end + + test_loss(p, prob) + + dp1 = Zygote.gradient(p -> test_loss(p, prob), p)[1] + + @test dp1[1] ≈ 32 + @test dp1[2] ≈ -32 + + for uniform_blocked_diagonal_jacobian in (true, false), + linsolve_method in (SSAdjointFullJacobianLinsolve(), + SSAdjointIterativeVJPLinsolve(), SSAdjointHeuristicLinsolve()), + concrete_jac in (true, false) + + sensealg = SteadyStateAdjoint(; autojacvec = ZygoteVJP(), + uniform_blocked_diagonal_jacobian, linsolve_method, concrete_jac) + test_loss(p, prob; sensealg) + dp1 = Zygote.gradient(p -> test_loss(p, prob; sensealg), p)[1] + + @test dp1[1] ≈ 32 + @test dp1[2] ≈ -32 + end + + # Inplace version + prob = NonlinearProblem((du, u, p) -> (du .= u .- p[1] .+ p[2]), u0, p) + + function test_loss(p, prob; alg = NewtonRaphson(), sensealg = SteadyStateAdjoint()) + _prob = remake(prob, p = p) + sol = sum(solve(_prob, alg; sensealg)) + return sol + end + + test_loss(p, prob) + + dp1 = Zygote.gradient(p -> test_loss(p, prob), p)[1] + + @test dp1[1] ≈ 32 + @test dp1[2] ≈ -32 + + for uniform_blocked_diagonal_jacobian in (true, false), + linsolve_method in (SSAdjointFullJacobianLinsolve(), + SSAdjointIterativeVJPLinsolve(), SSAdjointHeuristicLinsolve()), + concrete_jac in (true, false) + + sensealg = SteadyStateAdjoint(; uniform_blocked_diagonal_jacobian, linsolve_method, + concrete_jac) + test_loss(p, prob; sensealg) + dp1 = Zygote.gradient(p -> test_loss(p, prob; sensealg), p)[1] + + @test dp1[1] ≈ 32 + @test dp1[2] ≈ -32 + end +end From 5c78ebddad40f12b93ba9051a06242ffc6ab6404 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 29 Aug 2023 11:26:41 -0400 Subject: [PATCH 4/8] Error on misspecifying arguments --- src/sensitivity_algorithms.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 504542fcf..789be4150 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1148,7 +1148,12 @@ function setvjp(sensealg::SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM}, end needs_concrete_jac(::SteadyStateAdjoint, ::SSAdjointFullJacobianLinsolve, _) = true -needs_concrete_jac(::SteadyStateAdjoint, ::SSAdjointIterativeVJPLinsolve, _) = false +function needs_concrete_jac(S::SteadyStateAdjoint, ::SSAdjointIterativeVJPLinsolve, _) + if S.linsolve !== nothing && LinearSolve.needs_concrete_A(S.linsolve) + error("$(S.linsolve) requires a concrete Matrix. Cannot be solved using the Iterative VJPs!") + end + return false +end function needs_concrete_jac(S::SteadyStateAdjoint, L::SSAdjointHeuristicLinsolve, u0) return length(u0) ≤ L.auto_switch_threshold end From 4084fd1beb5e97288cc7a1834f18fbf95bd65061 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 16 Sep 2023 15:09:13 -0400 Subject: [PATCH 5/8] Support Non Array Parameters for Steady State Adjoint --- Project.toml | 3 +++ src/SciMLSensitivity.jl | 1 + src/derivative_wrappers.jl | 12 ++++++++--- src/steadystate_adjoint.jl | 44 +++++++++++++++++++++++++------------- 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index f2a508910..d23cea222 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -50,6 +51,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" ADTypes = "0.1, 0.2" Adapt = "1.0, 2.0, 3.0" ArrayInterface = "7" +BandedMatrices = "0.17" Cassette = "0.3.6" ChainRulesCore = "0.10.7, 1" DiffEqBase = "6.93" @@ -62,6 +64,7 @@ Enzyme = "0.11.6" FiniteDiff = "2" ForwardDiff = "0.10" FunctionWrappersWrappers = "0.1" +Functors = "0.4" GPUArraysCore = "0.1" LinearSolve = "2" OrdinaryDiffEq = "6.19.1" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 4671e7852..958beca4b 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -19,6 +19,7 @@ using SparseDiffTools using SciMLOperators using BandedMatrices using SparseArrays +using Functors import TruncatedStacktraces import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 21ae19858..1119804c2 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -534,6 +534,12 @@ function Base.showerror(io::IO, e::ZygoteVJPNothingError) print(io, ZYGOTEVJP_NOTHING_MESSAGE) end +recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x) +recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x) +recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} = + map(recursive_copyto!, values(y), values(x)) +recursive_copyto!(y, x) = fmap(recursive_copyto!, y, x) + function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, W) where {TS <: SensitivityFunction} @unpack sensealg = S @@ -579,20 +585,20 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, end # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(_dy)) + dy !== nothing && recursive_copyto!(dy, _dy) tmp1, tmp2 = back(λ) if tmp1 === nothing && !sensealg.autojacvec.allow_nothing throw(ZygoteVJPNothingError()) elseif tmp1 !== nothing - dλ !== nothing && (dλ[:] .= vec(tmp1)) + dλ !== nothing && recursive_copyto!(dλ, tmp1) end if dgrad !== nothing if tmp2 === nothing && !sensealg.autojacvec.allow_nothing throw(ZygoteVJPNothingError()) elseif tmp2 !== nothing - !isempty(dgrad) && (dgrad[:] .= vec(tmp2)) + !isempty(dgrad) && recursive_copyto!(dgrad, tmp2) end end end diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 015222373..8fd1b2b81 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -1,5 +1,6 @@ -struct SteadyStateAdjointSensitivityFunction{C<:AdjointDiffCache, Alg<:SteadyStateAdjoint, - uType, SType, fType<:ODEFunction, CV, λType, VJPType, LS,} <: SensitivityFunction +struct SteadyStateAdjointSensitivityFunction{C <: AdjointDiffCache, + Alg <: SteadyStateAdjoint, + uType, SType, fType <: ODEFunction, CV, λType, VJPType, LS} <: SensitivityFunction diffcache::C sensealg::Alg y::uType @@ -13,6 +14,16 @@ end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjointSensitivityFunction +allocate_vjp(λ, x::Tuple) = allocate_vjp.((λ,), x) +allocate_vjp(λ, x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_vjp(λ, values(x))) +allocate_vjp(λ, x::AbstractArray) = similar(λ, size(x)) +allocate_vjp(λ, x) = fmap(x -> similar(λ, size(x)), x) + +neg!(x::AbstractArray) = x .*= -1 +neg!(x::Tuple) = neg!.(x) +neg!(x::NamedTuple{F}) where {F} = NamedTuple{F}(neg!(values(x))) +neg!(x) = fmap(neg!, x) + function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, f, colorvec, needs_jac, jac_prototype) @unpack p, u0 = sol.prob @@ -23,7 +34,7 @@ function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp λ = zero(y) # Override the choice of the user if we feel that it is not a fast enough choice. linsolve = needs_jac ? nothing : sensealg.linsolve - vjp = similar(λ, length(p)) + vjp = allocate_vjp(λ, p) return SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec, λ, vjp, linsolve) @@ -97,19 +108,21 @@ end if !needs_jac if DiffEqBase.isinplace(sol.prob) - __f_iip(fx, x) = f(reshape(fx, size(diffcache.f_cache)), reshape(x, size(y)), p, - nothing) + function __f_iip(fx, x) + f(reshape(fx, size(diffcache.f_cache)), reshape(x, size(y)), p, + nothing) + end operator = VecJac(__f_iip, vec(diffcache.f_cache), vec(y); - autodiff=get_autodiff_from_vjp(sensealg.autojacvec)) + autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) else __f_oop(x) = vec(f(reshape(x, size(y)), p, nothing)) operator = VecJac(__f_oop, vec(y); - autodiff=get_autodiff_from_vjp(sensealg.autojacvec)) + autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) end if linsolve === nothing && sensealg.uniform_blocked_diagonal_jacobian @warn "linsolve not specified, and Jacobian is specified to be uniform block diagonal. Using SimpleGMRES with blocksize $blocksize" maxlog=1 - linsolve = SimpleGMRES(; blocksize, restart=false) + linsolve = SimpleGMRES(; blocksize, restart = false) end A_ = operator @@ -117,16 +130,16 @@ end A_ = diffcache.J' end - linear_problem = LinearProblem(A_, vec(dgdu_val'); u0=vec(λ)) - sol = solve(linear_problem, linsolve; alias_A=true) # u is vec(λ) + linear_problem = LinearProblem(A_, vec(dgdu_val'); u0 = vec(λ)) + sol = solve(linear_problem, linsolve; alias_A = true) # u is vec(λ) try - vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad=vjp, dy=nothing) + vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing) catch e if sense.sensealg.autojacvec === nothing @warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to nonlinear solve adjoint + numerical vjp" - vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad=vjp, - dy=nothing) + vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp, + dy = nothing) else @warn "AD choice of autojacvec failed in nonlinear solve adjoint" throw(e) @@ -134,6 +147,8 @@ end end if g !== nothing || dgdp !== nothing + # This code-path doesn't support Arbitrary Parameter Structures yet! + @assert vjp isa AbstractArray # compute del g/del p if dgdp !== nothing dgdp(dgdp_val, y, p, nothing, nothing) @@ -144,7 +159,6 @@ end dgdp_val .-= vjp return dgdp_val else - vjp .*= -1 - return vjp + return neg!(vjp) end end From 893239699f7e683191a7b04ff978d613c8904ebf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Oct 2023 12:56:39 -0400 Subject: [PATCH 6/8] Fix setvjp and allow specifying linsolve kwargs --- src/sensitivity_algorithms.jl | 18 ++++++++++-------- src/steadystate_adjoint.jl | 13 +++++++++---- test/steady_state.jl | 11 ++++------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 789be4150..be69421be 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1103,11 +1103,12 @@ documentation page or the docstrings of the vjp types. Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) """ -struct SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM <: AbstractSSAdjointLinsolveMethod} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} +struct SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM <: AbstractSSAdjointLinsolveMethod, LK} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS linsolve_method::LM uniform_blocked_diagonal_jacobian::Bool + linsolve_kwargs::LK end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint @@ -1115,10 +1116,11 @@ TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true, diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing, concrete_jac = false, uniform_blocked_diagonal_jacobian::Bool = false, - linsolve_method = SSAdjointHeuristicLinsolve(50)) + linsolve_method = SSAdjointHeuristicLinsolve(50), linsolve_kwargs = (;)) return SteadyStateAdjoint{_unwrap_val(concrete_jac), chunk_size, autodiff, diff_type, - typeof(autojacvec), typeof(linsolve), typeof(linsolve_method)}(autojacvec, linsolve, - linsolve_method, uniform_blocked_diagonal_jacobian) + typeof(autojacvec), typeof(linsolve), typeof(linsolve_method), + typeof(linsolve_kwargs)}(autojacvec, linsolve, linsolve_method, + uniform_blocked_diagonal_jacobian, linsolve_kwargs) end function needs_concrete_jac(S::SteadyStateAdjoint{CJ}, u0) where {CJ} @@ -1140,11 +1142,11 @@ function jacobian_adtype(S::SteadyStateAdjoint{CJ, CS, AD}) where {CJ, CS, AD} end end -function setvjp(sensealg::SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM}, - vjp) where {CJ, CS, AD, FDT, VJP, LS, LM} - return SteadyStateAdjoint{CJ, CS, AD, FDT, typeof(vjp), LS, LM}(vjp, +function setvjp(sensealg::SteadyStateAdjoint{CJ, CS, AD, FDT, VJP, LS, LM, LK}, + vjp) where {CJ, CS, AD, FDT, VJP, LS, LM, LK} + return SteadyStateAdjoint{CJ, CS, AD, FDT, typeof(vjp), LS, LM, LK}(vjp, sensealg.linsolve, sensealg.linsolve_method, - sensealg.uniform_blocked_diagonal_jacobian) + sensealg.uniform_blocked_diagonal_jacobian, sensealg.linsolve_kwargs) end needs_concrete_jac(::SteadyStateAdjoint, ::SSAdjointFullJacobianLinsolve, _) = true diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 8fd1b2b81..0be42d36c 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -120,9 +120,14 @@ end autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) end - if linsolve === nothing && sensealg.uniform_blocked_diagonal_jacobian - @warn "linsolve not specified, and Jacobian is specified to be uniform block diagonal. Using SimpleGMRES with blocksize $blocksize" maxlog=1 - linsolve = SimpleGMRES(; blocksize, restart = false) + if sensealg.uniform_blocked_diagonal_jacobian + if linsolve === nothing + @warn "linsolve not specified, and Jacobian is specified to be uniform block diagonal. Using SimpleGMRES with blocksize $blocksize" maxlog=1 + linsolve = SimpleGMRES(; blocksize) + elseif linsolve isa SimpleGMRES && linsolve.blocksize ≤ 0 + linsolve = SimpleGMRES(; linsolve.restart, blocksize, linsolve.warm_start, + linsolve.memory) + end end A_ = operator @@ -131,7 +136,7 @@ end end linear_problem = LinearProblem(A_, vec(dgdu_val'); u0 = vec(λ)) - sol = solve(linear_problem, linsolve; alias_A = true) # u is vec(λ) + sol = solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ) try vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing) diff --git a/test/steady_state.jl b/test/steady_state.jl index 03970df22..6ffc415d0 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -295,11 +295,8 @@ end dp1d = Zygote.gradient(p -> sum(solve(prob, DynamicSS(Rodas5()), u0 = u0, p = p)), p) dp2d = Zygote.gradient(p -> sum((2.0 .- - solve(prob, - DynamicSS(Rodas5()), - u0 = u0, - p = p)) .^ - 2) / 2.0, p) + solve(prob, DynamicSS(Rodas5()); u0, p)) .^ 2) / + 2.0, p) @test res1≈dp1[1] rtol=1e-12 @test res2≈dp2[1] rtol=1e-12 @@ -395,8 +392,8 @@ end @test solve3.u≈solve4.u rtol=1e-10 function test_loss(p, prob; alg = NewtonRaphson()) - _prob = remake(prob, p = p) - sol = sum(solve(_prob, alg, + _prob = remake(prob; p) + sol = sum(solve(_prob, alg; sensealg = SteadyStateAdjoint(autojacvec = ReverseDiffVJP()))) return sol end From d5e0b00e05940e5352be3f6519776abbce83a146 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Oct 2023 13:50:30 -0400 Subject: [PATCH 7/8] Upper bound DiffEqBase for now --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index d23cea222..e80a577f9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLSensitivity" uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" authors = ["Christopher Rackauckas ", "Yingbo Ma "] -version = "7.40.0" +version = "7.41.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -54,7 +54,7 @@ ArrayInterface = "7" BandedMatrices = "0.17" Cassette = "0.3.6" ChainRulesCore = "0.10.7, 1" -DiffEqBase = "6.93" +DiffEqBase = "6.93 - 6.130.3" DiffEqCallbacks = "2.29" DiffEqNoiseProcess = "4.1.4, 5.0" DiffRules = "1" From 0be4c21c868bf4839caa2529db7b42735e5f912f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Oct 2023 15:08:18 -0400 Subject: [PATCH 8/8] Lower-Bound DiffEqBase --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e80a577f9..99bdc47b1 100644 --- a/Project.toml +++ b/Project.toml @@ -54,7 +54,7 @@ ArrayInterface = "7" BandedMatrices = "0.17" Cassette = "0.3.6" ChainRulesCore = "0.10.7, 1" -DiffEqBase = "6.93 - 6.130.3" +DiffEqBase = "6.130.5" DiffEqCallbacks = "2.29" DiffEqNoiseProcess = "4.1.4, 5.0" DiffRules = "1"