Skip to content

Commit

Permalink
Fix adjoint for ChangeOfVariables (#246)
Browse files Browse the repository at this point in the history
* better buffer docstring

* test for ChangeOfVariables rrule

* fix ChangeOfVariables rrule

* typo

* rephrase

* test an affine change of variables

* format
  • Loading branch information
lxvm authored Mar 3, 2024
1 parent 973a5e4 commit 5033655
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 39 deletions.
45 changes: 20 additions & 25 deletions ext/IntegralsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,20 @@ end
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)
ChainRulesCore.@non_differentiable Integrals.isinplace(f, args...) # fixes #99
ChainRulesCore.@non_differentiable Integrals.init_cacheval(alg, prob)
ChainRulesCore.@non_differentiable Integrals.substitute_f(args...) # use ∂f/∂p instead
ChainRulesCore.@non_differentiable Integrals.substitute_v(args...) # TODO for ∂f/∂u
ChainRulesCore.@non_differentiable Integrals.substitute_bv(args...) # TODO for ∂f/∂u

function ChainRulesCore.rrule(::typeof(Integrals.__solve), cache::Integrals.IntegralCache,
alg::Integrals.ChangeOfVariables, sensealg, udomain, p;
kwargs...)
_cache, vdomain = Integrals._change_variables(cache, alg, sensealg, udomain, p)
sol, back = Zygote.pullback((args...) -> Integrals.__solve(args...; kwargs...),
_cache, alg.alg, sensealg, vdomain, p)
function change_of_variables_pullback(Δ)
return (NoTangent(), back(Δ)...)
# TODO move this adjoint to SciMLBase
function ChainRulesCore.rrule(
::typeof(SciMLBase.build_solution), prob::IntegralProblem, alg, u, resid; kwargs...)
function build_integral_solution_pullback(Δ)
return NoTangent(), NoTangent(), NoTangent(), Δ, NoTangent()
end
prob = Integrals.build_problem(cache)
_sol = SciMLBase.build_solution(
prob, alg.alg, sol.u, sol.resid, chi = sol.chi, retcode = sol.retcode, stats = sol.stats)
return _sol, change_of_variables_pullback
return SciMLBase.build_solution(prob, alg, u, resid; kwargs...),
build_integral_solution_pullback
end

# we will need to implement the following adjoints when we compute ∂f/∂u
function ChainRulesCore.rrule(::typeof(Integrals.substitute_v), args...)
function substitute_v_pullback(_)
return NoTangent(), ntuple(_ -> NoTangent(), length(args))...
end
return Integrals.substitute_v(args...), substitute_v_pullback
end
function ChainRulesCore.rrule(::typeof(Integrals.substitute_bv), args...)
function substitute_bv_pullback(_)
return NoTangent(), ntuple(_ -> NoTangent(), length(args))...
end
return Integrals.substitute_bv(args...), substitute_bv_pullback
end
function ChainRulesCore.rrule(::typeof(Integrals._evaluate!), f, y, u, p)
out, back = Zygote.pullback(y, u, p) do y, u, p
b = Zygote.Buffer(y)
Expand All @@ -51,6 +36,16 @@ function ChainRulesCore.rrule(::typeof(Integrals._evaluate!), f, y, u, p)
out, Δ -> (NoTangent(), NoTangent(), back(Δ)...)
end

function ChainRulesCore.rrule(::typeof(Integrals.u2t), lb, ub)
tlb, tub = out = Integrals.u2t(lb, ub)
function u2t_pullback(Δ)
_, lbjac = Integrals.t2ujac(tlb, lb, ub)
_, ubjac = Integrals.t2ujac(tub, lb, ub)
return NoTangent(), Δ[1] / lbjac, Δ[2] / ubjac
end
return out, u2t_pullback
end

function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain,
p;
kwargs...)
Expand Down
13 changes: 4 additions & 9 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ end

function __solve(cache::IntegralCache, alg::ChangeOfVariables, sensealg, udomain, p;
kwargs...)
_cache, vdomain = _change_variables(cache, alg, sensealg, udomain, p)
sol = __solve(_cache, alg.alg, sensealg, vdomain, p; kwargs...)
prob = build_problem(cache)
return SciMLBase.build_solution(
prob, alg.alg, sol.u, sol.resid, chi = sol.chi, retcode = sol.retcode, stats = sol.stats)
end

