Skip to content

Commit

Permalink
add explicit case back in for comparison. fixed a potential initializ…
Browse files Browse the repository at this point in the history
…ation issue
  • Loading branch information
andrewning committed Mar 15, 2023
1 parent 19f64d6 commit 1bff68a
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions src/unsteady.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,37 @@ Make adjoint more efficienet for time-marching analysis.
"""
explicit_unsteady(solve, initialize, perform_step, x, p=(); cache=nothing) = _explicit_unsteady(solve, initialize, perform_step, x, p, cache)

# If no AD, or Forward AD, just solve normally.
# If no AD just solve normally.
_explicit_unsteady(solve, initialize, perform_step, x, p, cache) = solve(x, p)


# Overloaded for ForwardDiff
function _explicit_unsteady(solve, initialize, perform_step, x::AbstractVector{<:ForwardDiff.Dual{T}}, p, cache) where {T}

# evaluate solver
xv = fd_value(x)
yv, tv = solve(xv, p) # TODO: if we have explicit (not adaptive) time steps we can skip this

# get solution dimensions
ny, nt = size(yv)

# initialize output and caches
yd = similar(x, ny, nt)

# --- Initial Time Step --- #
# solve for Jacobian-vector products
@views yd[:, 1] = initialize(tv[1], x, p)

# --- Additional Time Steps ---
perform_step! = _make_perform_step_inplace(perform_step)
@views for i = 2:nt
perform_step!(yd[:, i], yd[:, i-1], tv[i], tv[i-1], x, p)
end

return yd, tv
end


# ReverseDiff cases
_explicit_unsteady(solve, initialize, perform_step, x::ReverseDiff.TrackedArray, p, cache) = _explicit_unsteady_reverse_wrapper(solve, initialize, perform_step, x, p, cache)
_explicit_unsteady(solve, initialize, perform_step, x::AbstractVector{<:ReverseDiff.TrackedReal}, p, cache) = _explicit_unsteady_reverse_wrapper(solve, initialize, perform_step, x, p, cache)
Expand Down Expand Up @@ -108,11 +135,11 @@ function explicit_unsteady_cache(initialize, perform_step, x, p=(); compile=fals
y0 = initialize([1.0], x, p)

# allocate inputs
gyprev = similar(y0)
gyprev = ones(length(y0))
gt = ones(1)
gtprev = zeros(1)
gx = similar(x)
= similar(y0)
gx = ones(length(x))
= ones(length(y0))

# if out of place - make in-place
perform_step! = _make_perform_step_inplace(perform_step)
Expand Down

0 comments on commit 1bff68a

Please sign in to comment.