Skip to content

Commit

Permalink
explicit unsteady adjoint: cache function to preallocate and setup fu…
Browse files Browse the repository at this point in the history
…nctions. initialize moved to separate function
  • Loading branch information
andrewning committed Mar 9, 2023
1 parent 03f3148 commit 1c4ab23
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 89 deletions.
2 changes: 1 addition & 1 deletion src/ImplicitAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include("internals.jl")
export implicit
include("nonlinear.jl")

export explicit_unsteady, implicit_unsteady
export explicit_unsteady, implicit_unsteady, explicit_unsteady_cache
include("unsteady.jl")

export implicit_linear, apply_factorization
Expand Down
201 changes: 113 additions & 88 deletions src/unsteady.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,141 +28,166 @@ function drdy_forward(residual, y::Number, yprev, t, tprev, x, p) # 1D case
end



# ------ Overloads for explicit_unsteady ----------

"""
explicit_unsteady(solve, perform_step, x, p=(); compile=false)
explicit_unsteady(solve, initialize, perform_step, x, p=(); cache=nothing)
Make an explicit time-marching analysis AD compatible (specifically with ForwardDiff and
ReverseDiff).
Make adjoint more efficienet for time-marching analysis.
# Arguments:
- `solve::function`: function `y, t = solve(x, p)`. Perform a time marching analysis,
and return the matrix `y = [y[1] y[2] y[3] ... y[N]] where `y[i]` is the state vector at time step `i` (a rows is a state, a columms is a timesteps)
and return the matrix `y = [y[1] y[2] y[3] ... y[N]] where `y[i]` is the state vector at time step `i` (each row is a state, each columm is a timesteps)
and vector t = [t[1], t[2], t[3] ... t[N]]` where `t[i]` is the corresponding time,
for input variables `x`, and fixed variables `p`.
- `initialize::function`: `y0 = initialize(t0, x, p)`. Return starting state variables.
- `perform_step::function`: Either `y[i] = perform_step(y[i-1], t[i], t[i-1], x, p)` or
in-place `perform_step!(y[i], y[i-1], t[i], t[i-1], x, p)`. Set the next set of state
variables `y[i]` (scalar or vector), given the previous state `y[i-1]` (scalar or vector),
current time `t[i]`, previous time `t[i-1]`, variables `x` (vector), and fixed parameters `p` (tuple).
- `x`: evaluation point
- `p`: fixed parameters, default is empty tuple.
- `x::vector{float}`: evaluation point
- `p::tuple`: fixed parameters, default is empty tuple.
# Keyword Arguments:
- `compile=false`: indicates whether a tape for the function `perform_step` can be safely
prerecorded. This flag is only used for reverse mode automatic differentiation and
should only be set to `true` if `perform_step` does not contain any branches. Otherwise,
this function may return incorrect gradients.
- `cache=nothing`: see `explicit_unsteady_cache`. If computing derivatives more than once, you should compute the
cache beforehand the save for later iterations. Otherwise, it will be created internally.
"""
function explicit_unsteady(solve, perform_step, x, p=(); compile=false)

perform_step! = perform_step

# if out of place - make in-place
if applicable(perform_step, 1.0, 1.0, 1.0, 1.0, 1.0) # out-of-place
perform_step! = (yw, yprevw, tw, tprevw, xw, pw) -> begin
yw .= perform_step(yprevw, tw, tprevw, xw, pw)
end
end

return _explicit_unsteady(solve, perform_step!, x, p, compile)
end

# If no AD, just solve normally.
_explicit_unsteady(solve, perform_step!, x, p, compile) = solve(x, p)

# Overloaded for ForwardDiff inputs, providing exact derivatives using Jacobian vector product.
function _explicit_unsteady(solve, perform_step!, x::AbstractVector{<:ForwardDiff.Dual{T}}, p, compile) where {T}

# evaluate solver
xv = fd_value(x)
yv, tv = solve(xv, p)
explicit_unsteady(solve, initialize, perform_step, x, p=(); cache=nothing) = _explicit_unsteady(solve, initialize, perform_step, x, p, cache)

# get solution dimensions
ny, nt = size(yv)
# If no AD, or Forward AD, just solve normally.
_explicit_unsteady(solve, initialize, perform_step, x, p, cache) = solve(x, p)

# initialize output and caches
yd = similar(x, ny, nt)

# --- Initial Time Step --- #
# ReverseDiff cases
_explicit_unsteady(solve, initialize, perform_step, x::ReverseDiff.TrackedArray, p, cache) = _explicit_unsteady_reverse_wrapper(solve, initialize, perform_step, x, p, cache)
_explicit_unsteady(solve, initialize, perform_step, x::AbstractVector{<:ReverseDiff.TrackedReal}, p, cache) = _explicit_unsteady_reverse_wrapper(solve, initialize, perform_step, x, p, cache)

