Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Steady State Adjoints for Non-Vector Parameters #913

Merged
merged 2 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLSensitivity"
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
version = "7.41.1"
version = "7.42.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -18,6 +18,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand Down Expand Up @@ -58,6 +59,7 @@ FiniteDiff = "2"
ForwardDiff = "0.10"
FunctionProperties = "0.1"
FunctionWrappersWrappers = "0.1"
Functors = "0.4"
GPUArraysCore = "0.1"
LinearSolve = "2"
OrdinaryDiffEq = "6.19.1"
Expand Down
4 changes: 3 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using StaticArraysCore
using ADTypes
using SparseDiffTools
using SciMLOperators
using Functors
import TruncatedStacktraces

import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache
Expand All @@ -32,13 +33,14 @@ import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Pr
abstract type SensitivityFunction end
abstract type TransformedFunction end

import SciMLBase: unwrapped_f
import SciMLBase: unwrapped_f, _unwrap_val

import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm,
AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm,
AbstractSecondOrderSensitivityAlgorithm,
AbstractShadowingSensitivityAlgorithm

include("parameters_handling.jl")
include("sensitivity_algorithms.jl")
include("derivative_wrappers.jl")
include("sensitivity_interface.jl")
Expand Down
42 changes: 21 additions & 21 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,13 @@
if inplace_sensitivity(S)
f(dy, y, p, t)
else
dy[:] .= vec(f(y, p, t))
recursive_copyto!(dy, vec(f(y, p, t)))
end
else
if inplace_sensitivity(S)
f(dy, y, p, t, W)
else
dy[:] .= vec(f(y, p, t, W))
recursive_copyto!(dy, vec(f(y, p, t, W)))
end
end
end
Expand Down Expand Up @@ -386,11 +386,11 @@
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && (dy[:] .= vec(Tracker.data(_dy)))
dy !== nothing && recursive_copyto!(dy, Tracker.data(_dy))

tmp1, tmp2 = Tracker.data.(back(λ))
dλ !== nothing && (dλ[:] .= vec(tmp1))
dgrad !== nothing && (dgrad[:] .= vec(tmp2))
dλ !== nothing && recursive_copyto!(dλ, tmp1)
dgrad !== nothing && recursive_copyto!(dgrad, tmp2)
else
if W === nothing
_dy, back = Tracker.forward(y, p) do u, p
Expand All @@ -408,11 +408,11 @@
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && (dy[:] .= vec(Tracker.data(_dy)))
dy !== nothing && recursive_copyto!(dy, Tracker.data(_dy))

tmp1, tmp2 = Tracker.data.(back(λ))
dλ !== nothing && (dλ[:] .= vec(tmp1))
dgrad !== nothing && (dgrad[:] .= vec(tmp2))
dλ !== nothing && recursive_copyto!(dλ, tmp1)
dgrad !== nothing && recursive_copyto!(dgrad, tmp2)
end
return
end
Expand Down Expand Up @@ -556,15 +556,15 @@
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && (dy[:] .= vec(_dy))
dy !== nothing && recursive_copyto!(dy, _dy)

tmp1, tmp2 = back(λ)
dλ !== nothing && (dλ[:] .= vec(tmp1))
dλ !== nothing && recursive_copyto!(dλ, tmp1)
if dgrad !== nothing
if tmp2 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
else
!isempty(dgrad) && (dgrad[:] .= vec(tmp2))
!isempty(dgrad) && recursive_copyto!(dgrad, tmp2)
end
end
else
Expand All @@ -579,20 +579,20 @@
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && (dy[:] .= vec(_dy))
dy !== nothing && recursive_copyto!(dy, _dy)

tmp1, tmp2 = back(λ)
if tmp1 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp1 !== nothing
dλ !== nothing && (dλ[:] .= vec(tmp1))
dλ !== nothing && recursive_copyto!(dλ, tmp1)
end

if dgrad !== nothing
if tmp2 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp2 !== nothing
!isempty(dgrad) && (dgrad[:] .= vec(tmp2))
!isempty(dgrad) && recursive_copyto!(dgrad, tmp2)
end
end
end
Expand All @@ -616,7 +616,7 @@
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && (dy[:] .= vec(_dy))
dy !== nothing && recursive_copyto!(dy, _dy)

tmp1, tmp2 = back(λ)
if tmp1 === nothing && !sensealg.autojacvec.allow_nothing
Expand All @@ -629,7 +629,7 @@
if tmp2 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp2 !== nothing
(dgrad[:] .= vec(tmp2))
recursive_copyto!(dgrad, tmp2)

Check warning on line 632 in src/derivative_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/derivative_wrappers.jl#L632

Added line #L632 was not covered by tests
end
end
return dy, dλ, dgrad
Expand Down Expand Up @@ -697,7 +697,7 @@

dλ !== nothing && (dλ .= tmp1)
dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) &&
(dgrad[:] .= vec(tmp2))
recursive_copyto!(dgrad, tmp2)
dy !== nothing && (dy .= tmp3)
else
if W === nothing
Expand All @@ -715,11 +715,11 @@
else
f(y, p, t, W)
end
dy[:] .= vec(out_)
recursive_copyto!(dy, out_)

