Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Steady State Adjoint #877

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
*.jl.*.cov
*.jl.mem
Manifest.toml
/docs/build/
/docs/build/
.vscode
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "SciMLSensitivity"
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
version = "7.40.0"
version = "7.41.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand All @@ -19,6 +20,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand All @@ -35,6 +37,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -48,9 +51,10 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ADTypes = "0.1, 0.2"
Adapt = "1.0, 2.0, 3.0"
ArrayInterface = "7"
BandedMatrices = "0.17"
Cassette = "0.3.6"
ChainRulesCore = "0.10.7, 1"
DiffEqBase = "6.93"
DiffEqBase = "6.130.5"
DiffEqCallbacks = "2.29"
DiffEqNoiseProcess = "4.1.4, 5.0"
DiffRules = "1"
Expand All @@ -60,6 +64,7 @@ Enzyme = "0.11.6"
FiniteDiff = "2"
ForwardDiff = "0.10"
FunctionWrappersWrappers = "0.1"
Functors = "0.4"
GPUArraysCore = "0.1"
LinearSolve = "2"
OrdinaryDiffEq = "6.19.1"
Expand Down
8 changes: 7 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ using StaticArraysCore
using ADTypes
using SparseDiffTools
using SciMLOperators
using BandedMatrices
using SparseArrays
using Functors
import TruncatedStacktraces

import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache
Expand All @@ -34,7 +37,7 @@ import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Pr
abstract type SensitivityFunction end
abstract type TransformedFunction end

import SciMLBase: unwrapped_f
import SciMLBase: unwrapped_f, _unwrap_val

import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm,
AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm,
Expand Down Expand Up @@ -84,4 +87,7 @@ export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP

export StochasticTransformedFunction

export SSAdjointFullJacobianLinsolve, SSAdjointIterativeVJPLinsolve,
SSAdjointHeuristicLinsolve