# solve for Jacobian-vector products
perform_step!(view(yd, :, 1), view(yv, :, 1), tv[1], tv[1], x, p)
# just declaring dummy function for below (never called directly)
function _explicit_unsteady_reverse(solve, initialize, perform_step, x, p, cache) end

# --- Additional Time Steps ---
# ReverseDiff needs single array output (y and t concatenated) so unpack before returning to user to match returns of solve
function _explicit_unsteady_reverse_wrapper(solve, initialize, perform_step, x, p, cache)

for i = 2:nt
perform_step!(view(yd, :, i), view(yd, :, i-1), tv[i], tv[i-1], x, p)
end
yt = _explicit_unsteady_reverse(solve, initialize, perform_step, x, p, cache)

return yd, tv
# separate out state and time
return yt[1:end-1, :], yt[end, :]
end

# ReverseDiff needs single array output so unpack before returning to user
_explicit_unsteady(solve, perform_step!, x::ReverseDiff.TrackedArray, p, compile) = unpack_explicit_reverse(solve, perform_step!, x, p, compile)
_outofplace_explicit_unsteady(solve, perform_step!, x::AbstractVector{<:ReverseDiff.TrackedReal}, p, compile) = unpack_explicit_reverse(solve, perform_step!, x, p, compile)


# just declaring dummy function for below
function _explicit_unsteady_reverse(solve, perform_step!, x, p, compile) end

# unpack for user
function unpack_explicit_reverse(solve, perform_step!, x, p, compile)
yt = _explicit_unsteady_reverse(solve, perform_step!, x, p, compile)
return yt[1:end-1, :], yt[end, :]
# make perform_step behave as in-place
function _make_perform_step_inplace(perform_step)
if applicable(perform_step, 1.0, 1.0, 1.0, 1.0, 1.0) # out-of-place
return (yw, yprevw, tw, tprevw, xw, pw) -> begin
yw .= perform_step(yprevw, tw, tprevw, xw, pw)
end
else
return perform_step
end
end

# Provide a ChainRule rule for reverse mode
function ChainRulesCore.rrule(::typeof(_explicit_unsteady_reverse), solve, perform_step!, x, p, compile)
"""
explicit_unsteady_cache(initialize, perform_step, x, p; compile=false)
# evaluate solver
yv, tv = solve(x, p)
Initialize arrays and functions needed for explicit_unsteady
# get solution dimensions
ny, nt = size(yv)
# Arguments
- `initialize::function`: `y0 = initialize(t0, x, p)`. Return starting state variables.
- `perform_step::function`: Either `y[i] = perform_step(y[i-1], t[i], t[i-1], x, p)` or
in-place `perform_step!(y[i], y[i-1], t[i], t[i-1], x, p)`. Set the next set of state
variables `y[i]` (scalar or vector), given the previous state `y[i-1]` (scalar or vector),
current time `t[i]`, previous time `t[i-1]`, variables `x` (vector), and fixed parameters `p` (tuple).
- `x::vector{float}`: evaluation point
- `p::tuple`: fixed parameters, default is empty tuple.
- `compile::bool`: indicates whether a tape for the function `perform_step` can be
prerecorded. Will be much faster but should only be `true` if `perform_step` does not contain any branches.
Otherwise, ReverseDiff may return incorrect gradients.
"""
function explicit_unsteady_cache(initialize, perform_step, x, p=(); compile=false)

# create local copy of the output to guard against values getting overwritten
yv = copy(yv)
tv = copy(tv)
# need size of y
y0 = initialize([1.0], x, p)

# allocate inputs
gyprev = similar(yv, ny)
gt = ones(1) # gt and gtprev must differ so we catch the main branch (TODO: probably create an initialize function)
gyprev = similar(y0)
gt = ones(1)
gtprev = zeros(1)
gx = similar(x)
= similar(yv, ny)
input = (gyprev, gt, gtprev, gx, gλ)
= similar(y0)

# if out of place - make in-place
perform_step! = _make_perform_step_inplace(perform_step)

# allocate cache
TRD = eltype(ReverseDiff.track(perform_step!(zeros(ny), gyprev, gt[1], gtprev[1], gx, p)))
ycache = similar(yv, TRD, ny)
TRD = eltype(ReverseDiff.track(perform_step!(zeros(length(y0)), gyprev, gt[1], gtprev[1], gx, p)))
ycache = similar(y0, TRD)

# vector jacobian product
function fvjp(yprev, t, tprev, x, λ)
perform_step!(ycache, yprev, t[1], tprev[1], x, p)
return λ' * ycache
end

