From 28a77890ea5df8e741b507cb2dc75f6db0060968 Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 9 Aug 2023 14:52:54 -0400 Subject: [PATCH] fix AD --- ext/IntegralsForwardDiffExt.jl | 12 +++++++++--- ext/IntegralsZygoteExt.jl | 10 ++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/ext/IntegralsForwardDiffExt.jl b/ext/IntegralsForwardDiffExt.jl index f87b4103..6cfb3254 100644 --- a/ext/IntegralsForwardDiffExt.jl +++ b/ext/IntegralsForwardDiffExt.jl @@ -1,6 +1,5 @@ module IntegralsForwardDiffExt using Integrals -using Integrals: set_f, set_p, build_problem isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) ### Forward-Mode AD Intercepts @@ -68,7 +67,14 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub, rawp = copy(reinterpret(V, p)) - dp_cache = set_p(set_f(cache, dfdp, nout), rawp) + prob = Integrals.build_problem(cache) + dp_prob = remake(prob, f = dfdp, nout = nout, p = rawp) + # the infinity transformation was already applied to f so we don't apply it to dfdp + dp_cache = init(dp_prob, + alg; + sensealg = sensealg, + do_inf_transformation = Val(false), + cache.kwargs...) dual = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, rawp; kwargs...) res = similar(p, cache.nout) @@ -79,6 +85,6 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub, if primal.u isa Number res = first(res) end - SciMLBase.build_solution(build_problem(cache), alg, res, primal.resid) + SciMLBase.build_solution(prob, alg, res, primal.resid) end end diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 3d3185b2..a5c41870 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -1,6 +1,5 @@ module IntegralsZygoteExt using Integrals -using Integrals: set_f if isdefined(Base, :get_extension) using Zygote import ChainRulesCore @@ -68,7 +67,14 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal end end - dp_cache = set_f(cache, dfdp, length(p)) + prob = Integrals.build_problem(cache) + dp_prob = remake(prob, f = dfdp, nout = length(p)) + # the infinity transformation was already applied to f so we don't apply it to dfdp + dp_cache = init(dp_prob, + alg; + sensealg = sensealg, + do_inf_transformation = Val(false), + cache.kwargs...) if p isa Number dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1]