Skip to content

Commit

Permalink
fix forwarddiff bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Jan 1, 2024
1 parent bca4a16 commit 22dabde
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 57 deletions.
26 changes: 15 additions & 11 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ using Integrals
isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
### Forward-Mode AD Intercepts

# Direct AD on solvers with QuadGK and HCubature
#= Direct AD on solvers with QuadGK and HCubature
# incompatible with iip since types must change
function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, domain,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Expand All @@ -15,11 +16,14 @@ function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, domain,
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
end
=#

# TODO: add the pushforward for derivative w.r.t lb, and ub (and then combinations?)

# Manually split for the pushforward
function Integrals.__solvebp(cache, alg, sensealg, domain,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
p::Union{D,AbstractArray{<:D}};
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}

# we need the output type to avoid perturbation confusion while unwrapping nested duals
# We compute a vector-valued integral of the primal and dual simultaneously
Expand All @@ -32,11 +36,11 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
len,
size(cache.f.integrand_prototype)...)

dfdp_ = function (out, x, p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
dfdp_ = function (out, x, _p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, _p)
dout = reinterpret(reshape, DT, out)
cache.f(dout, x, dualp)
return out
cache.f(dout, x, p isa D ? only(dualp) : reshape(dualp, size(p)))
return
end
dfdp = if cache.f isa BatchIntegralFunction
BatchIntegralFunction{true}(dfdp_, dual_prototype)
Expand All @@ -55,9 +59,9 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
DT = y isa AbstractArray ? eltype(y) : typeof(y)
elt = unwrap_dualvaltype(DT)

dfdp_ = function (x, p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
ys = cache.f(x, dualp)
dfdp_ = function (x, _p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, _p)
ys = cache.f(x, p isa D ? only(dualp) : reshape(dualp, size(p)))
ys_ = ys isa AbstractArray ? ys : [ys]
# we need to reshape in order for batching to be consistent
return reinterpret(reshape, elt, ys_)
Expand All @@ -70,7 +74,7 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
end

ForwardDiff.can_dual(elt) || ForwardDiff.throw_cannot_dual(elt)
rawp = copy(reinterpret(V, p))
rawp = p isa D ? reinterpret(V, [p]) : copy(reinterpret(V, vec(p)))

prob = Integrals.build_problem(cache)
dp_prob = remake(prob, f = dfdp, p = rawp)
Expand Down
120 changes: 74 additions & 46 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ alg_req = Dict(
max_dim = Inf, allows_iip = true),
# VEGAS() => (nout = 1, allows_batch = true, min_dim = 2, max_dim = Inf,
# allows_iip = true),
# CubatureJLh() => (nout = Inf, allows_batch = true, min_dim = 1,
# max_dim = Inf, allows_iip = true),
# CubatureJLp() => (nout = Inf, allows_batch = true, min_dim = 1,
# max_dim = Inf, allows_iip = true),
CubatureJLh() => (nout = Inf, allows_batch = true, min_dim = 1,
max_dim = Inf, allows_iip = true),
CubatureJLp() => (nout = Inf, allows_batch = true, min_dim = 1,
max_dim = Inf, allows_iip = true),
# CubaVegas() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf,
# allows_iip = true),
# CubaSUAVE() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf,
# allows_iip = true),
# CubaDivonne() => (nout = Inf, allows_batch = true, min_dim = 2,
# max_dim = Inf, allows_iip = true),
# CubaCuhre() => (nout = Inf, allows_batch = true, min_dim = 2, max_dim = Inf,
# allows_iip = true),
CubaSUAVE() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf,
allows_iip = true),
CubaDivonne() => (nout = Inf, allows_batch = true, min_dim = 2,
max_dim = Inf, allows_iip = true),
CubaCuhre() => (nout = Inf, allows_batch = true, min_dim = 2, max_dim = Inf,
allows_iip = true),
)
# helper function / test runner
scalarize_solution = (
Expand All @@ -41,7 +41,7 @@ do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
end
testf(lb, ub, p)

dlb1, dub1, dp1 = Zygote.gradient(testf, lb, ub, p)
# dlb1, dub1, dp1 = Zygote.gradient(testf, lb, ub, p)

f_lb = lb -> testf(lb, ub, p)
f_ub = ub -> testf(lb, ub, p)
Expand All @@ -52,13 +52,13 @@ do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
dlb2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dlb))(f_lb, lb)
dub2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dub))(f_ub, ub)

if lb isa Number
@test dlb1 dlb2 atol=abstol rtol=reltol
@test dub1 dub2 atol=abstol rtol=reltol
else # TODO: implement multivariate limit derivatives in ZygoteExt
@test_broken dlb1 dlb2 atol=abstol rtol=reltol
@test_broken dub1 dub2 atol=abstol rtol=reltol
end
# if lb isa Number
# @test dlb1 ≈ dlb2 atol=abstol rtol=reltol
# @test dub1 ≈ dub2 atol=abstol rtol=reltol
# else # TODO: implement multivariate limit derivatives in ZygoteExt
# @test_broken dlb1 ≈ dlb2 atol=abstol rtol=reltol
# @test_broken dub1 ≈ dub2 atol=abstol rtol=reltol
# end

# TODO: implement limit derivatives in ForwardDiffExt
@test_broken dlb2 getproperty(ForwardDiff, dlb)(dfdlb, lb) atol=abstol rtol=reltol
Expand All @@ -71,14 +71,49 @@ do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol)
dp2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dp))(f_p, p)
dp3 = getproperty(ForwardDiff, dp)(f_p, p)

@test dp1 dp2 atol=abstol rtol=reltol
@test dp1 dp3 atol=abstol rtol=reltol
# @test dp1 ≈ dp2 atol=abstol rtol=reltol
@test dp2 dp3 atol=abstol rtol=reltol

return
end

### One Dimensional
f_1d_scalar = (x, p) -> sum(q -> sin(q*x), p)
f_1d_nout = (x, p) -> map(q -> q*x, p)
f_nd_scalar = (x, p) -> prod(y -> f_1d_scalar(y, p), x)
f_nd_nout = (x, p) -> mapreduce(y -> f_1d_nout(y, p), +, x)

f_1d_scalar_iip = (y, x, p) -> y .= f_1d_scalar(x, p)
f_1d_nout_iip = (y, x, p) -> y .= f_1d_nout(x, p)
f_nd_scalar_iip = (y, x, p) -> y .= f_nd_scalar(x, p)
f_nd_nout_iip = (y, x, p) -> y .= f_nd_nout(x, p)

bf_helper = (f, x, p) -> begin
elt = typeof(zero(eltype(x))*zero(eltype(p))) # output type of above functions
if p isa AbstractArray
# p and f_*_nout are of size nout
# this is like a call to stack that should also work for empty arrays
out = similar(p, elt, size(p)..., size(x, ndims(x)))
for (v,y) in zip(eachslice(out; dims=ndims(out)), eachslice(x; dims=ndims(x)))
v .= f(x isa AbstractVector ? only(y) : y, p)
end
out
else
elt[f(x isa AbstractVector ? only(y) : y, p) for y in eachslice(x; dims=ndims(x))]
end
end

bf_1d_scalar = (x, p) -> bf_helper(f_1d_scalar, x, p)
bf_1d_nout = (x, p) -> bf_helper(f_1d_nout, x, p)
bf_nd_scalar = (x, p) -> bf_helper(f_nd_scalar, x, p)
bf_nd_nout = (x, p) -> bf_helper(f_nd_nout, x, p)

bf_1d_nout_iip = (y, x, p) -> y .= bf_1d_nout(x, p)
bf_1d_scalar_iip = (y, x, p) -> y .= bf_1d_scalar(x, p)
bf_nd_scalar_iip = (y, x, p) -> y .= bf_nd_scalar(x, p)
bf_nd_nout_iip = (y, x, p) -> y .= bf_nd_nout(x, p)


### One Dimensional
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution)
req.nout > 1 || continue
req.min_dim > 0 || continue
Expand All @@ -88,7 +123,6 @@ for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution
end

## One-dimensional nout
f_1d_nout = (x, p) -> map(q -> q*x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), nout in 1:max_nout_test
req.nout > 1 || continue
req.min_dim > 0 || continue
Expand All @@ -98,17 +132,15 @@ for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution
end

### N-dimensional
f_nd_scalar = (x, p) -> prod(y -> f_1d_scalar(y, p), x)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), dim in 1:max_dim_test
req.nout > 1 || continue
req.min_dim <= dim <= req.max_dim || continue

@info "Multi-dimensional, scalar, oop derivative test" alg scalarize=i, dim
@info "Multi-dimensional, scalar, oop derivative test" alg scalarize=i dim
do_tests(; f=f_nd_scalar, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol)
end

### N-dimensional nout
f_nd_nout = (x, p) -> mapreduce(y -> f_1d_nout(y, p), +, x)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), dim in 1:max_dim_test, nout in 1:max_nout_test
req.nout > 1 || continue
req.min_dim <= dim <= req.max_dim || continue
Expand All @@ -118,7 +150,6 @@ for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution
end

### One Dimensional
f_1d_scalar_iip = (y, x, p) -> y .= f_1d_scalar(x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution)
req.nout > 1 || continue
req.min_dim > 0 || continue
Expand All @@ -128,38 +159,35 @@ for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution
end

## One-dimensional nout
f_1d_nout_iip = (y, x, p) -> y .= f_1d_nout(x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), nout in 1:max_nout_test
req.nout > 1 || continue
req.min_dim > 0 || continue

@info "One-dimensional, multivariate, iip derivative test" alg scalarize=i nout
do_tests(; f=IntegralFunction(f_1d_nout, zeros(nout)), scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol)
do_tests(; f=IntegralFunction(f_1d_nout_iip, zeros(nout)), scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol)
end

### N-dimensional
f_nd_scalar_iip = (y, x, p) -> y .= f_nd_scalar(x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), dim in 1:max_dim_test
req.nout > 1 || continue
req.min_dim <= dim <= req.max_dim || continue

@info "Multi-dimensional, scalar, iip derivative test" alg scalarize=i, dim
@info "Multi-dimensional, scalar, iip derivative test" alg scalarize=i dim
do_tests(; f=IntegralFunction(f_nd_scalar_iip, zeros(1)), scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol)
end

### N-dimensional nout iip
f_nd_nout_iip = (y, x, p) -> y .= f_nd_nout(x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), dim in 1:max_dim_test, nout in 1:max_nout_test
req.nout > 1 || continue
req.min_dim <= dim <= req.max_dim || continue

@info "Multi-dimensional, multivariate, iip derivative test" alg scalarize=i dim nout
do_tests(; f=IntegralFunction(f_nd_nout, zeros(nout)), scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol)
do_tests(; f=IntegralFunction(f_nd_nout_iip, zeros(nout)), scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol)
end

### Batch, One Dimensional
bf_1d_scalar = (x, p) -> map(y -> f_1d_scalar(y, p), x)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution)
req.allows_batch || continue
req.nout > 1 || continue
req.min_dim > 0 || continue

Expand All @@ -168,8 +196,8 @@ for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution
end

## Batch, One-dimensional nout
bf_1d_nout = (x, p) -> stack(y -> f_1d_nout(y, p), x)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), nout in 1:max_nout_test
req.allows_batch || continue
req.nout > 1 || continue
req.min_dim > 0 || continue

Expand All @@ -178,18 +206,18 @@ for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution
end

### Batch, N-dimensional
bf_nd_scalar = (x, p) -> map(y -> f_nd_scalar(y, p), eachslice(x; dims=ndims(x)))
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), dim in 1:max_dim_test
req.allows_batch || continue
req.nout > 1 || continue
req.min_dim <= dim <= req.max_dim || continue

@info "Batched, multi-dimensional, scalar, oop derivative test" alg scalarize=i, dim
@info "Batched, multi-dimensional, scalar, oop derivative test" alg scalarize=i dim
do_tests(; f=BatchIntegralFunction(bf_nd_scalar), scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol)
end

### Batch, N-dimensional nout
bf_nd_nout = (x, p) -> stack(y -> f_nd_nout(y, p), eachslice(x; dims=ndims(x)))
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), dim in 1:max_dim_test, nout in 1:max_nout_test
req.allows_batch || continue
req.nout > 1 || continue
req.min_dim <= dim <= req.max_dim || continue

Expand All @@ -198,41 +226,41 @@ for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution
end

### Batch, one-dimensional
bf_1d_scalar_iip = (y, x, p) -> y .= bf_1d_scalar(x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution)
req.allows_batch || continue
req.nout > 1 || continue
req.min_dim > 0 || continue

@info "Batched, one-dimensional, scalar, iip derivative test" alg scalarize=i
do_tests(; f=BatchIntegralFunction(bf_1d_scalar_iip, zeros(1, 0)), scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol)
do_tests(; f=BatchIntegralFunction(bf_1d_scalar_iip, zeros(0)), scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol)
end

## Batch, one-dimensional nout
bf_1d_nout_iip = (y, x, p) -> y .= bf_1d_nout(x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), nout in 1:max_nout_test
req.allows_batch || continue
req.nout > 1 || continue
req.min_dim > 0 || continue

@info "Batched, one-dimensional, multivariate, iip derivative test" alg scalarize=i nout
do_tests(; f=BatchIntegralFunction(bf_1d_nout, zeros(nout, 0)), scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol)
do_tests(; f=BatchIntegralFunction(bf_1d_nout_iip, zeros(nout, 0)), scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol)
end

### Batch, N-dimensional
bf_nd_scalar_iip = (y, x, p) -> y .= bf_nd_scalar(x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), dim in 1:max_dim_test
req.allows_batch || continue
req.nout > 1 || continue
req.min_dim <= dim <= req.max_dim || continue

@info "Batched, multi-dimensional, scalar, iip derivative test" alg scalarize=i, dim
do_tests(; f=BatchIntegralFunction(bf_nd_scalar_iip, zeros(1, 0)), scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol)
@info "Batched, multi-dimensional, scalar, iip derivative test" alg scalarize=i dim
do_tests(; f=BatchIntegralFunction(bf_nd_scalar_iip, zeros(0)), scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol)
end

### Batch, N-dimensional nout iip
bf_nd_nout_iip = (y, x, p) -> y .= bf_nd_nout(x, p)
for (alg, req) in pairs(alg_req), (i, scalarize) in enumerate(scalarize_solution), dim in 1:max_dim_test, nout in 1:max_nout_test
req.allows_batch || continue
req.nout > 1 || continue
req.min_dim <= dim <= req.max_dim || continue

@info "Batched, multi-dimensional, multivariate, iip derivative test" alg scalarize=i dim nout
do_tests(; f=BatchIntegralFunction(bf_nd_nout, zeros(nout, 0)), scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol)
do_tests(; f=BatchIntegralFunction(bf_nd_nout_iip, zeros(nout, 0)), scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol)
end

0 comments on commit 22dabde

Please sign in to comment.