From 4577e26f54be91aa5a885299eef907b52d3bb0d9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Oct 2023 18:03:37 -0400 Subject: [PATCH 1/2] Steady State Adjoints for Non-Vector Parameters --- Project.toml | 4 +++- src/SciMLSensitivity.jl | 4 +++- src/derivative_wrappers.jl | 42 ++++++++++++++++---------------- src/parameters_handling.jl | 45 +++++++++++++++++++++++++++++++++++ src/sensitivity_algorithms.jl | 31 ++++++++++++------------ src/steadystate_adjoint.jl | 36 ++++++++++------------------ 6 files changed, 100 insertions(+), 62 deletions(-) create mode 100644 src/parameters_handling.jl diff --git a/Project.toml b/Project.toml index c098fdf18..2ab07212d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLSensitivity" uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" authors = ["Christopher Rackauckas ", "Yingbo Ma "] -version = "7.41.1" +version = "7.42.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -18,6 +18,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -58,6 +59,7 @@ FiniteDiff = "2" ForwardDiff = "0.10" FunctionProperties = "0.1" FunctionWrappersWrappers = "0.1" +Functors = "0.4" GPUArraysCore = "0.1" LinearSolve = "2" OrdinaryDiffEq = "6.19.1" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 062132127..97b7be35f 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -17,6 +17,7 @@ using StaticArraysCore using ADTypes using SparseDiffTools using SciMLOperators +using Functors import TruncatedStacktraces import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache @@ -32,13 +33,14 @@ import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Pr abstract type SensitivityFunction end abstract type TransformedFunction end -import SciMLBase: unwrapped_f +import SciMLBase: unwrapped_f, _unwrap_val import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm, AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm, AbstractSecondOrderSensitivityAlgorithm, AbstractShadowingSensitivityAlgorithm +include("parameters_handling.jl") include("sensitivity_algorithms.jl") include("derivative_wrappers.jl") include("sensitivity_interface.jl") diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 21ae19858..d644365c0 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -313,13 +313,13 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy, if inplace_sensitivity(S) f(dy, y, p, t) else - dy[:] .= vec(f(y, p, t)) + recursive_copyto!(dy, vec(f(y, p, t))) end else if inplace_sensitivity(S) f(dy, y, p, t, W) else - dy[:] .= vec(f(y, p, t, W)) + recursive_copyto!(dy, vec(f(y, p, t, W))) end end end @@ -386,11 +386,11 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::TrackerVJP, dgrad, end # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(Tracker.data(_dy))) + dy !== nothing && recursive_copyto!(dy, Tracker.data(_dy)) tmp1, tmp2 = Tracker.data.(back(λ)) - dλ !== nothing && (dλ[:] .= vec(tmp1)) - dgrad !== nothing && (dgrad[:] .= vec(tmp2)) + dλ !== nothing && recursive_copyto!(dλ, tmp1) + dgrad !== nothing && recursive_copyto!(dgrad, tmp2) else if W === nothing _dy, back = Tracker.forward(y, p) do u, p @@ -408,11 +408,11 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::TrackerVJP, dgrad, end # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(Tracker.data(_dy))) + dy !== nothing && recursive_copyto!(dy, Tracker.data(_dy)) tmp1, tmp2 = Tracker.data.(back(λ)) - dλ !== nothing && (dλ[:] .= vec(tmp1)) - dgrad !== nothing && (dgrad[:] .= vec(tmp2)) + dλ !== nothing && recursive_copyto!(dλ, tmp1) + dgrad !== nothing && recursive_copyto!(dgrad, tmp2) end return end @@ -556,15 +556,15 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, end # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(_dy)) + dy !== nothing && recursive_copyto!(dy, _dy) tmp1, tmp2 = back(λ) - dλ !== nothing && (dλ[:] .= vec(tmp1)) + dλ !== nothing && recursive_copyto!(dλ, tmp1) if dgrad !== nothing if tmp2 === nothing && !sensealg.autojacvec.allow_nothing throw(ZygoteVJPNothingError()) else - !isempty(dgrad) && (dgrad[:] .= vec(tmp2)) + !isempty(dgrad) && recursive_copyto!(dgrad, tmp2) end end else @@ -579,20 +579,20 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, end # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(_dy)) + dy !== nothing && recursive_copyto!(dy, _dy) tmp1, tmp2 = back(λ) if tmp1 === nothing && !sensealg.autojacvec.allow_nothing throw(ZygoteVJPNothingError()) elseif tmp1 !== nothing - dλ !== nothing && (dλ[:] .= vec(tmp1)) + dλ !== nothing && recursive_copyto!(dλ, tmp1) end if dgrad !== nothing if tmp2 === nothing && !sensealg.autojacvec.allow_nothing throw(ZygoteVJPNothingError()) elseif tmp2 !== nothing - !isempty(dgrad) && (dgrad[:] .= vec(tmp2)) + !isempty(dgrad) && recursive_copyto!(dgrad, tmp2) end end end @@ -616,7 +616,7 @@ function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, end # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(_dy)) + dy !== nothing && recursive_copyto!(dy, _dy) tmp1, tmp2 = back(λ) if tmp1 === nothing && !sensealg.autojacvec.allow_nothing @@ -629,7 +629,7 @@ function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, if tmp2 === nothing && !sensealg.autojacvec.allow_nothing throw(ZygoteVJPNothingError()) elseif tmp2 !== nothing - (dgrad[:] .= vec(tmp2)) + recursive_copyto!(dgrad, tmp2) end end return dy, dλ, dgrad @@ -697,7 +697,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dλ !== nothing && (dλ .= tmp1) dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) && - (dgrad[:] .= vec(tmp2)) + recursive_copyto!(dgrad, tmp2) dy !== nothing && (dy .= tmp3) else if W === nothing @@ -715,11 +715,11 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, else f(y, p, t, W) end - dy[:] .= vec(out_) + recursive_copyto!(dy, out_) end dλ !== nothing && (dλ .= tmp1) dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) && - (dgrad[:] .= vec(tmp2)) + recursive_copyto!(dgrad, tmp2) dy !== nothing && (dy .= tmp3) end return @@ -755,7 +755,7 @@ function _jacNoise!(λ, y, p, t, S::TS, isnoise::Bool, dgrad, dλ, if StochasticDiffEq.is_diagonal_noise(prob) pJt = transpose(λ) .* transpose(pJ) - dgrad[:] .= vec(pJt) + recursive_copyto!(dgrad, pJt) else m = size(prob.noise_rate_prototype)[2] for i in 1:m @@ -808,7 +808,7 @@ function _jacNoise!(λ, y, p, t, S::TS, isnoise::Bool, dgrad, dλ, if StochasticDiffEq.is_diagonal_noise(prob) Jt = transpose(λ) .* transpose(J) - dλ[:] .= vec(Jt) + recursive_copyto!(dλ, Jt) else for i in 1:m tmp = λ' * J[((i - 1) * m + 1):(i * m), :] diff --git a/src/parameters_handling.jl b/src/parameters_handling.jl new file mode 100644 index 000000000..b8bb344da --- /dev/null +++ b/src/parameters_handling.jl @@ -0,0 +1,45 @@ +# NOTE: `fmap` can handle all these cases without us defining them, but it often makes the +# code type unstable. So we define them here to make the code type stable. +# Handle Non-Array Parameters in a Generic Fashion +""" + recursive_copyto!(y, x) + +`y[:] .= vec(x)` for generic `x` and `y`. This is used to handle non-array parameters! +""" +recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x) +recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x) +recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} = + map(recursive_copyto!, values(y), values(x)) +recursive_copyto!(y::T, x::T) where {T} = fmap(recursive_copyto!, y, x) + +""" + neg!(x) + +`x .*= -1` for generic `x`. This is used to handle non-array parameters! +""" +recursive_neg!(x::AbstractArray) = (x .*= -1) +recursive_neg!(x::Tuple) = map(recursive_neg!, x) +recursive_neg!(x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_neg!, values(x))) +recursive_neg!(x) = fmap(recursive_neg!, x) + +""" + recursive_sub!(y, x) + +`y .-= x` for generic `x` and `y`. This is used to handle non-array parameters! +""" +recursive_sub!(y::AbstractArray, x::AbstractArray) = (y .-= x) +recursive_sub!(y::Tuple, x::Tuple) = map(recursive_sub!, y, x) +recursive_sub!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} = + NamedTuple{F}(map(recursive_sub!, values(y), values(x))) +recursive_sub!(y::T, x::T) where {T} = fmap(recursive_sub!, y, x) + +""" + allocate_vjp(λ, x) + +`similar(λ, size(x))` for generic `x`. This is used to handle non-array parameters! +""" +allocate_vjp(λ::AbstractArray, x::AbstractArray) = similar(λ, size(x)) +allocate_vjp(λ::AbstractArray, x::Tuple) = allocate_vjp.((λ,), x) +allocate_vjp(λ::AbstractArray, x::NamedTuple{F}) where {F} = + NamedTuple{F}(allocate_vjp.((λ,), values(x))) +allocate_vjp(λ::AbstractArray, x) = fmap(Base.Fix1(allocate_vjp, λ), x) diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 9ad688310..a1043dc8b 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -304,7 +304,7 @@ InterpolatingAdjoint(; chunk_size = 0, autodiff = true, differentiation with special seeding. The total set of choices are: + `nothing`: uses an automatic algorithm to automatically choose the vjp. - This is the default and recommended for most users. + This is the default and recommended for most users. + `false`: the Jacobian is constructed via FiniteDiff.jl + `true`: the Jacobian is constructed via ForwardDiff.jl + `TrackerVJP`: Uses Tracker.jl for the vjp. @@ -1067,6 +1067,7 @@ SteadyStateAdjoint(; chunk_size = 0, autodiff = true, - `linsolve`: the linear solver used in the adjoint solve. Defaults to `nothing`, which uses a polyalgorithm to choose an efficient algorithm automatically. + - `linsolve_kwargs`: keyword arguments to be passed to the linear solver. For more details on the vjp choices, please consult the sensitivity algorithms documentation page or the docstrings of the vjp types. @@ -1076,29 +1077,25 @@ documentation page or the docstrings of the vjp types. Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) """ -struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS} <: +struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS + linsolve_kwargs::LK end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true, - diff_type = Val{:central}, - autojacvec = nothing, linsolve = nothing) - SteadyStateAdjoint{ - chunk_size, - autodiff, - diff_type, - typeof(autojacvec), - typeof(linsolve), - }(autojacvec, - linsolve) + diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing, + linsolve_kwargs=(;)) + return SteadyStateAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec), + typeof(linsolve), typeof(linsolve_kwargs)}(autojacvec, linsolve, linsolve_kwargs) end -function setvjp(sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS}, - vjp) where {CS, AD, FDT, VJP, LS} - SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS}(vjp, sensealg.linsolve) +function setvjp(sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK}, + vjp) where {CS, AD, FDT, VJP, LS, LK} + return SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS, LK}(vjp, sensealg.linsolve, + sensealg.linsolve_kwargs) end abstract type VJPChoice end @@ -1314,7 +1311,9 @@ struct ForwardDiffOverAdjoint{A} <: adjalg::A end -get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where{compile} = AutoReverseDiff(; compile = compile) +get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where{compile} = AutoReverseDiff(; compile) get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote() get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() get_autodiff_from_vjp(::TrackerVJP) = AutoTracker() +get_autodiff_from_vjp(::Nothing) = AutoZygote() +get_autodiff_from_vjp(b::Bool) = ifelse(b, AutoForwardDiff(), AutoFiniteDiff()) diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index e24ba04c1..3585399f8 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -1,14 +1,6 @@ -struct SteadyStateAdjointSensitivityFunction{ - C <: AdjointDiffCache, - Alg <: SteadyStateAdjoint, - uType, - SType, - fType <: ODEFunction, - CV, - λType, - VJPType, - LS, -} <: SensitivityFunction +struct SteadyStateAdjointSensitivityFunction{C <: AdjointDiffCache, + Alg <: SteadyStateAdjoint, uType, SType, fType <: ODEFunction, CV, λType, VJPType, + LS} <: SensitivityFunction diffcache::C sensealg::Alg y::uType @@ -31,10 +23,10 @@ function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp λ = zero(y) linsolve = needs_jac ? nothing : sensealg.linsolve - vjp = similar(λ, length(p)) + vjp = allocate_vjp(λ, p) - SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec, λ, vjp, - linsolve) + return SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec, + λ, vjp, linsolve) end @noinline function SteadyStateAdjointProblem(sol, sensealg::SteadyStateAdjoint, alg, @@ -51,10 +43,10 @@ end needs_jac = if has_adjoint(f) false - # TODO: What is the correct heuristic? Can we afford to compute Jacobian for - # cases where the length(u0) > 50 and if yes till what threshold + # TODO: What is the correct heuristic? Can we afford to compute Jacobian for + # cases where the length(u0) > 50 and if yes till what threshold elseif sensealg.linsolve === nothing - length(u0) <= 50 + length(u0) ≤ 50 else LinearSolve.needs_concrete_A(sensealg.linsolve) end @@ -100,7 +92,6 @@ end end if !needs_jac - # operator = VecJac(f, y, p; Val(DiffEqBase.isinplace(sol.prob))) __f = y -> f(y, p, nothing) operator = VecJac(__f, y; autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) @@ -109,15 +100,14 @@ end end # Zygote pullback function won't work with deepcopy - solve(linear_problem, linsolve; alias_A = true) # u is vec(λ) + solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ) try vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing) catch e if sense.sensealg.autojacvec === nothing @warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp" - vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp, - dy = nothing) + vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp, dy = nothing) else @warn "AD choice of autojacvec failed in nonlinear solve adjoint" throw(e) @@ -132,10 +122,10 @@ end @unpack g_grad_config = diffcache gradient!(dgdp_val, diffcache.g[2], p, sensealg, g_grad_config[2]) end - dgdp_val .-= vjp + recursive_sub!(dgdp_val, vjp) return dgdp_val else - vjp .*= -1 + recursive_neg!(vjp) return vjp end end From 8ccac8b94e4b691e6d3c06cc0d3f612a2759277a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Oct 2023 18:08:15 -0400 Subject: [PATCH 2/2] Use axpy! --- src/parameters_handling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parameters_handling.jl b/src/parameters_handling.jl index b8bb344da..293771944 100644 --- a/src/parameters_handling.jl +++ b/src/parameters_handling.jl @@ -27,7 +27,7 @@ recursive_neg!(x) = fmap(recursive_neg!, x) `y .-= x` for generic `x` and `y`. This is used to handle non-array parameters! """ -recursive_sub!(y::AbstractArray, x::AbstractArray) = (y .-= x) +recursive_sub!(y::AbstractArray, x::AbstractArray) = axpy!(-1, x, y) recursive_sub!(y::Tuple, x::Tuple) = map(recursive_sub!, y, x) recursive_sub!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_sub!, values(y), values(x)))