From f56d6545b88f47aa55925f82fc6f31000e8f0718 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 19 Sep 2023 02:51:33 -0400 Subject: [PATCH] canonicalize --- src/scimlfunctions.jl | 263 +++++++++++++++++------------------------- 1 file changed, 106 insertions(+), 157 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index a24cd5238..e24533809 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -203,6 +203,48 @@ function Base.showerror(io::IO, e::NonconformingFunctionsError) printstyled(io, e.nonconforming; bold = true, color = :red) end +const INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE = """ + Nonconforming functions detected. If an integrand function `f` is defined + as out-of-place (`f(u,p)`), then no integrand_prototype can be passed into the + function constructor. Likewise if `f` is defined as in-place (`f(out,u,p)`), then + an integrand_prototype is required. Either change the use of the function + constructor or define the appropriate dispatch for `f`. + """ + +struct IntegrandMismatchFunctionError <: Exception + iip::Bool + integrand_passed::Bool +end + +function Base.showerror(io::IO, e::IntegrandMismatchFunctionError) + println(io, INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE) + print(io, "Mismatch: IIP=") + printstyled(io, e.iip; bold = true, color = :red) + print(io, ", Integrand passed=") + printstyled(io, e.integrand_passed; bold = true, color = :red) +end + +const BATCH_INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE = """ + Nonconforming functions detected. If an integrand function `f` is defined + as out-of-place (`f(u,p)`), then no integrand_prototype can be passed into the + function constructor. Likewise if `f` is defined as in-place (`f(out,u,p)`), then + an integrand_prototype is required. Either change the use of the function + constructor or define the appropriate dispatch for `f`. + """ + +struct BatchedIntegrandMismatchFunctionError <: Exception + iip::Bool + integrand_passed::Bool +end + +function Base.showerror(io::IO, e::BatchedIntegrandMismatchFunctionError) + println(io, BATCH_INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE) + print(io, "Mismatch: IIP=") + printstyled(io, e.iip; bold = true, color = :red) + print(io, ", Integrand passed=") + printstyled(io, e.integrand_passed; bold = true, color = :red) +end + """ $(TYPEDEF) """ @@ -2262,7 +2304,7 @@ end TruncatedStacktraces.@truncate_stacktrace BVPFunction 1 2 @doc doc""" - IntegralFunction{iip,specialize,F,T,J,TJ,TPJ,Ta,S,JP,SP,TCV,O} <: AbstractIntegralFunction{iip} + IntegralFunction{iip,specialize,F} <: AbstractIntegralFunction{iip} A representation of an integrand `f` defined by: @@ -2270,53 +2312,26 @@ A representation of an integrand `f` defined by: f(u, p) ``` -and its related functions, such as its Jacobian and gradient with respect to parameters. For -an in-place form of `f` see the `iip` section below for details on in-place or out-of-place +For an in-place form of `f` see the `iip` section below for details on in-place or out-of-place handling. ```julia -IntegralFunction{iip,specialize}(f, [I]; - jac = __has_jac(f) ? f.jac : nothing, - paramjac = __has_paramjac(f) ? f.paramjac : nothing, - analytic = __has_analytic(f) ? f.analytic : nothing, - syms = __has_syms(f) ? f.syms : nothing, - jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, - sparsity = __has_sparsity(f) ? f.sparsity : nothing, - colorvec = __has_colorvec(f) ? f.colorvec : nothing, - observed = __has_observed(f) ? f.observed : nothing) +IntegralFunction{iip,specialize}(f, [integrand_prototype]) ``` Note that only `f` is required, and in the case of inplace integrands a mutable container -`I` to store the result of the integral. - -The remaining functions are optional and mainly used for accelerating the usage of `f`: -- `jac`: unused -- `paramjac`: unused -- `analytic`: unused -- `syms`: unused -- `jac_prototype`: unused -- `sparsity`: unused -- `colorvec`: unused -- `observed`: unused - -Since most arguments are unused, the following constructor provides the essential behavior: - -```julia -IntegralFunction(f, [I]; kws..) -``` - -If `I` is present, `f` is interpreted as in-place, and otherwise `f` is assumed to be -out-of-place. +`integrand_prototype` to store the result of the integral. If `integrand_prototype` is present, +`f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place. ## iip: In-Place vs Out-Of-Place Out-of-place functions must be of the form ``f(u, p)`` and in-place functions of the form ``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or -arrays), in-place functions must provide a container `I` that is of the right type for the -final result of the integral, and the result is written to this container in-place. When -in-place forms are used, in-place array operations may be used by algorithms to reduce -allocations. If `I` is not provided, `f` is assumed to be out-of-place and quadrature is -performed assuming immutable return types. +arrays), in-place functions must provide a container `integrand_prototype` that is of the +right type for the final result of the integral, and the result is written to this container +in-place. When in-place forms are used, in-place array operations may be used by algorithms +to reduce allocations. If `integrand_prototype` is not provided, `f` is assumed to be +out-of-place and quadrature is performed assuming immutable return types. ## specialize @@ -2326,18 +2341,10 @@ This field is currently unused The fields of the IntegralFunction type directly match the names of the inputs. """ -struct IntegralFunction{iip, specialize, F, T, TJ, TPJ, Ta, S, JP, SP, TCV, O} <: +struct IntegralFunction{iip, specialize, F, T} <: AbstractIntegralFunction{iip} f::F - I::T - jac::TJ - paramjac::TPJ - analytic::Ta - syms::S - jac_prototype::JP - sparsity::SP - colorvec::TCV - observed::O + integrand_prototype::T end TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2 @@ -2352,25 +2359,13 @@ using threads, the gpu, or distributed memory defined by: f(y, u, p) ``` -and its related functions, such as its Jacobian and gradient with respect to parameters. For -an in-place form of `f` see the `iip` section below for details on in-place or out-of-place -handling. - ``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose output must be returned in the corresponding entries of ``y``. In general, the integration algorithm is allowed to vary the number of evaluation points between subsequent calls to `f` ```julia -BatchIntegralFunction{iip,specialize}(f, y, [I]; - max_batch = typemax(Int), - jac = __has_jac(f) ? f.jac : nothing, - paramjac = __has_paramjac(f) ? f.paramjac : nothing, - analytic = __has_analytic(f) ? f.analytic : nothing, - syms = __has_syms(f) ? f.syms : nothing, - jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, - sparsity = __has_sparsity(f) ? f.sparsity : nothing, - colorvec = __has_colorvec(f) ? f.colorvec : nothing, - observed = __has_observed(f) ? f.observed : nothing) +BatchIntegralFunction{iip,specialize}(f, y, [integrand_prototype]; + max_batch=typemax(Int)) ``` Note that `f` is required and a `resize`-able buffer `y` to store the output, or range of @@ -2381,37 +2376,27 @@ allocations. The keyword `max_batch` is used to set a soft limit on the number of points to batch at the same time so that memory usage is controlled. -The remaining functions are optional and mainly used for accelerating the usage of `f`: -- `jac`: unused -- `paramjac`: unused -- `analytic`: unused -- `syms`: unused -- `jac_prototype`: unused -- `sparsity`: unused -- `colorvec`: unused -- `observed`: unused - Since most arguments are unused, the following constructor provides the essential behavior: ```julia -BatchIntegralFunction(f, y, [I]; max_batch=typemax(Int), kws..) +BatchIntegralFunction(f, y, [integrand_prototype]; max_batch=typemax(Int), kws..) ``` -If `I` is present, `f` is interpreted as in-place, and otherwise `f` is assumed to be +If `integrand_prototype` is present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place. ## iip: In-Place vs Out-Of-Place Out-of-place and in-place functions are both of the form ``f(y, u, p)``, but differ in the element type of ``y``. Since `f` is allowed to return any type (e.g. real or complex numbers -or arrays), in-place functions must provide a container `I` that is of the right type for -the final result of the integral, and the result is written to this container in-place. When -`f` is in-place, the output buffer ``y`` is assumed to have a mutable element type, and the -last dimension of ``y`` should correspond to the batch index. For example, ``y`` would have -to be an `ElasticArray` or a `VectorOfSimilarArrays` of an `ElasticArray`. When in-place -forms are used, in-place array operations may be used by algorithms to reduce allocations. -If `I` is not provided, `f` is assumed to be out-of-place and quadrature is performed -assuming ``y`` is an `AbstractVector` with an immutable element type. +or arrays), in-place functions must provide a container `integrand_prototype` that is of the +right type for the final result of the integral, and the result is written to this container +in-place. When `f` is in-place, the output buffer ``y`` is assumed to have a mutable element +type, and the last dimension of ``y`` should correspond to the batch index. For example, +``y`` would have to be an `ElasticArray` or a `VectorOfSimilarArrays` of an `ElasticArray`. +When in-place forms are used, in-place array operations may be used by algorithms to reduce +allocations. If `integrand_prototype` is not provided, `f` is assumed to be out-of-place and +quadrature is performed assuming ``y`` is an `AbstractVector` with an immutable element type. ## specialize @@ -2421,20 +2406,12 @@ This field is currently unused The fields of the BatchIntegralFunction type directly match the names of the inputs. """ -struct BatchIntegralFunction{iip, specialize, F, Y, T, TJ, TPJ, Ta, S, JP, SP, TCV, O} <: +struct BatchIntegralFunction{iip, specialize, F, Y, T} <: AbstractIntegralFunction{iip} f::F y::Y - I::T + integrand_prototype::T max_batch::Int - jac::TJ - paramjac::TPJ - analytic::Ta - syms::S - jac_prototype::JP - sparsity::SP - colorvec::TCV - observed::O end TruncatedStacktraces.@truncate_stacktrace BatchIntegralFunction 1 2 @@ -4133,79 +4110,51 @@ function BVPFunction(f, bc; kwargs...) end BVPFunction(f::BVPFunction; kwargs...) = f -function IntegralFunction{iip, specialize}(f, I; - jac = __has_jac(f) ? f.jac : nothing, - paramjac = __has_paramjac(f) ? f.paramjac : nothing, - analytic = __has_analytic(f) ? f.analytic : nothing, - syms = __has_syms(f) ? f.syms : nothing, - jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, - sparsity = __has_sparsity(f) ? f.sparsity : nothing, - colorvec = __has_colorvec(f) ? f.colorvec : nothing, - observed = __has_observed(f) ? f.observed : nothing) where {iip, specialize} - IntegralFunction{ - iip, - specialize, - typeof(f), - typeof(I), - typeof(jac), - typeof(paramjac), - typeof(analytic), - typeof(syms), - typeof(jac_prototype), - typeof(sparsity), - typeof(colorvec), - typeof(observed), - }(f, - I, - jac, - paramjac, - analytic, - syms, - jac_prototype, - sparsity, - colorvec, - observed) +function IntegralFunction{iip, specialize}(f, integrand_prototype) where {iip, specialize} + IntegralFunction{iip,specialize,typeof(f),typeof(I)}(f,integrand_prototype) end -function IntegralFunction{iip}(f, I; kws...) where {iip} - return IntegralFunction{iip, FullSpecialize}(f, I; kws...) +function IntegralFunction{iip}(f, integrand_prototype) where {iip} + return IntegralFunction{iip, FullSpecialize}(f, integrand_prototype) +end +function IntegralFunction(f) + calcuated_iip = isinplace(f, 3, "integral", iip) + if !calcuated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, false)) + end + IntegralFunction{false}(f, nothing) +end +function IntegralFunction(f, integrand_prototype) + calcuated_iip = isinplace(f, 3, "integral", iip) + if !calcuated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, true)) + end + IntegralFunction{true}(f, integrand_prototype) end -IntegralFunction(f; kws...) = IntegralFunction{false}(f, nothing; kws...) -IntegralFunction(f, I; kws...) = IntegralFunction{true}(f, I; kws...) -function BatchIntegralFunction{iip, specialize}(f, y, I; - max_batch::Integer = typemax(Int), - jac = __has_jac(f) ? f.jac : nothing, - paramjac = __has_paramjac(f) ? f.paramjac : nothing, - analytic = __has_analytic(f) ? f.analytic : nothing, - syms = __has_syms(f) ? f.syms : nothing, - jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, - sparsity = __has_sparsity(f) ? f.sparsity : nothing, - colorvec = __has_colorvec(f) ? f.colorvec : nothing, - observed = __has_observed(f) ? f.observed : nothing) where {iip, specialize} - BatchIntegralFunction{ - iip, - specialize, - typeof(f), - typeof(y), - typeof(I), - typeof(jac), - typeof(paramjac), - typeof(analytic), - typeof(syms), - typeof(jac_prototype), - typeof(sparsity), - typeof(colorvec), - typeof(observed), - }(f, y, I, max_batch, jac, paramjac, analytic, syms, jac_prototype, sparsity, colorvec, - observed) +function BatchIntegralFunction{iip, specialize}(f, output_prototype, integrand_prototype; + max_batch::Integer = typemax(Int)) where {iip, specialize} + BatchIntegralFunction{iip,specialize,typeof(f),typeof(output_prototype),typeof(integrand_prototype) + }(f, output_prototype, integrand_prototype, max_batch) end -function BatchIntegralFunction{iip}(f, y, I; kws...) where {iip} - return BatchIntegralFunction{iip, FullSpecialize}(f, y, I; kws...) +function BatchIntegralFunction{iip}(f, output_prototype, integrand_prototype; kwargs...) where {iip} + return BatchIntegralFunction{iip, FullSpecialize}(f, output_prototype, integrand_prototype; kwargs...) +end +function BatchIntegralFunction(f, output_prototype; kwargs...) + calcuated_iip = isinplace(f, 4, "batchintegral", iip) + if !calcuated_iip + throw(BatchedIntegrandMismatchFunctionError(calculated_iip, false)) + end + BatchIntegralFunction{false}(f, output_prototype, nothing; kwargs...) +end +function BatchIntegralFunction(f, output_prototype, integrand_prototype; kwargs...) + calcuated_iip = isinplace(f, 4, "batchintegral", iip) + if !calcuated_iip + throw(BatchIntegrandMismatchFunctionError(calculated_iip, true)) + end + BatchIntegralFunction{true}(f, output_prototype, integrand_prototype; kwargs...) end -BatchIntegralFunction(f, y; kws...) = BatchIntegralFunction{false}(f, y, nothing; kws...) -BatchIntegralFunction(f, y, I; kws...) = BatchIntegralFunction{true}(f, y, I; kws...) ########## Existence Functions