function _change_variables(cache, alg, sensealg, udomain, p)
cacheval = cache.cacheval.alg
g, vdomain = alg.fu2gv(cache.f, udomain)
_cache = IntegralCache(Val(isinplace(g)),
Expand All @@ -97,7 +89,10 @@ function _change_variables(cache, alg, sensealg, udomain, p)
sensealg,
cache.kwargs,
cacheval)
return _cache, vdomain
sol = __solve(_cache, alg.alg, sensealg, vdomain, p; kwargs...)
prob = build_problem(cache)
return SciMLBase.build_solution(
prob, alg.alg, sol.u, sol.resid, chi = sol.chi, retcode = sol.retcode, stats = sol.stats)
end

function get_prototype(prob::IntegralProblem)
Expand Down
12 changes: 8 additions & 4 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ One-dimensional Gauss-Kronrod integration from QuadGK.jl.
This method also takes the optional arguments `order` and `norm`.
Which are the order of the integration rule
and the norm for calculating the error, respectively.
Lastly, the `buffer` keyword, if set, will allocate a buffer to reuse
for multiple integrals, which may require evaluating the integrand.
Lastly, the `buffer` keyword, if set (e.g. `buffer=true`), will allocate a buffer to reuse
for multiple integrals and may require evaluating the integrand unless an
`integrand_prototype` is provided. Unlike the `segbuf` keyword to `quadgk`, you do not
allocate the buffer as this is handled automatically.
## References
Expand Down Expand Up @@ -35,8 +37,10 @@ This method also takes the optional arguments `initdiv` and `norm`.
Which are the initial number of segments
each dimension of the integration domain is divided into,
and the norm for calculating the error, respectively.
Lastly, the `buffer` keyword, if set, will allocate a buffer to reuse
for multiple integrals, which may require evaluating the integrand.
Lastly, the `buffer` keyword, if set (e.g. `buffer=true`), will allocate a buffer to reuse
for multiple integrals and may require evaluating the integrand unless an
`integrand_prototype` is provided. Unlike the `buffer` keyword to `hcubature/hquadrature`,
you do not allocate the buffer as this is handled automatically.
## References
Expand Down
2 changes: 1 addition & 1 deletion src/infinity_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function substitute_f(f::IntegralFunction{true}, v2ujac, lb, ub)
vol = prod((ub - lb) / 2) # just to get the type of the jacobian determinant
IntegralFunction{true}(prototype * vol) do y, v, p
u, jac = substitute_v(v2ujac, v, lb, ub)
_y = _evaluate!(f, prototype, u, p)
_y = _evaluate!(_f, prototype, u, p)
y .= _y .* jac
return
end
Expand Down
27 changes: 27 additions & 0 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,30 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands),
do_tests(; f = bfiip, scalarize, lb = ones(dim), ub = 3ones(dim),
p = [2.0i for i in 1:nout], alg, abstol, reltol)
end

@testset "ChangeOfVariables rrule" begin
alg = QuadGKJL()
# test a simple u-substitution of x = 2.7u + 1.3
talg = Integrals.ChangeOfVariables(alg) do f, domain
if f isa IntegralFunction{false}
IntegralFunction((x, p) -> f((x - 1.3) / 2.7, p) / 2.7),
map(x -> 1.3 + 2.7x, domain)
else
error("not implemented")
end
end
testf = (f, lb, ub, p, alg) -> begin
prob = IntegralProblem(f, (lb, ub), p)
solve(prob, alg; abstol, reltol).u
end
_testf = (x, p) -> x^2 * p
lb, ub, p = 1.0, 5.0, 2.0
sol = Zygote.withgradient((args...) -> testf(_testf, args..., alg), lb, ub, p)
tsol = Zygote.withgradient((args...) -> testf(_testf, args..., talg), lb, ub, p)
@test sol.val tsol.val
# Fundamental theorem of Calculus part 1
@test sol.grad[1] tsol.grad[1] -_testf(lb, p)
@test sol.grad[2] tsol.grad[2] _testf(ub, p)
# This is to check ∂p
@test sol.grad[3] tsol.grad[3]
end

0 comments on commit 5033655

Please sign in to comment.