From 5be7d7afbca32c2b84a688389c311d9efc83fed1 Mon Sep 17 00:00:00 2001 From: lxvm Date: Thu, 21 Sep 2023 14:35:12 -0400 Subject: [PATCH] remove output_prototype --- src/scimlfunctions.jl | 78 ++++++++++++------------ test/function_building_error_messages.jl | 22 ++++--- 2 files changed, 52 insertions(+), 48 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index c9aef6d0a..71da8be5a 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2283,7 +2283,7 @@ end TruncatedStacktraces.@truncate_stacktrace BVPFunction 1 2 @doc doc""" - IntegralFunction{iip,specialize,F} <: AbstractIntegralFunction{iip} + IntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip} A representation of an integrand `f` defined by: @@ -2304,7 +2304,7 @@ present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out- ## iip: In-Place vs Out-Of-Place -Out-of-place functions must be of the form ``f(u, p)`` and in-place functions of the form +Out-of-place functions must be of the form ``y = f(u, p)`` and in-place functions of the form ``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or arrays), in-place functions must provide a container `integrand_prototype` that is of the right type for the variable ``y``, and the result is written to this container in-place. @@ -2329,30 +2329,30 @@ end TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2 @doc doc""" -BatchIntegralFunction{iip,specialize,F,T,Y,J,TJ,TPJ,Ta,S,JP,SP,TCV,O} <: -AbstractIntegralFunction{iip} +BatchIntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip} A representation of an integrand `f` that can be evaluated at multiple points simultaneously using threads, the gpu, or distributed memory defined by: ```math -f(y, u, p) +y = f(u, p) ``` ``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose -output must be returned in the corresponding entries of ``y``. In general, the integration -algorithm is allowed to vary the number of evaluation points between subsequent calls to `f` +output must be returned as an array whose last "batching" dimension corresponds to integrand +evaluations at the different points in ``u``. In general, the integration algorithm is +allowed to vary the number of evaluation points between subsequent calls to `f`. + +For an in-place form of `f` see the `iip` section below for details on in-place or +out-of-place handling. ```julia -BatchIntegralFunction{iip,specialize}(f, output_prototype, [integrand_prototype]; +BatchIntegralFunction{iip,specialize}(f, [integrand_prototype]; max_batch=typemax(Int)) ``` - -Note that `f` is required and a `resize`-able buffer `output_prototype` to store the output, -or range of `f`, consisting of multiple integrand evaluations, and in the case of inplace -integrands a mutable container `integrand_prototype` to store the result of one integrand -evaluation. These buffers can be reused across multiple compatible integrals to reduce -allocations. +Note that only `f` is required, and in the case of inplace integrands a mutable container +`integrand_prototype` to store the result of the integrand of one integrand, without a last +"batching" dimension. The keyword `max_batch` is used to set a soft limit on the number of points to batch at the same time so that memory usage is controlled. @@ -2362,17 +2362,17 @@ assumed to be out-of-place. ## iip: In-Place vs Out-Of-Place -Out-of-place and in-place functions are both of the form ``f(y, u, p)``, but differ in the -type of ``y``. Since `f` is allowed to return any type (e.g. real or complex numbers or -arrays), in-place functions must provide a container `output_prototype` of the right type -for ``y`` that stores multiple integrand evaluations. Typically, this means -`output_prototype` is a vector in the out-of-place case, or a matrix/array whose last -dimension is a batch index in the in-place case. For the in-place case, an -`integrand_prototype` is required and must be a container of the right type for a single -integrand evaluation. When `f` is in-place, the buffer `output_prototype` should be of the -type used by the integrand. When in-place forms are used, in-place array operations may be -used by algorithms to reduce allocations. If `integrand_prototype` is not provided, `f` is -assumed to be out-of-place. +Out-of-place functions must be of the form ``y = f(u,p)`` and in-place functions of the form +``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or +arrays), in-place functions must provide a container `integrand_prototype` of the right type +for a single integrand evaluation. The integration algorithm will then allocate a ``y`` +array with the same element type as `integrand_prototype` and an additional last "batching" +dimension to store multiple integrand evaluations. In the out-of-place case, the algorithm +may infer the type of ``y`` by passing `f` an empty array of input points. This means ``y`` +is a vector in the out-of-place case, or a matrix/array in the in-place case. The number of +batched points may vary between subsequent calls to `f`. When in-place forms are used, +in-place array operations may be used by algorithms to reduce allocations. If +`integrand_prototype` is not provided, `f` is assumed to be out-of-place. ## specialize @@ -2382,10 +2382,9 @@ This field is currently unused The fields of the BatchIntegralFunction type directly match the names of the inputs. """ -struct BatchIntegralFunction{iip, specialize, F, Y, T} <: +struct BatchIntegralFunction{iip, specialize, F, T} <: AbstractIntegralFunction{iip} f::F - output_prototype::Y integrand_prototype::T max_batch::Int end @@ -4111,36 +4110,39 @@ function IntegralFunction(f, integrand_prototype) IntegralFunction{true}(f, integrand_prototype) end -function BatchIntegralFunction{iip, specialize}(f, output_prototype, integrand_prototype; +function BatchIntegralFunction{iip, specialize}(f, integrand_prototype; max_batch::Integer = typemax(Int)) where {iip, specialize} BatchIntegralFunction{ iip, specialize, typeof(f), - typeof(output_prototype), typeof(integrand_prototype), }(f, - output_prototype, integrand_prototype, max_batch) end function BatchIntegralFunction{iip}(f, - output_prototype, integrand_prototype; kwargs...) where {iip} return BatchIntegralFunction{iip, FullSpecialize}(f, - output_prototype, integrand_prototype; kwargs...) end -function BatchIntegralFunction(f, output_prototype; kwargs...) - calcuated_iip = isinplace(f, 3, "batchintegral", true; has_two_dispatches = false) - BatchIntegralFunction{false}(f, output_prototype, nothing; kwargs...) + +function BatchIntegralFunction(f; kwargs...) + calculated_iip = isinplace(f, 3, "batchintegral", true) + if calculated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, false)) + end + BatchIntegralFunction{false}(f, nothing; kwargs...) end -function BatchIntegralFunction(f, output_prototype, integrand_prototype; kwargs...) - calcuated_iip = isinplace(f, 3, "batchintegral", true; has_two_dispatches = false) - BatchIntegralFunction{true}(f, output_prototype, integrand_prototype; kwargs...) +function BatchIntegralFunction(f, integrand_prototype; kwargs...) + calculated_iip = isinplace(f, 3, "batchintegral", true) + if !calculated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, true)) + end + BatchIntegralFunction{true}(f, integrand_prototype; kwargs...) end ########## Existence Functions diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index 008758ddb..52b2054cc 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -620,16 +620,18 @@ IntegralFunction(iiip, Float64[]) # BatchIntegralFunction -boop(y, u, p) = y .= p .* u -biip(y, u, p) = y .= p .* u # this example is not realistic -bi1(y, u) = y .= p .* u +boop(u, p) = p .* u +biip(y, u, p) = y .= p .* u +bi1(u) = u bitoo(y, u, p, a) = y .= p .* u -BatchIntegralFunction(boop, Float64[]) -BatchIntegralFunction(boop, Float64[], max_batch = 20) -BatchIntegralFunction(biip, Float64[], Float64[]) # the 2nd argument should be an ElasticArray -@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1, Float64[]) -@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo, - Float64[], - Float64[]) +BatchIntegralFunction(boop) +BatchIntegralFunction(boop, max_batch = 20) +BatchIntegralFunction(biip, Float64[]) +BatchIntegralFunction(biip, Float64[], max_batch = 20) + +@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(boop, Float64[]) +@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(biip) +@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1) +@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo) @test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo, Float64[])