From 909c93909c0a5e27fac198116fdcfd74d67e06de Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 16 Sep 2023 15:09:13 -0400 Subject: [PATCH] Support Non Array Parameters for Steady State Adjoint --- Project.toml | 3 +++ src/SciMLSensitivity.jl | 1 + src/derivative_wrappers.jl | 12 ++++++++--- src/steadystate_adjoint.jl | 44 +++++++++++++++++++++++++------------- 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index f2a508910..d23cea222 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -50,6 +51,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" ADTypes = "0.1, 0.2" Adapt = "1.0, 2.0, 3.0" ArrayInterface = "7" +BandedMatrices = "0.17" Cassette = "0.3.6" ChainRulesCore = "0.10.7, 1" DiffEqBase = "6.93" @@ -62,6 +64,7 @@ Enzyme = "0.11.6" FiniteDiff = "2" ForwardDiff = "0.10" FunctionWrappersWrappers = "0.1" +Functors = "0.4" GPUArraysCore = "0.1" LinearSolve = "2" OrdinaryDiffEq = "6.19.1" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 4671e7852..958beca4b 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -19,6 +19,7 @@ using SparseDiffTools using SciMLOperators using BandedMatrices using SparseArrays +using Functors import TruncatedStacktraces import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 21ae19858..1119804c2 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -534,6 +534,12 @@ function Base.showerror(io::IO, e::ZygoteVJPNothingError) print(io, ZYGOTEVJP_NOTHING_MESSAGE) end +recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x) +recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x) +recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} = + map(recursive_copyto!, values(y), values(x)) +recursive_copyto!(y, x) = fmap(recursive_copyto!, y, x) + function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, W) where {TS <: SensitivityFunction} @unpack sensealg = S @@ -579,20 +585,20 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, end # Grab values from `_dy` before `back` in case mutated - dy !== nothing && (dy[:] .= vec(_dy)) + dy !== nothing && recursive_copyto!(dy, _dy) tmp1, tmp2 = back(λ) if tmp1 === nothing && !sensealg.autojacvec.allow_nothing throw(ZygoteVJPNothingError()) elseif tmp1 !== nothing - dλ !== nothing && (dλ[:] .= vec(tmp1)) + dλ !== nothing && recursive_copyto!(dλ, tmp1) end if dgrad !== nothing if tmp2 === nothing && !sensealg.autojacvec.allow_nothing throw(ZygoteVJPNothingError()) elseif tmp2 !== nothing - !isempty(dgrad) && (dgrad[:] .= vec(tmp2)) + !isempty(dgrad) && recursive_copyto!(dgrad, tmp2) end end end diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 015222373..8fd1b2b81 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -1,5 +1,6 @@ -struct SteadyStateAdjointSensitivityFunction{C<:AdjointDiffCache, Alg<:SteadyStateAdjoint, - uType, SType, fType<:ODEFunction, CV, λType, VJPType, LS,} <: SensitivityFunction +struct SteadyStateAdjointSensitivityFunction{C <: AdjointDiffCache, + Alg <: SteadyStateAdjoint, + uType, SType, fType <: ODEFunction, CV, λType, VJPType, LS} <: SensitivityFunction diffcache::C sensealg::Alg y::uType @@ -13,6 +14,16 @@ end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjointSensitivityFunction +allocate_vjp(λ, x::Tuple) = allocate_vjp.((λ,), x) +allocate_vjp(λ, x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_vjp(λ, values(x))) +allocate_vjp(λ, x::AbstractArray) = similar(λ, size(x)) +allocate_vjp(λ, x) = fmap(x -> similar(λ, size(x)), x) + +neg!(x::AbstractArray) = x .*= -1 +neg!(x::Tuple) = neg!.(x) +neg!(x::NamedTuple{F}) where {F} = NamedTuple{F}(neg!(values(x))) +neg!(x) = fmap(neg!, x) + function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, f, colorvec, needs_jac, jac_prototype) @unpack p, u0 = sol.prob @@ -23,7 +34,7 @@ function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp λ = zero(y) # Override the choice of the user if we feel that it is not a fast enough choice. linsolve = needs_jac ? nothing : sensealg.linsolve - vjp = similar(λ, length(p)) + vjp = allocate_vjp(λ, p) return SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec, λ, vjp, linsolve) @@ -97,19 +108,21 @@ end if !needs_jac if DiffEqBase.isinplace(sol.prob) - __f_iip(fx, x) = f(reshape(fx, size(diffcache.f_cache)), reshape(x, size(y)), p, - nothing) + function __f_iip(fx, x) + f(reshape(fx, size(diffcache.f_cache)), reshape(x, size(y)), p, + nothing) + end operator = VecJac(__f_iip, vec(diffcache.f_cache), vec(y); - autodiff=get_autodiff_from_vjp(sensealg.autojacvec)) + autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) else __f_oop(x) = vec(f(reshape(x, size(y)), p, nothing)) operator = VecJac(__f_oop, vec(y); - autodiff=get_autodiff_from_vjp(sensealg.autojacvec)) + autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) end if linsolve === nothing && sensealg.uniform_blocked_diagonal_jacobian @warn "linsolve not specified, and Jacobian is specified to be uniform block diagonal. Using SimpleGMRES with blocksize $blocksize" maxlog=1 - linsolve = SimpleGMRES(; blocksize, restart=false) + linsolve = SimpleGMRES(; blocksize, restart = false) end A_ = operator @@ -117,16 +130,16 @@ end A_ = diffcache.J' end - linear_problem = LinearProblem(A_, vec(dgdu_val'); u0=vec(λ)) - sol = solve(linear_problem, linsolve; alias_A=true) # u is vec(λ) + linear_problem = LinearProblem(A_, vec(dgdu_val'); u0 = vec(λ)) + sol = solve(linear_problem, linsolve; alias_A = true) # u is vec(λ) try - vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad=vjp, dy=nothing) + vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing) catch e if sense.sensealg.autojacvec === nothing @warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to nonlinear solve adjoint + numerical vjp" - vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad=vjp, - dy=nothing) + vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp, + dy = nothing) else @warn "AD choice of autojacvec failed in nonlinear solve adjoint" throw(e) @@ -134,6 +147,8 @@ end end if g !== nothing || dgdp !== nothing + # This code-path doesn't support Arbitrary Parameter Structures yet! + @assert vjp isa AbstractArray # compute del g/del p if dgdp !== nothing dgdp(dgdp_val, y, p, nothing, nothing) @@ -144,7 +159,6 @@ end dgdp_val .-= vjp return dgdp_val else - vjp .*= -1 - return vjp + return neg!(vjp) end end