Skip to content

Commit

Permalink
fix dual unwrapping in forwarddiff
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Nov 1, 2023
1 parent fe55054 commit 0d3f45f
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
# we need the output type to avoid perturbation confusion while unwrapping nested duals
# We compute a vector-valued integral of the primal and dual simultaneously
if isinplace(cache)
y = cache.f.integrand_prototype
elt = eltype(cache.f.integrand_prototype)
DT = replace_dualvaltype(eltype(p), elt)
len = duallen(p)
Expand Down Expand Up @@ -82,8 +83,9 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
dual = Integrals.__solvebp_call(dp_cache, alg, sensealg, domain, rawp; kwargs...)

res = reinterpret(reshape, DT, dual.u)
out = if (cache.f isa BatchIntegralFunction && cache.f.integrand_prototype isa AbstractVector) ||
(cache.f isa IntegralFunction && !(y isa AbstractArray))
# unwrap the dual when the primal would return a scalar
out = if (cache.f isa BatchIntegralFunction && y isa AbstractVector) ||
!(y isa AbstractArray)
only(res)
else
res
Expand Down

0 comments on commit 0d3f45f

Please sign in to comment.