From 59cbfd4fe11e77460341e818d6777ce24eb08c8a Mon Sep 17 00:00:00 2001 From: Lorenzo Van Munoz Date: Fri, 20 Dec 2024 07:55:59 -0800 Subject: [PATCH] initial commit --- Project.toml | 8 +- ext/IntegralsDifferentiationInterfaceExt.jl | 116 ++++++++++++++++++++ test/derivative_tests.jl | 22 +++- 3 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 ext/IntegralsDifferentiationInterfaceExt.jl diff --git a/Project.toml b/Project.toml index f7577f5..533c616 100644 --- a/Project.toml +++ b/Project.toml @@ -14,10 +14,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" [weakdeps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Arblib = "fb37089c-8514-4489-9461-98f9c8763369" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1" Cubature = "667455a9-e2ce-5579-9412-b964f529a492" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" @@ -27,18 +29,21 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" IntegralsArblibExt = "Arblib" IntegralsCubaExt = "Cuba" IntegralsCubatureExt = "Cubature" +IntegralsDifferentiationInterfaceExt = ["ADTypes", "DifferentiationInterface", "ChainRulesCore"] IntegralsFastGaussQuadratureExt = "FastGaussQuadrature" IntegralsForwardDiffExt = "ForwardDiff" IntegralsMCIntegrationExt = "MCIntegration" IntegralsZygoteExt = ["Zygote", "ChainRulesCore"] [compat] +ADTypes = "1" Aqua = "0.8" Arblib = "1" ChainRulesCore = "1.18" CommonSolve = "0.2.4" Cuba = "2.2" Cubature = "1.5" +DifferentiationInterface = "0.6" Distributions = "0.25.87" FastGaussQuadrature = "0.5,1" FiniteDiff = "2.12" @@ -63,6 +68,7 @@ Arblib = "fb37089c-8514-4489-9461-98f9c8763369" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1" Cubature = "667455a9-e2ce-5579-9412-b964f529a492" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -74,4 +80,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"] +test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration", "DifferentiationInterface"] diff --git a/ext/IntegralsDifferentiationInterfaceExt.jl b/ext/IntegralsDifferentiationInterfaceExt.jl new file mode 100644 index 0000000..4f97966 --- /dev/null +++ b/ext/IntegralsDifferentiationInterfaceExt.jl @@ -0,0 +1,116 @@ +module IntegralsDifferentiationInterfaceExt +using Integrals +using LinearAlgebra +using DifferentiationInterface +using ADTypes +using ChainRulesCore + +function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg::AbstractADType, domain, + p; + kwargs...) + # TODO: integrate the primal and dual in the same call to the quadrature library + out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...) + + # the adjoint will be the integral of the input sensitivities, so it maps the + # sensitivity of the output to an object of the type of the parameters + function quadrature_adjoint(Δ) + # https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes + if isinplace(cache) + # zygote doesn't support mutation, so we build an oop pullback + if cache.f isa BatchIntegralFunction + dx = similar(cache.f.integrand_prototype, + size(cache.f.integrand_prototype)[begin:(end - 1)]..., 1) + _f = x -> (cache.f(dx, x, p); dx) + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + pullback(p -> (cache.f(dx, x_, p); dx), sensealg, p, (Δ,))[1] + # z, back = Zygote.pullback(p) do p + # _dx = Zygote.Buffer(dx) + # cache.f(_dx, x_, p) + # copy(_dx) + # end + # return back(z .= (Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : + # Δ))[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + else + dx = similar(cache.f.integrand_prototype) + _f = x -> (cache.f(dx, x, p); dx) + dfdp_ = function (x, p) + pullback(p -> (cache.f(dx, x, p); dx), sensealg, p, (Δ,))[1] + # _, back = Zygote.pullback(p) do p + # _dx = Zygote.Buffer(dx) + # cache.f(_dx, x, p) + # copy(_dx) + # end + # back(Δ)[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + else + _f = x -> cache.f(x, p) + if cache.f isa BatchIntegralFunction + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + pullback(p -> cache.f(x_, p), sensealg, p, (Δ,))[1] + # z, back = Zygote.pullback(p -> cache.f(x_, p), p) + # return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + else + dfdp_ = function (x, p) + pullback(p -> cache.f(x, p), sensealg, p, (Δ,))[1] + # z, back = Zygote.pullback(p -> cache.f(x, p), p) + # back(z isa Number ? only(Δ) : Δ)[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + end + + prob = Integrals.build_problem(cache) + # dp_prob = remake(prob, f = dfdp) # fails because we change iip + dp_prob = IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...) + # the infinity transformation was already applied to f so we don't apply it to dfdp + dp_cache = init(dp_prob, + alg; + sensealg = sensealg, + cache.kwargs...) + + project_p = ProjectTo(p) + dp = project_p(solve!(dp_cache).u) + + lb, ub = domain + if lb isa Number + # TODO replace evaluation at endpoint (which anyone can do without Integrals.jl) + # with integration of dfdx uing the same quadrature + dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb) + dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub) + return (NoTangent(), + NoTangent(), + NoTangent(), + NoTangent(), + Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)), + dp) + else + # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we + # can see from writing the multidimensional integral as an iterated integral + # alternatively we can use Stokes' theorem to replace the integral on the + # boundary with a volume integral of the flux of the integrand + # ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the + # dimensionality of the integral or the quadrature used (such as quadratures + # that don't evaluate points on the boundaries) and it could be generalized to + # other kinds of domains. The only question is to determine ω in terms of f and + # the deformation of the surface (e.g. consider integral over an ellipse and + # asking for the derivative of the result w.r.t. the semiaxes of the ellipse) + return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp) + end + end + out, quadrature_adjoint + + +end +batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x)) + +end \ No newline at end of file diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 124c2ac..603c2a9 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -1,7 +1,9 @@ -using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity +using Integrals using Cuba, Cubature using FastGaussQuadrature using Test +using DifferentiationInterface +import Zygote, FiniteDiff, ForwardDiff max_dim_test = 2 max_nout_test = 2 @@ -95,9 +97,9 @@ end # helper function / test runner do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol) - testf = function (lb, ub, p) + testf = function (lb, ub, p; kws...) prob = IntegralProblem(f, (lb, ub), p) - scalarize(solve(prob, alg; reltol, abstol)) + scalarize(solve(prob, alg; reltol, abstol, kws...)) end testf(lb, ub, p) @@ -135,6 +137,20 @@ do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol) @test dp1≈dp2 atol=abstol rtol=reltol @test dp2≈dp3 atol=abstol rtol=reltol + # DI tests + for sensealg in [AutoForwardDiff(), AutoZygote()] + dlb_di, dub_di, dp_di = let sensealg=sensealg + Zygote.gradient((args...) -> testf(args...; sensealg), lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p) + end + if lb isa Number + @test dlb1≈dlb_di atol=abstol rtol=reltol + @test dub1≈dub_di atol=abstol rtol=reltol + else # TODO: implement multivariate limit derivatives in ZygoteExt + @test_broken dlb1≈dlb_di atol=abstol rtol=reltol + @test_broken dub1≈dub_di atol=abstol rtol=reltol + end + @test dp1≈dp_di atol=abstol rtol=reltol + end return end