diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index b30eca7af..41d2d911d 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2088,9 +2088,10 @@ out-of-place handling. IntegralFunction{iip,specialize}(f, [integrand_prototype]) ``` -Note that only `f` is required, and in the case of inplace integrands a mutable container +Note that only `f` is required, and in the case of inplace integrands a mutable array `integrand_prototype` to store the result of the integrand. If `integrand_prototype` is -present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place. +present for either in-place or out-of-place integrands it is used to infer the return type +of the integrand. ## iip: In-Place vs Out-Of-Place @@ -2148,14 +2149,14 @@ BatchIntegralFunction{iip,specialize}(bf, [integrand_prototype]; max_batch=typemax(Int)) ``` Note that only `bf` is required, and in the case of inplace integrands a mutable -container `integrand_prototype` to store a batch of integrand evaluations, with +array `integrand_prototype` to store a batch of integrand evaluations, with a last "batching" dimension. The keyword `max_batch` is used to set a soft limit on the number of points to batch at the same time so that memory usage is controlled. -If `integrand_prototype` is present, `bf` is interpreted as in-place, and -otherwise `bf` is assumed to be out-of-place. +If `integrand_prototype` is present for either in-place or out-of-place integrands it is +used to infer the return type of the integrand. ## iip: In-Place vs Out-Of-Place @@ -3194,7 +3195,8 @@ function DAEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {iip, + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where { + iip, specialize } if jac === nothing && isa(jac_prototype, AbstractSciMLOperator) @@ -3890,10 +3892,7 @@ function IntegralFunction(f) end function IntegralFunction(f, integrand_prototype) calculated_iip = isinplace(f, 3, "integral", true) - if !calculated_iip - throw(IntegrandMismatchFunctionError(calculated_iip, true)) - end - IntegralFunction{true}(f, integrand_prototype) + IntegralFunction{calculated_iip}(f, integrand_prototype) end function BatchIntegralFunction{iip, specialize}(f, integrand_prototype; @@ -3926,10 +3925,7 @@ function BatchIntegralFunction(f; kwargs...) end function BatchIntegralFunction(f, integrand_prototype; kwargs...) calculated_iip = isinplace(f, 3, "batchintegral", true) - if !calculated_iip - throw(IntegrandMismatchFunctionError(calculated_iip, true)) - end - BatchIntegralFunction{true}(f, integrand_prototype; kwargs...) + BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...) end ########## Utility functions diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index 7f0ba65ab..daf1c8565 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -474,6 +474,7 @@ intfiip(y, u, p) = y .= 1.0 for (f, kws, iip) in ( (intf, (;), false), (IntegralFunction(intf), (;), false), + (IntegralFunction(intf, 1.0), (;), false), (intfiip, (; nout = 3), true), (IntegralFunction(intfiip, zeros(3)), (;), true) ), domain in (((0.0, 1.0),), (([0.0], [1.0]),), (0.0, 1.0), ([0.0], [1.0])) @@ -648,9 +649,9 @@ i1(u) = u itoo(y, u, p, a) = y .= u * p IntegralFunction(ioop) +IntegralFunction(ioop, 0.0) IntegralFunction(iiip, Float64[]) -@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(ioop, Float64[]) @test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(iiip) @test_throws SciMLBase.TooFewArgumentsError IntegralFunction(i1) @test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo) @@ -665,10 +666,11 @@ bitoo(y, u, p, a) = y .= p .* u BatchIntegralFunction(boop) BatchIntegralFunction(boop, max_batch = 20) +BatchIntegralFunction(boop, Float64[]) +BatchIntegralFunction(boop, Float64[], max_batch = 20) BatchIntegralFunction(biip, Float64[]) BatchIntegralFunction(biip, Float64[], max_batch = 20) -@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(boop, Float64[]) @test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(biip) @test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1) @test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo)