Skip to content

Commit

Permalink
Merge pull request #175 from lxvm/autodiff
Browse files Browse the repository at this point in the history
Fix AD for parameters
  • Loading branch information
ChrisRackauckas authored Sep 17, 2023
2 parents bb6a2a9 + 2d3ee8e commit 823046c
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 92 deletions.
5 changes: 3 additions & 2 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
dfdp = function (out, x, p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
if cache.batch > 0
dx = similar(dualp, cache.nout, size(x, 2))
dx = cache.nout == 1 ? similar(dualp, size(x, ndims(x))) :
similar(dualp, cache.nout, size(x, ndims(x)))
else
dx = similar(dualp, cache.nout)
end
Expand All @@ -49,7 +50,7 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
ys = cache.f(x, dualp)
if cache.batch > 0
out = similar(p, V, nout, size(x, 2))
out = similar(p, V, nout, size(x, ndims(x)))
else
out = similar(p, V, nout)
end
Expand Down
80 changes: 55 additions & 25 deletions ext/IntegralsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,47 @@ using Integrals
if isdefined(Base, :get_extension)
using Zygote
import ChainRulesCore
import ChainRulesCore: NoTangent
import ChainRulesCore: NoTangent, ProjectTo
else
using ..Zygote
import ..Zygote.ChainRulesCore
import ..Zygote.ChainRulesCore: NoTangent
import ..Zygote.ChainRulesCore: NoTangent, ProjectTo
end
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)
ChainRulesCore.@non_differentiable Integrals.isinplace(f, n) # fixes #99

function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub,
p;
kwargs...)
out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)

# the adjoint will be the integral of the input sensitivities, so it maps the
# sensitivity of the output to an object of the type of the parameters
function quadrature_adjoint(Δ)
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes
y = cache.nout == 1 ? Δ[1] : Δ # interpret the output as scalar
# this will not be type-stable, but I believe it is unavoidable due to two ambiguities:
# 1. Δ is the output of the algorithm, and when nout = 1 it is undefined whether the
# output of the algorithm must be a scalar or a vector of length 1
# 2. when nout = 1 the integrand can either be a scalar or a vector of length 1
if isinplace(cache)
dx = zeros(cache.nout)
_f = x -> cache.f(dx, x, p)
if sensealg.vjp isa Integrals.ZygoteVJP
dfdp = function (dx, x, p)
_, back = Zygote.pullback(p) do p
_dx = Zygote.Buffer(x, cache.nout, size(x, 2))
z, back = Zygote.pullback(p) do p
_dx = cache.nout == 1 ?
Zygote.Buffer(dx, eltype(y), size(x, ndims(x))) :
Zygote.Buffer(dx, eltype(y), cache.nout, size(x, ndims(x)))
cache.f(_dx, x, p)
copy(_dx)
end

z = zeros(size(x, 2))
for idx in 1:size(x, 2)
z[1] = 1
dx[:, idx] = back(z)[1]
z[idx] = 0
z .= zero(eltype(z))
for idx in 1:size(x, ndims(x))
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y)
dx[:, idx] .= back(z)[1]
z isa Vector ? (z[idx] = zero(eltype(z))) :
(z[:, idx] .= zero(eltype(z)))
end
end
elseif sensealg.vjp isa Integrals.ReverseDiffVJP
Expand All @@ -44,14 +54,21 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
if sensealg.vjp isa Integrals.ZygoteVJP
if cache.batch > 0
dfdp = function (x, p)
_, back = Zygote.pullback(p -> cache.f(x, p), p)
z, back = Zygote.pullback(p -> cache.f(x, p), p)
# messy, there are 4 cases, some better in forward mode than reverse
# 1: length(y) == 1 and length(p) == 1
# 2: length(y) > 1 and length(p) == 1
# 3: length(y) == 1 and length(p) > 1
# 4: length(y) > 1 and length(p) > 1

out = zeros(length(p), size(x, 2))
z = zeros(size(x, 2))
for idx in 1:size(x, 2)
z[idx] = 1
out[:, idx] = back(z)[1]
z[idx] = 0
z .= zero(eltype(z))
out = zeros(eltype(p), size(p)..., size(x, ndims(x)))
for idx in 1:size(x, ndims(x))
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y)
out isa Vector ? (out[idx] = back(z)[1]) :
(out[:, idx] .= back(z)[1])
z isa Vector ? (z[idx] = zero(y)) :
(z[:, idx] .= zero(eltype(y)))
end
out
end
Expand All @@ -76,17 +93,30 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
do_inf_transformation = Val(false),
cache.kwargs...)

if p isa Number
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1]
else
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u
end
project_p = ProjectTo(p)
dp = project_p(Integrals.__solvebp_call(dp_cache,
alg,
sensealg,
lb,
ub,
p;
kwargs...).u)

