-
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
142 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
module IntegralsDifferentiationInterfaceExt | ||
using Integrals | ||
using LinearAlgebra | ||
using DifferentiationInterface | ||
using ADTypes | ||
using ChainRulesCore | ||
|
||
function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg::AbstractADType, domain, | ||
p; | ||
kwargs...) | ||
# TODO: integrate the primal and dual in the same call to the quadrature library | ||
out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...) | ||
|
||
# the adjoint will be the integral of the input sensitivities, so it maps the | ||
# sensitivity of the output to an object of the type of the parameters | ||
function quadrature_adjoint(Δ) | ||
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes | ||
if isinplace(cache) | ||
# zygote doesn't support mutation, so we build an oop pullback | ||
if cache.f isa BatchIntegralFunction | ||
dx = similar(cache.f.integrand_prototype, | ||
size(cache.f.integrand_prototype)[begin:(end - 1)]..., 1) | ||
_f = x -> (cache.f(dx, x, p); dx) | ||
# TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction | ||
dfdp_ = function (x, p) | ||
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] | ||
pullback(p -> (cache.f(dx, x_, p); dx), sensealg, p, (Δ,))[1] | ||
# z, back = Zygote.pullback(p) do p | ||
# _dx = Zygote.Buffer(dx) | ||
# cache.f(_dx, x_, p) | ||
# copy(_dx) | ||
# end | ||
# return back(z .= (Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : | ||
# Δ))[1] | ||
end | ||
dfdp = IntegralFunction{false}(dfdp_, nothing) | ||
else | ||
dx = similar(cache.f.integrand_prototype) | ||
_f = x -> (cache.f(dx, x, p); dx) | ||
dfdp_ = function (x, p) | ||
pullback(p -> (cache.f(dx, x, p); dx), sensealg, p, (Δ,))[1] | ||
# _, back = Zygote.pullback(p) do p | ||
# _dx = Zygote.Buffer(dx) | ||
# cache.f(_dx, x, p) | ||
# copy(_dx) | ||
# end | ||
# back(Δ)[1] | ||
end | ||
dfdp = IntegralFunction{false}(dfdp_, nothing) | ||
end | ||
else | ||
_f = x -> cache.f(x, p) | ||
if cache.f isa BatchIntegralFunction | ||
# TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction | ||
dfdp_ = function (x, p) | ||
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] | ||
pullback(p -> cache.f(x_, p), sensealg, p, (Δ,))[1] | ||
# z, back = Zygote.pullback(p -> cache.f(x_, p), p) | ||
# return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1] | ||
end | ||
dfdp = IntegralFunction{false}(dfdp_, nothing) | ||
else | ||
dfdp_ = function (x, p) | ||
pullback(p -> cache.f(x, p), sensealg, p, (Δ,))[1] | ||
# z, back = Zygote.pullback(p -> cache.f(x, p), p) | ||
# back(z isa Number ? only(Δ) : Δ)[1] | ||
end | ||
dfdp = IntegralFunction{false}(dfdp_, nothing) | ||
end | ||
end | ||
|
||
prob = Integrals.build_problem(cache) | ||
# dp_prob = remake(prob, f = dfdp) # fails because we change iip | ||
dp_prob = IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...) | ||
# the infinity transformation was already applied to f so we don't apply it to dfdp | ||
dp_cache = init(dp_prob, | ||
alg; | ||
sensealg = sensealg, | ||
cache.kwargs...) | ||
|
||
project_p = ProjectTo(p) | ||
dp = project_p(solve!(dp_cache).u) | ||
|
||
lb, ub = domain | ||
if lb isa Number | ||
# TODO replace evaluation at endpoint (which anyone can do without Integrals.jl) | ||
# with integration of dfdx uing the same quadrature | ||
dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb) | ||
dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub) | ||
return (NoTangent(), | ||
NoTangent(), | ||
NoTangent(), | ||
NoTangent(), | ||
Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)), | ||
dp) | ||
else | ||
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we | ||
# can see from writing the multidimensional integral as an iterated integral | ||
# alternatively we can use Stokes' theorem to replace the integral on the | ||
# boundary with a volume integral of the flux of the integrand | ||
# ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the | ||
# dimensionality of the integral or the quadrature used (such as quadratures | ||
# that don't evaluate points on the boundaries) and it could be generalized to | ||
# other kinds of domains. The only question is to determine ω in terms of f and | ||
# the deformation of the surface (e.g. consider integral over an ellipse and | ||
# asking for the derivative of the result w.r.t. the semiaxes of the ellipse) | ||
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp) | ||
end | ||
end | ||
out, quadrature_adjoint | ||
|
||
|
||
end | ||
batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x)) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters