Skip to content

Commit

Permalink
fix AD
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Aug 9, 2023
1 parent 5026be1 commit 28a7789
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
12 changes: 9 additions & 3 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module IntegralsForwardDiffExt
using Integrals
using Integrals: set_f, set_p, build_problem
isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
### Forward-Mode AD Intercepts

Expand Down Expand Up @@ -68,7 +67,14 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub,

rawp = copy(reinterpret(V, p))

dp_cache = set_p(set_f(cache, dfdp, nout), rawp)
prob = Integrals.build_problem(cache)
dp_prob = remake(prob, f = dfdp, nout = nout, p = rawp)
# the infinity transformation was already applied to f so we don't apply it to dfdp
dp_cache = init(dp_prob,
alg;
sensealg = sensealg,
do_inf_transformation = Val(false),
cache.kwargs...)
dual = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, rawp; kwargs...)

res = similar(p, cache.nout)
Expand All @@ -79,6 +85,6 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
if primal.u isa Number
res = first(res)
end
SciMLBase.build_solution(build_problem(cache), alg, res, primal.resid)
SciMLBase.build_solution(prob, alg, res, primal.resid)
end
end
10 changes: 8 additions & 2 deletions ext/IntegralsZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module IntegralsZygoteExt
using Integrals
using Integrals: set_f
if isdefined(Base, :get_extension)
using Zygote
import ChainRulesCore
Expand Down Expand Up @@ -68,7 +67,14 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
end
end

dp_cache = set_f(cache, dfdp, length(p))
prob = Integrals.build_problem(cache)
dp_prob = remake(prob, f = dfdp, nout = length(p))
# the infinity transformation was already applied to f so we don't apply it to dfdp
dp_cache = init(dp_prob,
alg;
sensealg = sensealg,
do_inf_transformation = Val(false),
cache.kwargs...)

if p isa Number
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1]
Expand Down

0 comments on commit 28a7789

Please sign in to comment.