Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ForwardDiff directly on all non-C solvers #239

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions docs/src/basics/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The in-place interface allows evaluating vector-valued integrands without
allocating an output array. This can be beneficial for reducing allocations when
integrating many functions simultaneously or to make use of existing in-place
code. However, note that not all algorithms use in-place operations under the
hood, i.e. `HCubatureJL()`, and may still allocate.
hood, i.e. [`HCubatureJL`](@ref), and may still allocate.

You can construct an `IntegralFunction(f, prototype)`, where `f` is of the form
`f(y, u, p)` where `prototype` is of the desired type and shape of `y`.
Expand All @@ -22,16 +22,17 @@ different points, which maximizes the parallelism for a given algorithm.
You can construct an out-of-place `BatchIntegralFunction(bf)` where `bf` is of
the form `bf(u, p) = stack(x -> f(x, p), eachslice(u; dims=ndims(u)))`, where
`f` is the (unbatched) integrand.
For interoperability with as many algorithms as possible, it is important that your out-of-place batch integrand accept an **empty** array of quadrature points and still return an output with a size and type consistent with the non-empty case.

You can construct an in-place `BatchIntegralFunction(bf, prototype)`, where `bf`
is of the form `bf(y, u, p) = foreach((y,x) -> f(y,x,p), eachslice(y, dims=ndims(y)), eachslice(x, dims=ndims(x)))`.

Note that not all algorithms use in-place batched operations under the hood,
i.e. `QuadGKJL()`.
i.e. [`QuadGKJL`](@ref).

## What should I do if my solution is not converged?

Certain algorithms, such as `QuadratureRule` used a fixed number of points to
Certain algorithms, such as [`QuadratureRule`](@ref) used a fixed number of points to
calculate an integral and cannot provide an error estimate. In this case, you
have to increase the number of points and check the convergence yourself, which
will depend on the accuracy of the rule you choose.
Expand All @@ -47,7 +48,7 @@ precision arithmetic may help.

## How can I integrate arbitrarily-spaced data?

See `SampledIntegralProblem`.
See [`SampledIntegralProblem`](@ref).

## How can I integrate on arbitrary geometries?

Expand All @@ -59,6 +60,13 @@ because that is what lower-level packages implement.
Fixed quadrature rules from other packages can be used with `QuadratureRule`.
Otherwise, feel free to open an issue or pull request.

## My integrand works with algorithm X but fails on algorithm Y

While bugs are not out of the question, certain algorithms, especially those implemented in C, are not compatible with arbitrary Julia types and have to return specific numeric types or arrays thereof.
In some cases, such as [`ArblibJL`](@ref), it is also expected that the integrand work with a custom quadrature point type.
Moreover, some algorithms, such as [`VEGAS`](@ref), only support scalar integrands.
For more details see the [solver page](@ref solvers).

## Can I take derivatives with respect to the limits of integration?

Currently this is not implemented.
41 changes: 26 additions & 15 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,37 @@ using Integrals
isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
### Forward-Mode AD Intercepts

#= Direct AD on solvers with QuadGK and HCubature
# incompatible with iip since types must change
function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, domain,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
end
# Default to direct AD on solvers
function Integrals.__solvebp(cache, alg, sensealg, domain,
p::Union{D,AbstractArray{<:D}};
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}

function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, domain,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
if isinplace(cache.f)
prototype = cache.f.integrand_prototype
elt = eltype(prototype)
ForwardDiff.can_dual(elt) || throw(ArgumentError("ForwardDiff of in-place integrands only supports prototypes with real elements"))
dprototype = similar(prototype, replace_dualvaltype(D, elt))
df = if cache.f isa BatchIntegralFunction
BatchIntegralFunction{true}(cache.f.f, dprototype)
else
IntegralFunction{true}(cache.f.f, dprototype)
end
prob = Integrals.build_problem(cache)
dprob = remake(prob, f = df)
dcache = init(dprob, alg; sensealg = sensealg, do_inf_transformation=Val(false), kwargs...)
Integrals.__solvebp_call(dcache, alg, sensealg, domain, p; kwargs...)
else
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
end
end
=#


# TODO: add the pushforward for derivative w.r.t lb, and ub (and then combinations?)

# Manually split for the pushforward
function Integrals.__solvebp(cache, alg, sensealg, domain,
p::Union{D, AbstractArray{<:D}};
kwargs...) where {T, V, P, D <: ForwardDiff.Dual{T, V, P}}
function Integrals.__solvebp(cache, alg::Integrals.AbstractIntegralCExtensionAlgorithm, sensealg, domain,
p::Union{D,AbstractArray{<:D}};
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}

# we need the output type to avoid perturbation confusion while unwrapping nested duals
# We compute a vector-valued integral of the primal and dual simultaneously
Expand Down Expand Up @@ -73,6 +83,7 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
end
end

DT <: Real || throw(ArgumentError("differentiating algorithms in C"))
ForwardDiff.can_dual(elt) || ForwardDiff.throw_cannot_dual(elt)
rawp = p isa D ? reinterpret(V, [p]) : copy(reinterpret(V, vec(p)))

Expand Down
7 changes: 4 additions & 3 deletions src/algorithms_extension.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
## Extension Algorithms

abstract type AbstractIntegralExtensionAlgorithm <: SciMLBase.AbstractIntegralAlgorithm end
abstract type AbstractIntegralCExtensionAlgorithm <: AbstractIntegralExtensionAlgorithm end

abstract type AbstractCubaAlgorithm <: AbstractIntegralExtensionAlgorithm end
abstract type AbstractCubaAlgorithm <: AbstractIntegralCExtensionAlgorithm end

"""
CubaVegas()
Expand Down Expand Up @@ -152,7 +153,7 @@ function CubaCuhre(; flags = 0, minevals = 0, key = 0)
return CubaCuhre(flags, minevals, key)
end

abstract type AbstractCubatureJLAlgorithm <: AbstractIntegralExtensionAlgorithm end
abstract type AbstractCubatureJLAlgorithm <: AbstractIntegralCExtensionAlgorithm end

"""
CubatureJLh(; error_norm=Cubature.INDIVIDUAL)
Expand Down Expand Up @@ -219,7 +220,7 @@ documentation for additional details the algorithm arguments and on implementing
high-precision integrands. Additionally, the error estimate is included in the return value
of the integral, representing a ball.
"""
struct ArblibJL{O} <: AbstractIntegralExtensionAlgorithm
struct ArblibJL{O} <: AbstractIntegralCExtensionAlgorithm
check_analytic::Bool
take_prec::Bool
warn_on_no_convergence::Bool
Expand Down
Loading