Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DifferentiationInterface support #260

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Arblib = "fb37089c-8514-4489-9461-98f9c8763369"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
Expand All @@ -27,18 +29,21 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
IntegralsArblibExt = "Arblib"
IntegralsCubaExt = "Cuba"
IntegralsCubatureExt = "Cubature"
IntegralsDifferentiationInterfaceExt = ["ADTypes", "DifferentiationInterface", "ChainRulesCore"]
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsMCIntegrationExt = "MCIntegration"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]

[compat]
ADTypes = "1"
Aqua = "0.8"
Arblib = "1"
ChainRulesCore = "1.18"
CommonSolve = "0.2.4"
Cuba = "2.2"
Cubature = "1.5"
DifferentiationInterface = "0.6"
Distributions = "0.25.87"
FastGaussQuadrature = "0.5,1"
FiniteDiff = "2.12"
Expand All @@ -63,6 +68,7 @@ Arblib = "fb37089c-8514-4489-9461-98f9c8763369"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand All @@ -74,4 +80,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"]
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration", "DifferentiationInterface"]
116 changes: 116 additions & 0 deletions ext/IntegralsDifferentiationInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
module IntegralsDifferentiationInterfaceExt
using Integrals
using LinearAlgebra
using DifferentiationInterface
using ADTypes
using ChainRulesCore

function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg::AbstractADType, domain,
p;
kwargs...)
# TODO: integrate the primal and dual in the same call to the quadrature library
out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)

# the adjoint will be the integral of the input sensitivities, so it maps the
# sensitivity of the output to an object of the type of the parameters
function quadrature_adjoint(Δ)
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes
if isinplace(cache)
# zygote doesn't support mutation, so we build an oop pullback
if cache.f isa BatchIntegralFunction
dx = similar(cache.f.integrand_prototype,
size(cache.f.integrand_prototype)[begin:(end - 1)]..., 1)
_f = x -> (cache.f(dx, x, p); dx)
# TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction
dfdp_ = function (x, p)
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x]
pullback(p -> (cache.f(dx, x_, p); dx), sensealg, p, (Δ,))[1]
# z, back = Zygote.pullback(p) do p
# _dx = Zygote.Buffer(dx)
# cache.f(_dx, x_, p)
# copy(_dx)
# end
# return back(z .= (Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) :
# Δ))[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
else
dx = similar(cache.f.integrand_prototype)
_f = x -> (cache.f(dx, x, p); dx)
dfdp_ = function (x, p)
pullback(p -> (cache.f(dx, x, p); dx), sensealg, p, (Δ,))[1]
# _, back = Zygote.pullback(p) do p
# _dx = Zygote.Buffer(dx)
# cache.f(_dx, x, p)
# copy(_dx)
# end
# back(Δ)[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
end
else
_f = x -> cache.f(x, p)
if cache.f isa BatchIntegralFunction
# TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction
dfdp_ = function (x, p)
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x]
pullback(p -> cache.f(x_, p), sensealg, p, (Δ,))[1]
# z, back = Zygote.pullback(p -> cache.f(x_, p), p)
# return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
else
dfdp_ = function (x, p)
pullback(p -> cache.f(x, p), sensealg, p, (Δ,))[1]
# z, back = Zygote.pullback(p -> cache.f(x, p), p)
# back(z isa Number ? only(Δ) : Δ)[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
end
end

prob = Integrals.build_problem(cache)
# dp_prob = remake(prob, f = dfdp) # fails because we change iip
dp_prob = IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...)
# the infinity transformation was already applied to f so we don't apply it to dfdp
dp_cache = init(dp_prob,
alg;
sensealg = sensealg,
cache.kwargs...)

project_p = ProjectTo(p)
dp = project_p(solve!(dp_cache).u)

lb, ub = domain
if lb isa Number
# TODO replace evaluation at endpoint (which anyone can do without Integrals.jl)
# with integration of dfdx uing the same quadrature
dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb)
dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub)
return (NoTangent(),
NoTangent(),
NoTangent(),
NoTangent(),
Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)),
dp)
else
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we
# can see from writing the multidimensional integral as an iterated integral
# alternatively we can use Stokes' theorem to replace the integral on the
# boundary with a volume integral of the flux of the integrand
# ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the
# dimensionality of the integral or the quadrature used (such as quadratures
# that don't evaluate points on the boundaries) and it could be generalized to
# other kinds of domains. The only question is to determine ω in terms of f and
# the deformation of the surface (e.g. consider integral over an ellipse and
# asking for the derivative of the result w.r.t. the semiaxes of the ellipse)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp)
end
end
out, quadrature_adjoint


end
batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x))

end
22 changes: 19 additions & 3 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity
using Integrals
using Cuba, Cubature
using FastGaussQuadrature
using Test
using DifferentiationInterface
import Zygote, FiniteDiff, ForwardDiff

max_dim_test = 2
max_nout_test = 2
Expand Down Expand Up @@ -95,9 +97,9 @@ end

# helper function / test runner
do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
testf = function (lb, ub, p)
testf = function (lb, ub, p; kws...)
prob = IntegralProblem(f, (lb, ub), p)
scalarize(solve(prob, alg; reltol, abstol))
scalarize(solve(prob, alg; reltol, abstol, kws...))
end
testf(lb, ub, p)

Expand Down Expand Up @@ -135,6 +137,20 @@ do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
@test dp1≈dp2 atol=abstol rtol=reltol
@test dp2≈dp3 atol=abstol rtol=reltol

# DI tests
for sensealg in [AutoForwardDiff(), AutoZygote()]
dlb_di, dub_di, dp_di = let sensealg=sensealg
Zygote.gradient((args...) -> testf(args...; sensealg), lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p)
end
if lb isa Number
@test dlb1≈dlb_di atol=abstol rtol=reltol
@test dub1≈dub_di atol=abstol rtol=reltol
else # TODO: implement multivariate limit derivatives in ZygoteExt
@test_broken dlb1≈dlb_di atol=abstol rtol=reltol
@test_broken dub1≈dub_di atol=abstol rtol=reltol
end
@test dp1≈dp_di atol=abstol rtol=reltol
end
return
end

Expand Down
Loading