if lb isa Number
dlb = -_f(lb)
dub = _f(ub)
dlb = cache.batch > 0 ? -_f([lb]) : -_f(lb)
dub = cache.batch > 0 ? _f([ub]) : _f(ub)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp)
else
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we
# can see from writing the multidimensional integral as an iterated integral
# alternatively we can use Stokes' theorem to replace the integral on the
# boundary with a volume integral of the flux of the integrand
# ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the
# dimensionality of the integral or the quadrature used (such as quadratures
# that don't evaluate points on the boundaries) and it could be generalized to
# other kinds of domains. The only question is to determine ω in terms of f and
# the deformation of the surface (e.g. consider integral over an ellipse and
# asking for the derivative of the result w.r.t. the semiaxes of the ellipse)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NoTangent(), dp)
end
Expand Down
54 changes: 16 additions & 38 deletions lib/IntegralsCubature/src/IntegralsCubature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
maxiters = typemax(Int))
nout = prob.nout
if nout == 1
# the output of prob.f could be either scalar or a vector of length 1, however
# the behavior of the output of the integration routine is undefined (could differ
# across algorithms)
# Cubature will output a real number in when called without nout/fdim
if prob.batch == 0
if isinplace(prob)
dx = zeros(eltype(lb), prob.nout)
Expand All @@ -63,74 +67,52 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
if lb isa Number
if alg isa CubatureJLh
_val, err = Cubature.hquadrature(f, lb, ub;
val, err = Cubature.hquadrature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
_val, err = Cubature.pquadrature(f, lb, ub;
val, err = Cubature.pquadrature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
val = prob.f(lb, p) isa Number ? _val : [_val]
else
if alg isa CubatureJLh
_val, err = Cubature.hcubature(f, lb, ub;
val, err = Cubature.hcubature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
_val, err = Cubature.pcubature(f, lb, ub;
val, err = Cubature.pcubature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end

if isinplace(prob) || !isa(prob.f(lb, p), Number)
val = [_val]
else
val = _val
end
end
else
if isinplace(prob)
f = (x, dx) -> prob.f(dx', x, p)
elseif lb isa Number
if prob.f([lb ub], p) isa Vector
f = (x, dx) -> (dx .= prob.f(x', p))
else
f = function (x, dx)
dx[:] = prob.f(x', p)
end
end
f = (x, dx) -> prob.f(dx, x, p)
else
if prob.f([lb ub], p) isa Vector
f = (x, dx) -> (dx .= prob.f(x, p))
else
f = function (x, dx)
dx .= prob.f(x, p)[:]
end
end
f = (x, dx) -> (dx .= prob.f(x, p))
end
if lb isa Number
if alg isa CubatureJLh
_val, err = Cubature.hquadrature_v(f, lb, ub;
val, err = Cubature.hquadrature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
_val, err = Cubature.pquadrature_v(f, lb, ub;
val, err = Cubature.pquadrature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
else
if alg isa CubatureJLh
_val, err = Cubature.hcubature_v(f, lb, ub;
val, err = Cubature.hcubature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
_val, err = Cubature.pcubature_v(f, lb, ub;
val, err = Cubature.pcubature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
end
val = _val isa Number ? [_val] : _val
end
else
if prob.batch == 0
Expand Down Expand Up @@ -166,13 +148,9 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
else
if isinplace(prob)
f = (x, dx) -> prob.f(dx, x, p)
f = (x, dx) -> (prob.f(dx, x, p); dx)
else
if lb isa Number
f = (x, dx) -> (dx .= prob.f(x', p))
else
f = (x, dx) -> (dx .= prob.f(x, p))
end
f = (x, dx) -> (dx .= prob.f(x, p))
end

if lb isa Number
Expand Down
26 changes: 13 additions & 13 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Integrals, Zygote, FiniteDiff, ForwardDiff, SciMLSensitivity
using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity
using IntegralsCuba, IntegralsCubature
using Test

Expand Down Expand Up @@ -117,7 +117,7 @@ dp4 = ForwardDiff.gradient(p -> testf(lb, ub, p), p)
@test dp1 dp4

### Batch Single dim
f(x, p) = x * p[1] .+ p[2] * p[3]
f(x, p) = x * p[1] .+ p[2] * p[3] # scalar integrand

lb = 1.0
ub = 3.0
Expand All @@ -130,14 +130,14 @@ function testf3(lb, ub, p; f = f)
end

dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)")
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)")
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)

@test dp1 dp3 #passes
@test_broken dp2 dp3 #passes
@test dp2 dp3 #passes

### Batch single dim, nout
f(x, p) = (x * p[1] .+ p[2] * p[3]) .* [1; 2]
f(x, p) = (x' * p[1] .+ p[2] * p[3]) .* [1; 2]

lb = 1.0
ub = 3.0
Expand All @@ -150,11 +150,11 @@ function testf3(lb, ub, p; f = f)
end

dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)

@test dp1 dp3 #passes
# @test dp2 ≈ dp3 #passes
@test dp2 dp3 #passes

### Batch multi dim
f(x, p) = x[1, :] * p[1] .+ p[2] * p[3]
Expand Down Expand Up @@ -190,15 +190,15 @@ function testf3(lb, ub, p; f = f)
end

dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)

@test dp1 dp3
# @test dp2 ≈ dp3
@test dp2 dp3

## iip Batch mulit dim
## iip Batch multi dim
function g(dx, x, p)
dx .= sum(x * p[1] .+ p[2] * p[3], dims = 1)
dx .= dropdims(sum(x * p[1] .+ p[2] * p[3], dims = 1), dims = 1)
end

lb = [1.0, 1.0]
Expand Down Expand Up @@ -236,8 +236,8 @@ function testf3(lb, ub, p; f = g)
end

dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)

@test dp1 dp3
# @test dp2 ≈ dp3
@test dp2 dp3
Loading

0 comments on commit 823046c

Please sign in to comment.