From 5033655b5298f52ce94ab1d5d3f29f135072fa44 Mon Sep 17 00:00:00 2001 From: Lorenzo Van Munoz <66997677+lxvm@users.noreply.github.com> Date: Sun, 3 Mar 2024 17:34:13 -0500 Subject: [PATCH] Fix adjoint for ChangeOfVariables (#246) * better buffer docstring * test for ChangeOfVariables rrule * fix ChangeOfVariables rrule * typo * rephrase * test an affine change of variables * format --- ext/IntegralsZygoteExt.jl | 45 +++++++++++++++++---------------------- src/Integrals.jl | 13 ++++------- src/algorithms.jl | 12 +++++++---- src/infinity_handling.jl | 2 +- test/derivative_tests.jl | 27 +++++++++++++++++++++++ 5 files changed, 60 insertions(+), 39 deletions(-) diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index bb5d457..ea4c55b 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -13,35 +13,20 @@ end ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) ChainRulesCore.@non_differentiable Integrals.isinplace(f, args...) # fixes #99 ChainRulesCore.@non_differentiable Integrals.init_cacheval(alg, prob) +ChainRulesCore.@non_differentiable Integrals.substitute_f(args...) # use ∂f/∂p instead +ChainRulesCore.@non_differentiable Integrals.substitute_v(args...) # TODO for ∂f/∂u +ChainRulesCore.@non_differentiable Integrals.substitute_bv(args...) # TODO for ∂f/∂u -function ChainRulesCore.rrule(::typeof(Integrals.__solve), cache::Integrals.IntegralCache, - alg::Integrals.ChangeOfVariables, sensealg, udomain, p; - kwargs...) - _cache, vdomain = Integrals._change_variables(cache, alg, sensealg, udomain, p) - sol, back = Zygote.pullback((args...) -> Integrals.__solve(args...; kwargs...), - _cache, alg.alg, sensealg, vdomain, p) - function change_of_variables_pullback(Δ) - return (NoTangent(), back(Δ)...) +# TODO move this adjoint to SciMLBase +function ChainRulesCore.rrule( + ::typeof(SciMLBase.build_solution), prob::IntegralProblem, alg, u, resid; kwargs...) + function build_integral_solution_pullback(Δ) + return NoTangent(), NoTangent(), NoTangent(), Δ, NoTangent() end - prob = Integrals.build_problem(cache) - _sol = SciMLBase.build_solution( - prob, alg.alg, sol.u, sol.resid, chi = sol.chi, retcode = sol.retcode, stats = sol.stats) - return _sol, change_of_variables_pullback + return SciMLBase.build_solution(prob, alg, u, resid; kwargs...), + build_integral_solution_pullback end -# we will need to implement the following adjoints when we compute ∂f/∂u -function ChainRulesCore.rrule(::typeof(Integrals.substitute_v), args...) - function substitute_v_pullback(_) - return NoTangent(), ntuple(_ -> NoTangent(), length(args))... - end - return Integrals.substitute_v(args...), substitute_v_pullback -end -function ChainRulesCore.rrule(::typeof(Integrals.substitute_bv), args...) - function substitute_bv_pullback(_) - return NoTangent(), ntuple(_ -> NoTangent(), length(args))... - end - return Integrals.substitute_bv(args...), substitute_bv_pullback -end function ChainRulesCore.rrule(::typeof(Integrals._evaluate!), f, y, u, p) out, back = Zygote.pullback(y, u, p) do y, u, p b = Zygote.Buffer(y) @@ -51,6 +36,16 @@ function ChainRulesCore.rrule(::typeof(Integrals._evaluate!), f, y, u, p) out, Δ -> (NoTangent(), NoTangent(), back(Δ)...) end +function ChainRulesCore.rrule(::typeof(Integrals.u2t), lb, ub) + tlb, tub = out = Integrals.u2t(lb, ub) + function u2t_pullback(Δ) + _, lbjac = Integrals.t2ujac(tlb, lb, ub) + _, ubjac = Integrals.t2ujac(tub, lb, ub) + return NoTangent(), Δ[1] / lbjac, Δ[2] / ubjac + end + return out, u2t_pullback +end + function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain, p; kwargs...) diff --git a/src/Integrals.jl b/src/Integrals.jl index 5edee6d..d11fab9 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -78,14 +78,6 @@ end function __solve(cache::IntegralCache, alg::ChangeOfVariables, sensealg, udomain, p; kwargs...) - _cache, vdomain = _change_variables(cache, alg, sensealg, udomain, p) - sol = __solve(_cache, alg.alg, sensealg, vdomain, p; kwargs...) - prob = build_problem(cache) - return SciMLBase.build_solution( - prob, alg.alg, sol.u, sol.resid, chi = sol.chi, retcode = sol.retcode, stats = sol.stats) -end - -function _change_variables(cache, alg, sensealg, udomain, p) cacheval = cache.cacheval.alg g, vdomain = alg.fu2gv(cache.f, udomain) _cache = IntegralCache(Val(isinplace(g)), @@ -97,7 +89,10 @@ function _change_variables(cache, alg, sensealg, udomain, p) sensealg, cache.kwargs, cacheval) - return _cache, vdomain + sol = __solve(_cache, alg.alg, sensealg, vdomain, p; kwargs...) + prob = build_problem(cache) + return SciMLBase.build_solution( + prob, alg.alg, sol.u, sol.resid, chi = sol.chi, retcode = sol.retcode, stats = sol.stats) end function get_prototype(prob::IntegralProblem) diff --git a/src/algorithms.jl b/src/algorithms.jl index 4841f0a..bbffa12 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -5,8 +5,10 @@ One-dimensional Gauss-Kronrod integration from QuadGK.jl. This method also takes the optional arguments `order` and `norm`. Which are the order of the integration rule and the norm for calculating the error, respectively. -Lastly, the `buffer` keyword, if set, will allocate a buffer to reuse -for multiple integrals, which may require evaluating the integrand. +Lastly, the `buffer` keyword, if set (e.g. `buffer=true`), will allocate a buffer to reuse +for multiple integrals and may require evaluating the integrand unless an +`integrand_prototype` is provided. Unlike the `segbuf` keyword to `quadgk`, you do not +allocate the buffer as this is handled automatically. ## References @@ -35,8 +37,10 @@ This method also takes the optional arguments `initdiv` and `norm`. Which are the initial number of segments each dimension of the integration domain is divided into, and the norm for calculating the error, respectively. -Lastly, the `buffer` keyword, if set, will allocate a buffer to reuse -for multiple integrals, which may require evaluating the integrand. +Lastly, the `buffer` keyword, if set (e.g. `buffer=true`), will allocate a buffer to reuse +for multiple integrals and may require evaluating the integrand unless an +`integrand_prototype` is provided. Unlike the `buffer` keyword to `hcubature/hquadrature`, +you do not allocate the buffer as this is handled automatically. ## References diff --git a/src/infinity_handling.jl b/src/infinity_handling.jl index b822362..1abef46 100644 --- a/src/infinity_handling.jl +++ b/src/infinity_handling.jl @@ -29,7 +29,7 @@ function substitute_f(f::IntegralFunction{true}, v2ujac, lb, ub) vol = prod((ub - lb) / 2) # just to get the type of the jacobian determinant IntegralFunction{true}(prototype * vol) do y, v, p u, jac = substitute_v(v2ujac, v, lb, ub) - _y = _evaluate!(f, prototype, u, p) + _y = _evaluate!(_f, prototype, u, p) y .= _y .* jac return end diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 9f0cd35..124c2ac 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -356,3 +356,30 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), do_tests(; f = bfiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol) end + +@testset "ChangeOfVariables rrule" begin + alg = QuadGKJL() + # test a simple u-substitution of x = 2.7u + 1.3 + talg = Integrals.ChangeOfVariables(alg) do f, domain + if f isa IntegralFunction{false} + IntegralFunction((x, p) -> f((x - 1.3) / 2.7, p) / 2.7), + map(x -> 1.3 + 2.7x, domain) + else + error("not implemented") + end + end + testf = (f, lb, ub, p, alg) -> begin + prob = IntegralProblem(f, (lb, ub), p) + solve(prob, alg; abstol, reltol).u + end + _testf = (x, p) -> x^2 * p + lb, ub, p = 1.0, 5.0, 2.0 + sol = Zygote.withgradient((args...) -> testf(_testf, args..., alg), lb, ub, p) + tsol = Zygote.withgradient((args...) -> testf(_testf, args..., talg), lb, ub, p) + @test sol.val ≈ tsol.val + # Fundamental theorem of Calculus part 1 + @test sol.grad[1] ≈ tsol.grad[1] ≈ -_testf(lb, p) + @test sol.grad[2] ≈ tsol.grad[2] ≈ _testf(ub, p) + # This is to check ∂p + @test sol.grad[3] ≈ tsol.grad[3] +end