Skip to content

Commit

Permalink
Merge pull request #196 from lxvm/adtests
Browse files Browse the repository at this point in the history
Test AD on more algorithms
  • Loading branch information
ChrisRackauckas authored Jan 3, 2024
2 parents cc3748f + f5648fb commit 7f71009
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 206 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ QuadGK = "2.9"
Reexport = "1.0"
SafeTestsets = "0.1"
SciMLBase = "2.6"
SciMLSensitivity = "7.41"
StaticArrays = "1"
Test = "1"
Zygote = "0.6.60"
Expand All @@ -66,10 +65,9 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Arblib", "SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature"]
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature"]
26 changes: 15 additions & 11 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ using Integrals
isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
### Forward-Mode AD Intercepts

# Direct AD on solvers with QuadGK and HCubature
#= Direct AD on solvers with QuadGK and HCubature
# incompatible with iip since types must change
function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, domain,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Expand All @@ -15,11 +16,14 @@ function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, domain,
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
end
=#

# TODO: add the pushforward for derivative w.r.t lb, and ub (and then combinations?)

# Manually split for the pushforward
function Integrals.__solvebp(cache, alg, sensealg, domain,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
p::Union{D,AbstractArray{<:D}};
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}

# we need the output type to avoid perturbation confusion while unwrapping nested duals
# We compute a vector-valued integral of the primal and dual simultaneously
Expand All @@ -32,11 +36,11 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
len,
size(cache.f.integrand_prototype)...)

dfdp_ = function (out, x, p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
dfdp_ = function (out, x, _p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, _p)
dout = reinterpret(reshape, DT, out)
cache.f(dout, x, dualp)
return out
cache.f(dout, x, p isa D ? only(dualp) : reshape(dualp, size(p)))
return
end
dfdp = if cache.f isa BatchIntegralFunction
BatchIntegralFunction{true}(dfdp_, dual_prototype)
Expand All @@ -55,9 +59,9 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
DT = y isa AbstractArray ? eltype(y) : typeof(y)
elt = unwrap_dualvaltype(DT)

dfdp_ = function (x, p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
ys = cache.f(x, dualp)
dfdp_ = function (x, _p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, _p)
ys = cache.f(x, p isa D ? only(dualp) : reshape(dualp, size(p)))
ys_ = ys isa AbstractArray ? ys : [ys]
# we need to reshape in order for batching to be consistent
return reinterpret(reshape, elt, ys_)
Expand All @@ -70,7 +74,7 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
end

ForwardDiff.can_dual(elt) || ForwardDiff.throw_cannot_dual(elt)
rawp = copy(reinterpret(V, p))
rawp = p isa D ? reinterpret(V, [p]) : copy(reinterpret(V, vec(p)))

prob = Integrals.build_problem(cache)
dp_prob = remake(prob, f = dfdp, p = rawp)
Expand Down
25 changes: 16 additions & 9 deletions ext/IntegralsZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module IntegralsZygoteExt
using LinearAlgebra: dot
using Integrals
if isdefined(Base, :get_extension)
using Zygote
Expand All @@ -11,6 +12,7 @@ else
end
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)
ChainRulesCore.@non_differentiable Integrals.isinplace(f, args...) # fixes #99
ChainRulesCore.@non_differentiable Integrals.init_cacheval(alg, prob)

function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain,
p;
Expand All @@ -24,23 +26,25 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes
if isinplace(cache)
# zygote doesn't support mutation, so we build an oop pullback
dx = similar(cache.f.integrand_prototype)
_f = x -> cache.f(dx, x, p)
if sensealg.vjp isa Integrals.ZygoteVJP
if cache.f isa BatchIntegralFunction
dx = similar(cache.f.integrand_prototype, size(cache.f.integrand_prototype)[begin:end-1]..., 1)
_f = x -> (cache.f(dx, x, p); dx)
# TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction
dfdp_ = function (x, p)
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x]
z, back = Zygote.pullback(p) do p
_dx = Zygote.Buffer(dx, size(dx)[begin:(end - 1)]..., 1)
_dx = Zygote.Buffer(dx)
cache.f(_dx, x_, p)
copy(_dx)
end
return back(z .=isa AbstractArray ? reshape(Δ, size(Δ)..., 1) :
[Δ]))[1]
Δ))[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
else
dx = similar(cache.f.integrand_prototype)
_f = x -> (cache.f(dx, x, p); dx)
dfdp_ = function (x, p)
_, back = Zygote.pullback(p) do p
_dx = Zygote.Buffer(dx)
Expand All @@ -62,8 +66,7 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
dfdp_ = function (x, p)
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x]
z, back = Zygote.pullback(p -> cache.f(x_, p), p)
return back(z .=isa AbstractArray ? reshape(Δ, size(Δ)..., 1) :
[Δ]))[1]
return backisa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1]
end
dfdp = IntegralFunction{false}(dfdp_, nothing)
else
Expand Down Expand Up @@ -98,13 +101,15 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal

lb, ub = domain
if lb isa Number
dlb = cache.f isa BatchIntegralFunction ? -_f([lb]) : -_f(lb)
dub = cache.f isa BatchIntegralFunction ? _f([ub]) : _f(ub)
# TODO replace evaluation at endpoint (which anyone can do without Integrals.jl)
# with integration of dfdx uing the same quadrature
dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb)
dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub)
return (NoTangent(),
NoTangent(),
NoTangent(),
NoTangent(),
Tangent{typeof(domain)}(dlb, dub),
Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)),
dp)
else
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we
Expand All @@ -123,6 +128,8 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
out, quadrature_adjoint
end

batch_unwrap(x::AbstractArray) = dropdims(x; dims=ndims(x))

Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution,
::Val{:u})
sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),)
Expand Down
Loading

0 comments on commit 7f71009

Please sign in to comment.