Check warning on line 718 in src/derivative_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/derivative_wrappers.jl#L718

Added line #L718 was not covered by tests
end
dλ !== nothing && (dλ .= tmp1)
dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) &&
(dgrad[:] .= vec(tmp2))
recursive_copyto!(dgrad, tmp2)
dy !== nothing && (dy .= tmp3)
end
return
Expand Down Expand Up @@ -755,7 +755,7 @@

if StochasticDiffEq.is_diagonal_noise(prob)
pJt = transpose(λ) .* transpose(pJ)
dgrad[:] .= vec(pJt)
recursive_copyto!(dgrad, pJt)
else
m = size(prob.noise_rate_prototype)[2]
for i in 1:m
Expand Down Expand Up @@ -808,7 +808,7 @@

if StochasticDiffEq.is_diagonal_noise(prob)
Jt = transpose(λ) .* transpose(J)
dλ[:] .= vec(Jt)
recursive_copyto!(dλ, Jt)
else
for i in 1:m
tmp = λ' * J[((i - 1) * m + 1):(i * m), :]
Expand Down
45 changes: 45 additions & 0 deletions src/parameters_handling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# NOTE: `fmap` can handle all these cases without us defining them, but it often makes the
# code type unstable. So we define them here to make the code type stable.
# Handle Non-Array Parameters in a Generic Fashion
"""
recursive_copyto!(y, x)

`y[:] .= vec(x)` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x)
recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x)
recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} =

Check warning on line 11 in src/parameters_handling.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters_handling.jl#L10-L11

Added lines #L10 - L11 were not covered by tests
map(recursive_copyto!, values(y), values(x))
recursive_copyto!(y::T, x::T) where {T} = fmap(recursive_copyto!, y, x)

Check warning on line 13 in src/parameters_handling.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters_handling.jl#L13

Added line #L13 was not covered by tests

"""
neg!(x)

`x .*= -1` for generic `x`. This is used to handle non-array parameters!
"""
recursive_neg!(x::AbstractArray) = (x .*= -1)
recursive_neg!(x::Tuple) = map(recursive_neg!, x)
recursive_neg!(x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_neg!, values(x)))
recursive_neg!(x) = fmap(recursive_neg!, x)

Check warning on line 23 in src/parameters_handling.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters_handling.jl#L21-L23

Added lines #L21 - L23 were not covered by tests

"""
recursive_sub!(y, x)

`y .-= x` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_sub!(y::AbstractArray, x::AbstractArray) = axpy!(-1, x, y)
recursive_sub!(y::Tuple, x::Tuple) = map(recursive_sub!, y, x)
recursive_sub!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} =

Check warning on line 32 in src/parameters_handling.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters_handling.jl#L31-L32

Added lines #L31 - L32 were not covered by tests
NamedTuple{F}(map(recursive_sub!, values(y), values(x)))
recursive_sub!(y::T, x::T) where {T} = fmap(recursive_sub!, y, x)

Check warning on line 34 in src/parameters_handling.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters_handling.jl#L34

Added line #L34 was not covered by tests

"""
allocate_vjp(λ, x)

`similar(λ, size(x))` for generic `x`. This is used to handle non-array parameters!
"""
allocate_vjp(λ::AbstractArray, x::AbstractArray) = similar(λ, size(x))
allocate_vjp(λ::AbstractArray, x::Tuple) = allocate_vjp.((λ,), x)
allocate_vjp(λ::AbstractArray, x::NamedTuple{F}) where {F} =

Check warning on line 43 in src/parameters_handling.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters_handling.jl#L42-L43

Added lines #L42 - L43 were not covered by tests
NamedTuple{F}(allocate_vjp.((λ,), values(x)))
allocate_vjp(λ::AbstractArray, x) = fmap(Base.Fix1(allocate_vjp, λ), x)

Check warning on line 45 in src/parameters_handling.jl

View check run for this annotation

Codecov / codecov/patch

src/parameters_handling.jl#L45

Added line #L45 was not covered by tests
31 changes: 15 additions & 16 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@
differentiation with special seeding. The total set of choices are:

+ `nothing`: uses an automatic algorithm to automatically choose the vjp.
This is the default and recommended for most users.
This is the default and recommended for most users.
+ `false`: the Jacobian is constructed via FiniteDiff.jl
+ `true`: the Jacobian is constructed via ForwardDiff.jl
+ `TrackerVJP`: Uses Tracker.jl for the vjp.
Expand Down Expand Up @@ -1067,6 +1067,7 @@
- `linsolve`: the linear solver used in the adjoint solve. Defaults to `nothing`,
which uses a polyalgorithm to choose an efficient
algorithm automatically.
- `linsolve_kwargs`: keyword arguments to be passed to the linear solver.

