Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Dec 20, 2024
1 parent d7ac00a commit 59cbfd4
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 4 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Arblib = "fb37089c-8514-4489-9461-98f9c8763369"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
Expand All @@ -27,18 +29,21 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
IntegralsArblibExt = "Arblib"
IntegralsCubaExt = "Cuba"
IntegralsCubatureExt = "Cubature"
IntegralsDifferentiationInterfaceExt = ["ADTypes", "DifferentiationInterface", "ChainRulesCore"]
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsMCIntegrationExt = "MCIntegration"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]

[compat]
ADTypes = "1"
Aqua = "0.8"
Arblib = "1"
ChainRulesCore = "1.18"
CommonSolve = "0.2.4"
Cuba = "2.2"
Cubature = "1.5"
DifferentiationInterface = "0.6"
Distributions = "0.25.87"
FastGaussQuadrature = "0.5,1"
FiniteDiff = "2.12"
Expand All @@ -63,6 +68,7 @@ Arblib = "fb37089c-8514-4489-9461-98f9c8763369"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand All @@ -74,4 +80,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"]
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration", "DifferentiationInterface"]
116 changes: 116 additions & 0 deletions ext/IntegralsDifferentiationInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
module IntegralsDifferentiationInterfaceExt
using Integrals
using LinearAlgebra
using DifferentiationInterface
using ADTypes
using ChainRulesCore

function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg::AbstractADType, domain,
p;
kwargs...)
# TODO: integrate the primal and dual in the same call to the quadrature library
out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)

# the adjoint will be the integral of the input sensitivities, so it maps the
# sensitivity of the output to an object of the type of the parameters
function quadrature_adjoint(Δ)
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes
if isinplace(cache)
# zygote doesn't support mutation, so we build an oop pullback
if cache.f isa BatchIntegralFunction
dx = similar(cache.f.integrand_prototype,
size(cache.f.integrand_prototype)[begin:(end - 1)]..., 1)
_f = x -> (cache.f(dx, x, p); dx)
# TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction
dfdp_ = function (x, p)
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x]
pullback(p -> (cache.f(dx, x_, p); dx), sensealg, p, (Δ,))[1]
# z, back = Zygote.pullback(p) do p
# _dx = Zygote.Buffer(dx)
# cache.f(_dx, x_, p)
# copy(_dx)
# end
# return back(z .= (Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) :
# Δ))[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
else
dx = similar(cache.f.integrand_prototype)
_f = x -> (cache.f(dx, x, p); dx)
dfdp_ = function (x, p)
pullback(p -> (cache.f(dx, x, p); dx), sensealg, p, (Δ,))[1]
# _, back = Zygote.pullback(p) do p
# _dx = Zygote.Buffer(dx)
# cache.f(_dx, x, p)
# copy(_dx)
# end
# back(Δ)[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
end
else
_f = x -> cache.f(x, p)
if cache.f isa BatchIntegralFunction
# TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction
dfdp_ = function (x, p)
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x]
pullback(p -> cache.f(x_, p), sensealg, p, (Δ,))[1]
# z, back = Zygote.pullback(p -> cache.f(x_, p), p)
# return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
else
dfdp_ = function (x, p)
pullback(p -> cache.f(x, p), sensealg, p, (Δ,))[1]
# z, back = Zygote.pullback(p -> cache.f(x, p), p)
# back(z isa Number ? only(Δ) : Δ)[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
end
end

prob = Integrals.build_problem(cache)
# dp_prob = remake(prob, f = dfdp) # fails because we change iip
dp_prob = IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...)
# the infinity transformation was already applied to f so we don't apply it to dfdp
dp_cache = init(dp_prob,
alg;
sensealg = sensealg,
cache.kwargs...)

project_p = ProjectTo(p)
dp = project_p(solve!(dp_cache).u)

lb, ub = domain
if lb isa Number
# TODO replace evaluation at endpoint (which anyone can do without Integrals.jl)
# with integration of dfdx uing the same quadrature
dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb)
dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub)
return (NoTangent(),
NoTangent(),
NoTangent(),
NoTangent(),
Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)),
dp)
else
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we
# can see from writing the multidimensional integral as an iterated integral
# alternatively we can use Stokes' theorem to replace the integral on the
# boundary with a volume integral of the flux of the integrand
# ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the
# dimensionality of the integral or the quadrature used (such as quadratures
# that don't evaluate points on the boundaries) and it could be generalized to
# other kinds of domains. The only question is to determine ω in terms of f and
# the deformation of the surface (e.g. consider integral over an ellipse and
# asking for the derivative of the result w.r.t. the semiaxes of the ellipse)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp)
end
end
out, quadrature_adjoint


end
batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x))

end
22 changes: 19 additions & 3 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity
using Integrals
using Cuba, Cubature
using FastGaussQuadrature
using Test
using DifferentiationInterface
import Zygote, FiniteDiff, ForwardDiff

max_dim_test = 2
max_nout_test = 2
Expand Down Expand Up @@ -95,9 +97,9 @@ end

# helper function / test runner
do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
testf = function (lb, ub, p)
testf = function (lb, ub, p; kws...)
prob = IntegralProblem(f, (lb, ub), p)
scalarize(solve(prob, alg; reltol, abstol))
scalarize(solve(prob, alg; reltol, abstol, kws...))
end
testf(lb, ub, p)

Expand Down Expand Up @@ -135,6 +137,20 @@ do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
@test dp1dp2 atol=abstol rtol=reltol
@test dp2dp3 atol=abstol rtol=reltol

# DI tests
for sensealg in [AutoForwardDiff(), AutoZygote()]
dlb_di, dub_di, dp_di = let sensealg=sensealg
Zygote.gradient((args...) -> testf(args...; sensealg), lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p)
end
if lb isa Number
@test dlb1dlb_di atol=abstol rtol=reltol
@test dub1dub_di atol=abstol rtol=reltol
else # TODO: implement multivariate limit derivatives in ZygoteExt
@test_broken dlb1dlb_di atol=abstol rtol=reltol
@test_broken dub1dub_di atol=abstol rtol=reltol
end
@test dp1dp_di atol=abstol rtol=reltol
end
return
end

Expand Down

0 comments on commit 59cbfd4

Please sign in to comment.