Skip to content

Commit

Permalink
Support Non Array Parameters for Steady State Adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 1, 2023
1 parent 444f69e commit 909c939
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand Down Expand Up @@ -50,6 +51,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ADTypes = "0.1, 0.2"
Adapt = "1.0, 2.0, 3.0"
ArrayInterface = "7"
BandedMatrices = "0.17"
Cassette = "0.3.6"
ChainRulesCore = "0.10.7, 1"
DiffEqBase = "6.93"
Expand All @@ -62,6 +64,7 @@ Enzyme = "0.11.6"
FiniteDiff = "2"
ForwardDiff = "0.10"
FunctionWrappersWrappers = "0.1"
Functors = "0.4"
GPUArraysCore = "0.1"
LinearSolve = "2"
OrdinaryDiffEq = "6.19.1"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using SparseDiffTools
using SciMLOperators
using BandedMatrices
using SparseArrays
using Functors
import TruncatedStacktraces

import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache
Expand Down
12 changes: 9 additions & 3 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,12 @@ function Base.showerror(io::IO, e::ZygoteVJPNothingError)
print(io, ZYGOTEVJP_NOTHING_MESSAGE)
end

recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x)
recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x)
recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} =
map(recursive_copyto!, values(y), values(x))
recursive_copyto!(y, x) = fmap(recursive_copyto!, y, x)

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy,
W) where {TS <: SensitivityFunction}
@unpack sensealg = S
Expand Down Expand Up @@ -579,20 +585,20 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad,
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && (dy[:] .= vec(_dy))
dy !== nothing && recursive_copyto!(dy, _dy)

tmp1, tmp2 = back(λ)
if tmp1 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp1 !== nothing
!== nothing && (dλ[:] .= vec(tmp1))
!== nothing && recursive_copyto!(dλ, tmp1)
end

if dgrad !== nothing
if tmp2 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp2 !== nothing
!isempty(dgrad) && (dgrad[:] .= vec(tmp2))
!isempty(dgrad) && recursive_copyto!(dgrad, tmp2)
end
end
end
Expand Down
44 changes: 29 additions & 15 deletions src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
struct SteadyStateAdjointSensitivityFunction{C<:AdjointDiffCache, Alg<:SteadyStateAdjoint,
uType, SType, fType<:ODEFunction, CV, λType, VJPType, LS,} <: SensitivityFunction
struct SteadyStateAdjointSensitivityFunction{C <: AdjointDiffCache,
Alg <: SteadyStateAdjoint,
uType, SType, fType <: ODEFunction, CV, λType, VJPType, LS} <: SensitivityFunction
diffcache::C
sensealg::Alg
y::uType
Expand All @@ -13,6 +14,16 @@ end

TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjointSensitivityFunction

allocate_vjp(λ, x::Tuple) = allocate_vjp.((λ,), x)
allocate_vjp(λ, x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_vjp(λ, values(x)))
allocate_vjp(λ, x::AbstractArray) = similar(λ, size(x))
allocate_vjp(λ, x) = fmap(x -> similar(λ, size(x)), x)

neg!(x::AbstractArray) = x .*= -1
neg!(x::Tuple) = neg!.(x)
neg!(x::NamedTuple{F}) where {F} = NamedTuple{F}(neg!(values(x)))
neg!(x) = fmap(neg!, x)

function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, f,
colorvec, needs_jac, jac_prototype)
@unpack p, u0 = sol.prob
Expand All @@ -23,7 +34,7 @@ function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp
λ = zero(y)
# Override the choice of the user if we feel that it is not a fast enough choice.
linsolve = needs_jac ? nothing : sensealg.linsolve
vjp = similar(λ, length(p))
vjp = allocate_vjp(λ, p)

return SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec,
λ, vjp, linsolve)
Expand Down Expand Up @@ -97,43 +108,47 @@ end

if !needs_jac
if DiffEqBase.isinplace(sol.prob)
__f_iip(fx, x) = f(reshape(fx, size(diffcache.f_cache)), reshape(x, size(y)), p,
nothing)
function __f_iip(fx, x)
f(reshape(fx, size(diffcache.f_cache)), reshape(x, size(y)), p,
nothing)
end
operator = VecJac(__f_iip, vec(diffcache.f_cache), vec(y);
autodiff=get_autodiff_from_vjp(sensealg.autojacvec))
autodiff = get_autodiff_from_vjp(sensealg.autojacvec))
else
__f_oop(x) = vec(f(reshape(x, size(y)), p, nothing))
operator = VecJac(__f_oop, vec(y);
autodiff=get_autodiff_from_vjp(sensealg.autojacvec))
autodiff = get_autodiff_from_vjp(sensealg.autojacvec))
end

if linsolve === nothing && sensealg.uniform_blocked_diagonal_jacobian
@warn "linsolve not specified, and Jacobian is specified to be uniform block diagonal. Using SimpleGMRES with blocksize $blocksize" maxlog=1
linsolve = SimpleGMRES(; blocksize, restart=false)
linsolve = SimpleGMRES(; blocksize, restart = false)
end

A_ = operator
else
A_ = diffcache.J'
end

linear_problem = LinearProblem(A_, vec(dgdu_val'); u0=vec(λ))
sol = solve(linear_problem, linsolve; alias_A=true) # u is vec(λ)
linear_problem = LinearProblem(A_, vec(dgdu_val'); u0 = vec(λ))
sol = solve(linear_problem, linsolve; alias_A = true) # u is vec(λ)

try
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad=vjp, dy=nothing)
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing)
catch e
if sense.sensealg.autojacvec === nothing
@warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to nonlinear solve adjoint + numerical vjp"
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad=vjp,
dy=nothing)
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp,
dy = nothing)
else
@warn "AD choice of autojacvec failed in nonlinear solve adjoint"
throw(e)
end
end

if g !== nothing || dgdp !== nothing
# This code-path doesn't support Arbitrary Parameter Structures yet!
@assert vjp isa AbstractArray
# compute del g/del p
if dgdp !== nothing
dgdp(dgdp_val, y, p, nothing, nothing)
Expand All @@ -144,7 +159,6 @@ end
dgdp_val .-= vjp
return dgdp_val
else
vjp .*= -1
return vjp
return neg!(vjp)
end
end

0 comments on commit 909c939

Please sign in to comment.