From 0cf8a7989124b30db6fb31cbad1724ba1696f5b1 Mon Sep 17 00:00:00 2001 From: lxvm Date: Mon, 26 Feb 2024 07:54:35 -0500 Subject: [PATCH 1/4] allow integrand_prototype for non-iip integral functions --- src/scimlfunctions.jl | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index b30eca7af..df919b397 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2088,9 +2088,10 @@ out-of-place handling. IntegralFunction{iip,specialize}(f, [integrand_prototype]) ``` -Note that only `f` is required, and in the case of inplace integrands a mutable container +Note that only `f` is required, and in the case of inplace integrands a mutable array `integrand_prototype` to store the result of the integrand. If `integrand_prototype` is -present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place. +present for either in-place or out-of-place integrands it is used to infer the return type +of the integrand. ## iip: In-Place vs Out-Of-Place @@ -2148,14 +2149,14 @@ BatchIntegralFunction{iip,specialize}(bf, [integrand_prototype]; max_batch=typemax(Int)) ``` Note that only `bf` is required, and in the case of inplace integrands a mutable -container `integrand_prototype` to store a batch of integrand evaluations, with +array `integrand_prototype` to store a batch of integrand evaluations, with a last "batching" dimension. The keyword `max_batch` is used to set a soft limit on the number of points to batch at the same time so that memory usage is controlled. -If `integrand_prototype` is present, `bf` is interpreted as in-place, and -otherwise `bf` is assumed to be out-of-place. +If `integrand_prototype` is present for either in-place or out-of-place integrands it is +used to infer the return type of the integrand. ## iip: In-Place vs Out-Of-Place @@ -3890,10 +3891,7 @@ function IntegralFunction(f) end function IntegralFunction(f, integrand_prototype) calculated_iip = isinplace(f, 3, "integral", true) - if !calculated_iip - throw(IntegrandMismatchFunctionError(calculated_iip, true)) - end - IntegralFunction{true}(f, integrand_prototype) + IntegralFunction{calculated_iip}(f, integrand_prototype) end function BatchIntegralFunction{iip, specialize}(f, integrand_prototype; @@ -3926,10 +3924,7 @@ function BatchIntegralFunction(f; kwargs...) end function BatchIntegralFunction(f, integrand_prototype; kwargs...) calculated_iip = isinplace(f, 3, "batchintegral", true) - if !calculated_iip - throw(IntegrandMismatchFunctionError(calculated_iip, true)) - end - BatchIntegralFunction{true}(f, integrand_prototype; kwargs...) + BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...) end ########## Utility functions From ab68b63b5fc536a19a8b5f5f19a7047ccea35fa3 Mon Sep 17 00:00:00 2001 From: lxvm Date: Mon, 26 Feb 2024 07:54:44 -0500 Subject: [PATCH 2/4] add test --- test/function_building_error_messages.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index 7f0ba65ab..235eee0f8 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -474,6 +474,7 @@ intfiip(y, u, p) = y .= 1.0 for (f, kws, iip) in ( (intf, (;), false), (IntegralFunction(intf), (;), false), + (IntegralFunction(intf, 1.0), (;), false), (intfiip, (; nout = 3), true), (IntegralFunction(intfiip, zeros(3)), (;), true) ), domain in (((0.0, 1.0),), (([0.0], [1.0]),), (0.0, 1.0), ([0.0], [1.0])) From 8530d24150768cd94adc47e0b27c3dc606f19253 Mon Sep 17 00:00:00 2001 From: lxvm Date: Mon, 26 Feb 2024 23:29:54 -0500 Subject: [PATCH 3/4] format --- src/scimlfunctions.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index df919b397..41d2d911d 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -3195,7 +3195,8 @@ function DAEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {iip, + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where { + iip, specialize } if jac === nothing && isa(jac_prototype, AbstractSciMLOperator) From 263f91092c609fc04cc9c4fbaadaef826d27b145 Mon Sep 17 00:00:00 2001 From: lxvm Date: Mon, 26 Feb 2024 23:30:11 -0500 Subject: [PATCH 4/4] update tests --- test/function_building_error_messages.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index 235eee0f8..daf1c8565 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -649,9 +649,9 @@ i1(u) = u itoo(y, u, p, a) = y .= u * p IntegralFunction(ioop) +IntegralFunction(ioop, 0.0) IntegralFunction(iiip, Float64[]) -@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(ioop, Float64[]) @test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(iiip) @test_throws SciMLBase.TooFewArgumentsError IntegralFunction(i1) @test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo) @@ -666,10 +666,11 @@ bitoo(y, u, p, a) = y .= p .* u BatchIntegralFunction(boop) BatchIntegralFunction(boop, max_batch = 20) +BatchIntegralFunction(boop, Float64[]) +BatchIntegralFunction(boop, Float64[], max_batch = 20) BatchIntegralFunction(biip, Float64[]) BatchIntegralFunction(biip, Float64[], max_batch = 20) -@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(boop, Float64[]) @test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(biip) @test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1) @test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo)