Skip to content

Commit

Permalink
remove output_prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Sep 21, 2023
1 parent e27965d commit 5be7d7a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 48 deletions.
78 changes: 40 additions & 38 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2283,7 +2283,7 @@ end
TruncatedStacktraces.@truncate_stacktrace BVPFunction 1 2

@doc doc"""
IntegralFunction{iip,specialize,F} <: AbstractIntegralFunction{iip}
IntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip}
A representation of an integrand `f` defined by:
Expand All @@ -2304,7 +2304,7 @@ present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out-
## iip: In-Place vs Out-Of-Place
Out-of-place functions must be of the form ``f(u, p)`` and in-place functions of the form
Out-of-place functions must be of the form ``y = f(u, p)`` and in-place functions of the form
``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or
arrays), in-place functions must provide a container `integrand_prototype` that is of the
right type for the variable ``y``, and the result is written to this container in-place.
Expand All @@ -2329,30 +2329,30 @@ end
TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2

@doc doc"""
BatchIntegralFunction{iip,specialize,F,T,Y,J,TJ,TPJ,Ta,S,JP,SP,TCV,O} <:
AbstractIntegralFunction{iip}
BatchIntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip}
A representation of an integrand `f` that can be evaluated at multiple points simultaneously
using threads, the gpu, or distributed memory defined by:
```math
f(y, u, p)
y = f(u, p)
```
``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose
output must be returned in the corresponding entries of ``y``. In general, the integration
algorithm is allowed to vary the number of evaluation points between subsequent calls to `f`
output must be returned as an array whose last "batching" dimension corresponds to integrand
evaluations at the different points in ``u``. In general, the integration algorithm is
allowed to vary the number of evaluation points between subsequent calls to `f`.
For an in-place form of `f` see the `iip` section below for details on in-place or
out-of-place handling.
```julia
BatchIntegralFunction{iip,specialize}(f, output_prototype, [integrand_prototype];
BatchIntegralFunction{iip,specialize}(f, [integrand_prototype];
max_batch=typemax(Int))
```
Note that `f` is required and a `resize`-able buffer `output_prototype` to store the output,
or range of `f`, consisting of multiple integrand evaluations, and in the case of inplace
integrands a mutable container `integrand_prototype` to store the result of one integrand
evaluation. These buffers can be reused across multiple compatible integrals to reduce
allocations.
Note that only `f` is required, and in the case of inplace integrands a mutable container
`integrand_prototype` to store the result of the integrand of one integrand, without a last
"batching" dimension.
The keyword `max_batch` is used to set a soft limit on the number of points to batch at the
same time so that memory usage is controlled.
Expand All @@ -2362,17 +2362,17 @@ assumed to be out-of-place.
## iip: In-Place vs Out-Of-Place
Out-of-place and in-place functions are both of the form ``f(y, u, p)``, but differ in the
type of ``y``. Since `f` is allowed to return any type (e.g. real or complex numbers or
arrays), in-place functions must provide a container `output_prototype` of the right type
for ``y`` that stores multiple integrand evaluations. Typically, this means
`output_prototype` is a vector in the out-of-place case, or a matrix/array whose last
dimension is a batch index in the in-place case. For the in-place case, an
`integrand_prototype` is required and must be a container of the right type for a single
integrand evaluation. When `f` is in-place, the buffer `output_prototype` should be of the
type used by the integrand. When in-place forms are used, in-place array operations may be
used by algorithms to reduce allocations. If `integrand_prototype` is not provided, `f` is
assumed to be out-of-place.
Out-of-place functions must be of the form ``y = f(u,p)`` and in-place functions of the form
``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or
arrays), in-place functions must provide a container `integrand_prototype` of the right type
for a single integrand evaluation. The integration algorithm will then allocate a ``y``
array with the same element type as `integrand_prototype` and an additional last "batching"
dimension to store multiple integrand evaluations. In the out-of-place case, the algorithm
may infer the type of ``y`` by passing `f` an empty array of input points. This means ``y``
is a vector in the out-of-place case, or a matrix/array in the in-place case. The number of
batched points may vary between subsequent calls to `f`. When in-place forms are used,
in-place array operations may be used by algorithms to reduce allocations. If
`integrand_prototype` is not provided, `f` is assumed to be out-of-place.
## specialize
Expand All @@ -2382,10 +2382,9 @@ This field is currently unused
The fields of the BatchIntegralFunction type directly match the names of the inputs.
"""
struct BatchIntegralFunction{iip, specialize, F, Y, T} <:
struct BatchIntegralFunction{iip, specialize, F, T} <:
AbstractIntegralFunction{iip}
f::F
output_prototype::Y
integrand_prototype::T
max_batch::Int
end
Expand Down Expand Up @@ -4111,36 +4110,39 @@ function IntegralFunction(f, integrand_prototype)
IntegralFunction{true}(f, integrand_prototype)
end

function BatchIntegralFunction{iip, specialize}(f, output_prototype, integrand_prototype;
function BatchIntegralFunction{iip, specialize}(f, integrand_prototype;
max_batch::Integer = typemax(Int)) where {iip, specialize}
BatchIntegralFunction{
iip,
specialize,
typeof(f),
typeof(output_prototype),
typeof(integrand_prototype),
}(f,
output_prototype,
integrand_prototype,
max_batch)
end

function BatchIntegralFunction{iip}(f,
output_prototype,
integrand_prototype;
kwargs...) where {iip}
return BatchIntegralFunction{iip, FullSpecialize}(f,
output_prototype,
integrand_prototype;
kwargs...)
end
function BatchIntegralFunction(f, output_prototype; kwargs...)
calcuated_iip = isinplace(f, 3, "batchintegral", true; has_two_dispatches = false)
BatchIntegralFunction{false}(f, output_prototype, nothing; kwargs...)

function BatchIntegralFunction(f; kwargs...)
calculated_iip = isinplace(f, 3, "batchintegral", true)
if calculated_iip
throw(IntegrandMismatchFunctionError(calculated_iip, false))
end
BatchIntegralFunction{false}(f, nothing; kwargs...)
end
function BatchIntegralFunction(f, output_prototype, integrand_prototype; kwargs...)
calcuated_iip = isinplace(f, 3, "batchintegral", true; has_two_dispatches = false)
BatchIntegralFunction{true}(f, output_prototype, integrand_prototype; kwargs...)
function BatchIntegralFunction(f, integrand_prototype; kwargs...)
calculated_iip = isinplace(f, 3, "batchintegral", true)
if !calculated_iip
throw(IntegrandMismatchFunctionError(calculated_iip, true))
end
BatchIntegralFunction{true}(f, integrand_prototype; kwargs...)
end

########## Existence Functions
Expand Down
22 changes: 12 additions & 10 deletions test/function_building_error_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -620,16 +620,18 @@ IntegralFunction(iiip, Float64[])

# BatchIntegralFunction

boop(y, u, p) = y .= p .* u
biip(y, u, p) = y .= p .* u # this example is not realistic
bi1(y, u) = y .= p .* u
boop(u, p) = p .* u
biip(y, u, p) = y .= p .* u
bi1(u) = u
bitoo(y, u, p, a) = y .= p .* u

BatchIntegralFunction(boop, Float64[])
BatchIntegralFunction(boop, Float64[], max_batch = 20)
BatchIntegralFunction(biip, Float64[], Float64[]) # the 2nd argument should be an ElasticArray
@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1, Float64[])
@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo,
Float64[],
Float64[])
BatchIntegralFunction(boop)
BatchIntegralFunction(boop, max_batch = 20)
BatchIntegralFunction(biip, Float64[])
BatchIntegralFunction(biip, Float64[], max_batch = 20)

@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(boop, Float64[])
@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(biip)
@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1)
@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo)
@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo, Float64[])

0 comments on commit 5be7d7a

Please sign in to comment.