end # module
11 changes: 6 additions & 5 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ TruncatedStacktraces.@truncate_stacktrace AdjointDiffCache
return (AdjointDiffCache, y)
"""
function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f, alg;
quad = false,
noiseterm = false, needs_jac = false) where {G, DG1, DG2}
quad = false, noiseterm = false, needs_jac = false,
jac_prototype = nothing) where {G, DG1, DG2}
prob = sol.prob
if prob isa Union{SteadyStateProblem, NonlinearProblem}
@unpack u0, p = prob
Expand Down Expand Up @@ -104,12 +104,13 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
if !needs_jac && !issemiexplicitdae && !(autojacvec isa Bool)
J = nothing
else
# jac_prototype can be provided if we want to exploit sparsity
J_ = jac_prototype !== nothing ? jac_prototype : similar(u0, numindvar, numindvar)
if SciMLBase.forwarddiffs_model_time(alg)
# 1 chunk is fine because it's only t
J = dualcache(similar(u0, numindvar, numindvar),
ForwardDiff.pickchunksize(length(u0)))
J = dualcache(J_, ForwardDiff.pickchunksize(length(u0)))
else
J = similar(u0, numindvar, numindvar)
J = J_
end
end

Expand Down
98 changes: 82 additions & 16 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,45 @@
const have_not_warned_vjp = Ref(true)
const STACKTRACE_WITH_VJPWARN = Ref(false)

__unwrapped_f(f) = unwrapped_f(f)
__unwrapped_f(f::NonlinearFunction) = f.f

function inplace_vjp(prob, u0, p, verbose)
du = copy(u0)

tspan_nothing = hasmethod(__unwrapped_f(prob.f),
Tuple{typeof(du), typeof(u0), typeof(p), Nothing})
no_tspan = (!hasfield(typeof(prob), :tspan) || !(hasmethod(__unwrapped_f(prob.f),
Tuple{typeof(du), typeof(u0), typeof(p), typeof(first(prob.tspan))}))) &&
!tspan_nothing
if !no_tspan
if tspan_nothing
__t = nothing
else
__t = first(prob.tspan)
end
end

ez = try
f = unwrapped_f(prob.f)

function adfunc(out, u, _p, t)
f(out, u, _p, t)
nothing
if no_tspan
function adfunc_nlprob(out, u, _p)
f(out, u, _p)
nothing
end
Enzyme.autodiff(Enzyme.Reverse, adfunc_nlprob, Enzyme.Duplicated(du, du),
copy(u0), copy(p))
true
else
function adfunc(out, u, _p, t)
f(out, u, _p, t)
nothing
end
Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(du, du),
copy(u0), copy(p), __t)
true
end
Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(du, du),
copy(u0), copy(p), prob.tspan[1])
true
catch e
if verbose || have_not_warned_vjp[]
@warn "Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.\n"
Expand All @@ -36,9 +62,17 @@ function inplace_vjp(prob, u0, p, verbose)
compile = try
f = unwrapped_f(prob.f)
if DiffEqBase.isinplace(prob)
!hasbranching(f, copy(u0), u0, p, prob.tspan[1])
if no_tspan
!hasbranching(f, copy(u0), u0, p)
else
!hasbranching(f, copy(u0), u0, p, __t)
end
else
!hasbranching(f, u0, p, prob.tspan[1])
if no_tspan
!hasbranching(f, u0, p)
else
!hasbranching(f, u0, p, __t)
end
end
catch
false
Expand All @@ -47,16 +81,48 @@ function inplace_vjp(prob, u0, p, verbose)
vjp = try
f = unwrapped_f(prob.f)
if p === nothing || p isa SciMLBase.NullParameters
ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
return vec(du1)
if no_tspan
ReverseDiff.GradientTape((copy(u0),)) do u
du1 = similar(u, size(u))
f(du1, u, p)
return vec(du1)
end
else
if tspan_nothing
ReverseDiff.GradientTape((copy(u0),)) do u
du1 = similar(u, size(u))
f(du1, u, p, nothing)
return vec(du1)
end
else
ReverseDiff.GradientTape((copy(u0), [__t])) do u, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
return vec(du1)
end
end
end
else
ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
return vec(du1)
if no_tspan
ReverseDiff.GradientTape((copy(u0), p)) do u, p
du1 = similar(u, size(u))
f(du1, u, p)
return vec(du1)
end
else
if tspan_nothing
ReverseDiff.GradientTape((copy(u0), p)) do u, p
du1 = similar(u, size(u))
f(du1, u, p, nothing)
return vec(du1)
end
else
ReverseDiff.GradientTape((copy(u0), p, [__t])) do u, p, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
return vec(du1)
end
end
end
end
ReverseDiffVJP(compile)
Expand Down
12 changes: 9 additions & 3 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,12 @@ function Base.showerror(io::IO, e::ZygoteVJPNothingError)
print(io, ZYGOTEVJP_NOTHING_MESSAGE)
end

recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x)
recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x)
recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} =
map(recursive_copyto!, values(y), values(x))
recursive_copyto!(y, x) = fmap(recursive_copyto!, y, x)

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy,
W) where {TS <: SensitivityFunction}
@unpack sensealg = S
Expand Down Expand Up @@ -579,20 +585,20 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad,
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && (dy[:] .= vec(_dy))
dy !== nothing && recursive_copyto!(dy, _dy)

tmp1, tmp2 = back(λ)
if tmp1 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp1 !== nothing
dλ !== nothing && (dλ[:] .= vec(tmp1))
dλ !== nothing && recursive_copyto!(dλ, tmp1)
end

if dgrad !== nothing
if tmp2 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp2 !== nothing
!isempty(dgrad) && (dgrad[:] .= vec(tmp2))
!isempty(dgrad) && recursive_copyto!(dgrad, tmp2)
end
end
end
Expand Down
Loading
Loading