# if no tape, just perform reverse diff - branching ok
function vjp_notape(yprev, t, tprev, x, λ)
# config
input = (gyprev, gt, gtprev, gx, gλ)
cfg = ReverseDiff.GradientConfig(input)

# version with no tape
vjp_step = (yprev, t, tprev, x, λ) -> begin
ReverseDiff.gradient!((gyprev, gt, gtprev, gx, gλ), fvjp, (yprev, [t], [tprev], x, λ), cfg)
return gyprev, gx
end

vjp_tape = vjp_notape # default is to not use record tape
# --- repeat for initialize function -----
fvjp_init(t, x, λ) = λ' * initialize(t[1], x, p)

input_init = (gt, gx, gλ)
cfg_init = ReverseDiff.GradientConfig(input_init)

# construct and compile tape
if compile && nt > 1
cfg = ReverseDiff.GradientConfig(input)
vjp_init = (t, x, λ) -> begin
ReverseDiff.gradient!((gt, gx, gλ), fvjp_init, ([t], x, λ), cfg_init)
return gx
end
# --------------

# ----- compile tape
if compile
tape = ReverseDiff.compile(ReverseDiff.GradientTape(fvjp, input, cfg))

# use tape api for vjp (valid for cases with no branching)
vjp_tape = (yprev, t, tprev, x, λ) -> begin
vjp_step = (yprev, t, tprev, x, λ) -> begin
ReverseDiff.gradient!((gyprev, gt, gtprev, gx, gλ), tape, (yprev, [t], [tprev], x, λ))
return gyprev, gx
end

tape_init = ReverseDiff.compile(ReverseDiff.GradientTape(fvjp_init, input_init, cfg_init))

vjp_init = (t, x, λ) -> begin
ReverseDiff.gradient!((gt, gx, gλ), tape_init, ([t], x, λ))
return gx
end
end
# --------------

return vjp_step, vjp_init
end

# Provide a ChainRule rule for reverse mode
function ChainRulesCore.rrule(::typeof(_explicit_unsteady_reverse), solve, initialize, perform_step, x, p, cache)

# evaluate solver
yv, tv = solve(x, p)

# get solution dimensions
ny, nt = size(yv)

# # create local copy of the output to guard against values getting overwritten
# yv = copy(yv)
# tv = copy(tv)

# unpack cache
if isnothing(cache)
cache = explicit_unsteady_cache(initialize, perform_step, x, p, compile=false)
end
vjp_step, vjp_init = cache

function explicit_unsteady_pullback(ytbar)

Expand All @@ -177,34 +202,34 @@ function ChainRulesCore.rrule(::typeof(_explicit_unsteady_reverse), solve, perfo
# --- Additional Time Steps --- #
for i = nt:-1:2
@views λ = ybar[:, i]
@views Δybar, Δxbar = vjp_tape(yv[:, i-1], tv[i], tv[i-1], x, λ)
@views Δybar, Δxbar = vjp_step(yv[:, i-1], tv[i], tv[i-1], x, λ)
xbar .+= Δxbar
@views ybar[:, i-1] .+= Δybar
end

# --- Initial Time Step --- #
@views λ = ybar[:, 1]
@views Δybar, Δxbar = vjp_notape(yv[:, 1], tv[1], tv[1], xbar, λ) # separate branch
Δxbar = vjp_init(tv[1], x, λ)
xbar .+= Δxbar

else

# --- Initial Time Step --- #
@views λ = ybar[:, 1]
@views Δybar, Δxbar = vjp_notape(yv[:, 1], tv[1], tv[1], xbar, λ)
xbar = Δxbar
Δxbar = vjp_init(tv[1], x, λ)
xbar .+= Δxbar

end

return NoTangent(), NoTangent(), NoTangent(), xbar, NoTangent(), NoTangent()
return NoTangent(), NoTangent(), NoTangent(), NoTangent(), xbar, NoTangent(), NoTangent()
end

return [yv; tv'], explicit_unsteady_pullback
end

# register above rule for ReverseDiff
ReverseDiff.@grad_from_chainrules _explicit_unsteady_reverse(solve, perform_step!, x::TrackedArray, p, compile)
ReverseDiff.@grad_from_chainrules _explicit_unsteady_reverse(solve, perform_step!, x::AbstractVector{<:TrackedReal}, p, compile)
ReverseDiff.@grad_from_chainrules _explicit_unsteady_reverse(solve, initialize, perform_step!, x::TrackedArray, p, cache)
ReverseDiff.@grad_from_chainrules _explicit_unsteady_reverse(solve, initialize, perform_step!, x::AbstractVector{<:TrackedReal}, p, cache)

# ------ Overloads for implicit_unsteady ----------

Expand Down

0 comments on commit 1c4ab23

Please sign in to comment.