Skip to content

Commit

Permalink
fixed where shadow allocation happens
Browse files Browse the repository at this point in the history
  • Loading branch information
ArbitRandomUser committed Dec 27, 2023
1 parent bc6e347 commit 404be32
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
elseif autojacvec isa EnzymeVJP
paramjac_config = get_paramjac_config(autojacvec, p, f, y, _p, _t; numindvar, alg)
pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE)
paramjac_config = (paramjac_config...,Enzyme.make_zero(pf))
elseif DiffEqBase.has_paramjac(f) || quad || !(autojacvec isa Bool) ||
autojacvec isa EnzymeVJP
paramjac_config = nothing
Expand Down
11 changes: 5 additions & 6 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,

prob = getprob(S)

_tmp1, tmp2, _tmp3, _tmp4, _tmp5 = S.diffcache.paramjac_config
_tmp1, tmp2, _tmp3, _tmp4, _tmp5, _tmp6 = S.diffcache.paramjac_config

if _tmp1 isa FixedSizeDiffCache
tmp1 = get_tmp(_tmp1, dλ)
Expand Down Expand Up @@ -680,15 +680,14 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,

isautojacvec = get_jacvec(sensealg)

dup_pf = Enzyme.make_zero(S.diffcache.pf)
if inplace_sensitivity(S)
if W === nothing
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf,dup_pf), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
dup,
Enzyme.Const(t))
else
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf,dup_pf), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
dup,
Enzyme.Const(t), Enzyme.Const(W))
Expand All @@ -699,11 +698,11 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
dy !== nothing && recursive_copyto!(dy,tmp3)
else
if W === nothing
Enzyme.autodiff(Enzyme.Reverse, Duplicated(S.diffcache.pf,dup_pf), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.autodiff(Enzyme.Reverse, Duplicated(S.diffcache.pf, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
dup, Enzyme.Const(t))
else
Enzyme.autodiff(Enzyme.Reverse, Duplicated(S.diffcache.pf,dup_pf), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.autodiff(Enzyme.Reverse, Duplicated(S.diffcache.pf, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
dup, Enzyme.Const(t), Enzyme.Const(W))
end
Expand Down

0 comments on commit 404be32

Please sign in to comment.