For more details on the vjp choices, please consult the sensitivity algorithms
documentation page or the docstrings of the vjp types.
Expand All @@ -1076,29 +1077,25 @@
Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at
http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007)
"""
struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS} <:
struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK} <:
AbstractAdjointSensitivityAlgorithm{CS, AD, FDT}
autojacvec::VJP
linsolve::LS
linsolve_kwargs::LK
end

TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjoint

Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true,
diff_type = Val{:central},
autojacvec = nothing, linsolve = nothing)
SteadyStateAdjoint{
chunk_size,
autodiff,
diff_type,
typeof(autojacvec),
typeof(linsolve),
}(autojacvec,
linsolve)
diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing,
linsolve_kwargs=(;))
return SteadyStateAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec),
typeof(linsolve), typeof(linsolve_kwargs)}(autojacvec, linsolve, linsolve_kwargs)
end
function setvjp(sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS},
vjp) where {CS, AD, FDT, VJP, LS}
SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS}(vjp, sensealg.linsolve)
function setvjp(sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK},
vjp) where {CS, AD, FDT, VJP, LS, LK}
return SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS, LK}(vjp, sensealg.linsolve,
sensealg.linsolve_kwargs)
end

abstract type VJPChoice end
Expand Down Expand Up @@ -1314,7 +1311,9 @@
adjalg::A
end

get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where{compile} = AutoReverseDiff(; compile = compile)
get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where{compile} = AutoReverseDiff(; compile)

Check warning on line 1314 in src/sensitivity_algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/sensitivity_algorithms.jl#L1314

Added line #L1314 was not covered by tests
get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote()
get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme()
get_autodiff_from_vjp(::TrackerVJP) = AutoTracker()
get_autodiff_from_vjp(::Nothing) = AutoZygote()
get_autodiff_from_vjp(b::Bool) = ifelse(b, AutoForwardDiff(), AutoFiniteDiff())

Check warning on line 1319 in src/sensitivity_algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/sensitivity_algorithms.jl#L1318-L1319

Added lines #L1318 - L1319 were not covered by tests
36 changes: 13 additions & 23 deletions src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
struct SteadyStateAdjointSensitivityFunction{
C <: AdjointDiffCache,
Alg <: SteadyStateAdjoint,
uType,
SType,
fType <: ODEFunction,
CV,
λType,
VJPType,
LS,
} <: SensitivityFunction
struct SteadyStateAdjointSensitivityFunction{C <: AdjointDiffCache,
Alg <: SteadyStateAdjoint, uType, SType, fType <: ODEFunction, CV, λType, VJPType,
LS} <: SensitivityFunction
diffcache::C
sensealg::Alg
y::uType
Expand All @@ -31,10 +23,10 @@

λ = zero(y)
linsolve = needs_jac ? nothing : sensealg.linsolve
vjp = similar(λ, length(p))
vjp = allocate_vjp(λ, p)

SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec, λ, vjp,
linsolve)
return SteadyStateAdjointSensitivityFunction(diffcache, sensealg, y, sol, f, colorvec,
λ, vjp, linsolve)
end

@noinline function SteadyStateAdjointProblem(sol, sensealg::SteadyStateAdjoint, alg,
Expand All @@ -51,10 +43,10 @@

needs_jac = if has_adjoint(f)
false
# TODO: What is the correct heuristic? Can we afford to compute Jacobian for
# cases where the length(u0) > 50 and if yes till what threshold
# TODO: What is the correct heuristic? Can we afford to compute Jacobian for
# cases where the length(u0) > 50 and if yes till what threshold
elseif sensealg.linsolve === nothing
length(u0) <= 50
length(u0) 50
else
LinearSolve.needs_concrete_A(sensealg.linsolve)
end
Expand Down Expand Up @@ -100,7 +92,6 @@
end

if !needs_jac
# operator = VecJac(f, y, p; Val(DiffEqBase.isinplace(sol.prob)))
__f = y -> f(y, p, nothing)
operator = VecJac(__f, y; autodiff = get_autodiff_from_vjp(sensealg.autojacvec))
linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ))
Expand All @@ -109,15 +100,14 @@
end

# Zygote pullback function won't work with deepcopy
solve(linear_problem, linsolve; alias_A = true) # u is vec(λ)
solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ)

try
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing)
catch e
if sense.sensealg.autojacvec === nothing
@warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp"
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp,
dy = nothing)
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, false, dgrad = vjp, dy = nothing)

Check warning on line 110 in src/steadystate_adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/steadystate_adjoint.jl#L110

Added line #L110 was not covered by tests
else
@warn "AD choice of autojacvec failed in nonlinear solve adjoint"
throw(e)
Expand All @@ -132,10 +122,10 @@
@unpack g_grad_config = diffcache
gradient!(dgdp_val, diffcache.g[2], p, sensealg, g_grad_config[2])
end
dgdp_val .-= vjp
recursive_sub!(dgdp_val, vjp)
return dgdp_val
else
vjp .*= -1
recursive_neg!(vjp)
return vjp
end
end
Loading