Skip to content

Commit

Permalink
Merge pull request #636 from lxvm/proto
Browse files Browse the repository at this point in the history
allow `integrand_prototype` for oop integral functions
  • Loading branch information
ChrisRackauckas authored Feb 27, 2024
2 parents 14c0b32 + 263f910 commit 400f292
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
24 changes: 10 additions & 14 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2056,9 +2056,10 @@ out-of-place handling.
IntegralFunction{iip,specialize}(f, [integrand_prototype])
```
Note that only `f` is required, and in the case of inplace integrands a mutable container
Note that only `f` is required, and in the case of inplace integrands a mutable array
`integrand_prototype` to store the result of the integrand. If `integrand_prototype` is
present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place.
present for either in-place or out-of-place integrands it is used to infer the return type
of the integrand.
## iip: In-Place vs Out-Of-Place
Expand Down Expand Up @@ -2114,14 +2115,14 @@ BatchIntegralFunction{iip,specialize}(bf, [integrand_prototype];
max_batch=typemax(Int))
```
Note that only `bf` is required, and in the case of inplace integrands a mutable
container `integrand_prototype` to store a batch of integrand evaluations, with
array `integrand_prototype` to store a batch of integrand evaluations, with
a last "batching" dimension.
The keyword `max_batch` is used to set a soft limit on the number of points to
batch at the same time so that memory usage is controlled.
If `integrand_prototype` is present, `bf` is interpreted as in-place, and
otherwise `bf` is assumed to be out-of-place.
If `integrand_prototype` is present for either in-place or out-of-place integrands it is
used to infer the return type of the integrand.
## iip: In-Place vs Out-Of-Place
Expand Down Expand Up @@ -3158,7 +3159,8 @@ function DAEFunction{iip, specialize}(f;
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {iip,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {
iip,
specialize
}
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
Expand Down Expand Up @@ -3854,10 +3856,7 @@ function IntegralFunction(f)
end
function IntegralFunction(f, integrand_prototype)
calculated_iip = isinplace(f, 3, "integral", true)
if !calculated_iip
throw(IntegrandMismatchFunctionError(calculated_iip, true))
end
IntegralFunction{true}(f, integrand_prototype)
IntegralFunction{calculated_iip}(f, integrand_prototype)
end

function BatchIntegralFunction{iip, specialize}(f, integrand_prototype;
Expand Down Expand Up @@ -3890,10 +3889,7 @@ function BatchIntegralFunction(f; kwargs...)
end
function BatchIntegralFunction(f, integrand_prototype; kwargs...)
calculated_iip = isinplace(f, 3, "batchintegral", true)
if !calculated_iip
throw(IntegrandMismatchFunctionError(calculated_iip, true))
end
BatchIntegralFunction{true}(f, integrand_prototype; kwargs...)
BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...)
end

########## Utility functions
Expand Down
6 changes: 4 additions & 2 deletions test/function_building_error_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ intfiip(y, u, p) = y .= 1.0
for (f, kws, iip) in (
(intf, (;), false),
(IntegralFunction(intf), (;), false),
(IntegralFunction(intf, 1.0), (;), false),
(intfiip, (; nout = 3), true),
(IntegralFunction(intfiip, zeros(3)), (;), true)
), domain in (((0.0, 1.0),), (([0.0], [1.0]),), (0.0, 1.0), ([0.0], [1.0]))
Expand Down Expand Up @@ -648,9 +649,9 @@ i1(u) = u
itoo(y, u, p, a) = y .= u * p

IntegralFunction(ioop)
IntegralFunction(ioop, 0.0)
IntegralFunction(iiip, Float64[])

@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(ioop, Float64[])
@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(iiip)
@test_throws SciMLBase.TooFewArgumentsError IntegralFunction(i1)
@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo)
Expand All @@ -665,10 +666,11 @@ bitoo(y, u, p, a) = y .= p .* u

BatchIntegralFunction(boop)
BatchIntegralFunction(boop, max_batch = 20)
BatchIntegralFunction(boop, Float64[])
BatchIntegralFunction(boop, Float64[], max_batch = 20)
BatchIntegralFunction(biip, Float64[])
BatchIntegralFunction(biip, Float64[], max_batch = 20)

@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(boop, Float64[])
@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(biip)
@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1)
@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo)
Expand Down

0 comments on commit 400f292

Please sign in to comment.