Skip to content

Commit

Permalink
Merge pull request #535 from lxvm/batchintegrand
Browse files Browse the repository at this point in the history
BatchIntegralFunction revisions
  • Loading branch information
ChrisRackauckas authored Nov 1, 2023
2 parents 6b0a385 + 10307a5 commit dee9a98
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
37 changes: 28 additions & 9 deletions src/problems/basic_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,10 @@ which are `Number`s or `AbstractVector`s with the same geometry as `u`.
### Constructors
```
IntegralProblem(f,domain,p=NullParameters(); kwargs...)
IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...)
IntegralProblem(f::AbstractIntegralFunction,domain,p=NullParameters(); kwargs...)
IntegralProblem(f::AbstractIntegralFunction,lb,ub,p=NullParameters(); kwargs...)
IntegralProblem(f,domain,p=NullParameters(); nout=nothing, batch=nothing, kwargs...)
IntegralProblem(f,lb,ub,p=NullParameters(); nout=nothing, batch=nothing, kwargs...)
```
- f: the integrand, callable function `y = f(u,p)` for out-of-place (default) or an
Expand All @@ -424,6 +426,10 @@ IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...)
- lb: Either a number or vector of lower bounds.
- ub: Either a number or vector of upper bounds.
- p: The parameters associated with the problem.
- nout: DEPRECATED (see `IntegralFunction`): length of the vector output of the integrand
(by default the integrand is assumed to be scalar)
- batch: DEPRECATED (see `BatchIntegralFunction`): number of points the integrand can
evaluate simultaneously (by default there is no batching)
- kwargs: Keyword arguments copied to the solvers.
Additionally, we can supply iip like IntegralProblem{iip}(...) as true or false to declare at
Expand Down Expand Up @@ -461,32 +467,45 @@ function IntegralProblem(f::AbstractIntegralFunction,
ub::B,
p = NullParameters();
kwargs...) where {B}
IntegralProblem(f, (lb, ub), p; kwargs...)
IntegralProblem{isinplace(f)}(f, (lb, ub), p; kwargs...)
end

function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...)
if nout !== nothing || batch !== nothing
@warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details."
end

max_batch = batch === nothing ? 0 : batch
g = if isinplace(f, 3)
output_prototype = Vector{Float64}(undef, nout === nothing ? 1 : nout)
if max_batch == 0
if batch === nothing
output_prototype = nout === nothing ? Array{Float64, 0}(undef) : Vector{Float64}(undef, nout)
IntegralFunction(f, output_prototype)
else
BatchIntegralFunction(f, output_prototype, max_batch=max_batch)
output_prototype = nout === nothing ? Float64[] : Matrix{Float64}(undef, nout, 0)
BatchIntegralFunction(f, output_prototype, max_batch=batch)
end
else
if max_batch == 0
if batch === nothing
IntegralFunction(f)
else
BatchIntegralFunction(f, max_batch=max_batch)
BatchIntegralFunction(f, max_batch=batch)
end
end
IntegralProblem(g, args...; kwargs...)
end

function Base.getproperty(prob::IntegralProblem, name::Symbol)
if name === :lb
domain = getfield(prob, :domain)
lb, ub = domain
return lb
elseif name === :ub
domain = getfield(prob, :domain)
lb, ub = domain
return ub
end
return Base.getfield(prob, name)
end

struct QuadratureProblem end
@deprecate QuadratureProblem(args...; kwargs...) IntegralProblem(args...; kwargs...)

Expand Down
17 changes: 9 additions & 8 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2361,8 +2361,8 @@ BatchIntegralFunction{iip,specialize}(f, [integrand_prototype];
max_batch=typemax(Int))
```
Note that only `f` is required, and in the case of inplace integrands a mutable container
`integrand_prototype` to store the result of the integrand of one integrand, without a last
"batching" dimension.
`integrand_prototype` to store a batch of integrand evaluations, with a last "batching"
dimension.
The keyword `max_batch` is used to set a soft limit on the number of points to batch at the
same time so that memory usage is controlled.
Expand All @@ -2375,12 +2375,13 @@ assumed to be out-of-place.
Out-of-place functions must be of the form ``y = f(u,p)`` and in-place functions of the form
``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or
arrays), in-place functions must provide a container `integrand_prototype` of the right type
for a single integrand evaluation. The integration algorithm will then allocate a ``y``
array with the same element type as `integrand_prototype` and an additional last "batching"
dimension to store multiple integrand evaluations. In the out-of-place case, the algorithm
may infer the type of ``y`` by passing `f` an empty array of input points. This means ``y``
is a vector in the out-of-place case, or a matrix/array in the in-place case. The number of
batched points may vary between subsequent calls to `f`. When in-place forms are used,
for ``y``. The only assumption that is enforced is that the last axes of `the `y`` and ``u``
arrays are the same length and correspond to distinct batched points. The algorithm will
then allocate arrays `similar` to ``y`` to pass to the integrand. Since the algorithm may
vary the number of points to batch, the length of the batching dimension of ``y`` may vary
between subsequent calls to `f`. To reduce allocations, views of ``y`` may also be passed to
the integrand. In the out-of-place case, the algorithm may infer the type
of ``y`` by passing `f` an empty array of input points. When in-place forms are used,
in-place array operations may be used by algorithms to reduce allocations. If
`integrand_prototype` is not provided, `f` is assumed to be out-of-place.
Expand Down

0 comments on commit dee9a98

Please sign in to comment.