diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 0854c0e04..4afc74b32 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -43,6 +43,11 @@ jobs: - {user: SciML, repo: StochasticDelayDiffEq.jl, group: All} - {user: SciML, repo: SimpleNonlinearSolve.jl, group: All} - {user: SciML, repo: SimpleDiffEq.jl, group: All} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core1} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core2} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core3} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core4} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core5} steps: - uses: actions/checkout@v4 diff --git a/Project.toml b/Project.toml index e52bc477a..079bca008 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "1.97.1" +version = "2.0.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -12,6 +12,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -45,6 +46,7 @@ CommonSolve = "0.2.4" ConstructionBase = "1" DocStringExtensions = "0.8, 0.9" EnumX = "1" +FillArrays = "1.6" FunctionWrappersWrappers = "0.1.3" IteratorInterfaceExtensions = "^0.1, ^1" PrecompileTools = "1" diff --git a/README.md b/README.md index 87749abf7..b294f98c8 100644 --- a/README.md +++ b/README.md @@ -12,3 +12,16 @@ SciMLBase.jl is the core interface definition of the SciML ecosystem. It is a low dependency library made to be depended on by the downstream libraries to supply the common interface and allow for interexchange of mathematical problems. + +## v2.0 Breaking Changes + +The breaking changes in v2.0 are: + +* `IntegralProblem` has moved to an interface with `IntegralFunction` and `BatchedIntegralFunction` which requires specifying `prototype`s for the values to be modified + instead of `nout` and `batch`. https://github.com/SciML/SciMLBase.jl/pull/497 +* `ODEProblem` was made temporarily into a `mutable struct` to allow for EnzymeRules support. Using the mutation throws a warning that this is only experimental and should not be relied on. + https://github.com/SciML/SciMLBase.jl/pull/501 +* `BVProblem` now has a new interface for `TwoPointBVProblem` which splits the bc terms for the two sides, forcing a true two-point BVProblem to allow for further specializations and to allow + for wrapping Fortran solvers in the interface. https://github.com/SciML/SciMLBase.jl/pull/477 +* `SDEProblem` constructor was changed to remove an anti-pattern which required passing the diffusion function `g` twice, i.e. `SDEProblem(SDEFunction(f,g),g, ...)`. + Now this is simply `SDEProblem(SDEFunction(f,g),...)`. https://github.com/SciML/SciMLBase.jl/pull/489 diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index e4d1e371f..d206ca07d 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -23,6 +23,7 @@ import TruncatedStacktraces import ADTypes: AbstractADType import ChainRulesCore import ZygoteRules: @adjoint +import FillArrays using Reexport using SciMLOperators @@ -588,6 +589,14 @@ abstract type AbstractDiffEqFunction{iip} <: """ $(TYPEDEF) +Base for types defining integrand functions. +""" +abstract type AbstractIntegralFunction{iip} <: + AbstractSciMLFunction{iip} end + +""" +$(TYPEDEF) + Base for types defining optimization functions. """ abstract type AbstractOptimizationFunction{iip} <: AbstractSciMLFunction{iip} end @@ -658,7 +667,9 @@ function specialization(::Union{ODEFunction{iip, specialize}, RODEFunction{iip, specialize}, NonlinearFunction{iip, specialize}, OptimizationFunction{iip, specialize}, - BVPFunction{iip, specialize}}) where {iip, + BVPFunction{iip, specialize}, + IntegralFunction{iip, specialize}, + BatchIntegralFunction{iip, specialize}}) where {iip, specialize} specialize end @@ -760,9 +771,7 @@ export solve, solve!, init, discretize, symbolic_discretize export LinearProblem, NonlinearProblem, IntervalNonlinearProblem, - IntegralProblem, OptimizationProblem - -export IntegralProblem + IntegralProblem, SampledIntegralProblem, OptimizationProblem export DiscreteProblem, ImplicitDiscreteProblem export SteadyStateProblem, SteadyStateSolution @@ -788,7 +797,8 @@ export remake export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction, DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction, - IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction + IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction, + IntegralFunction, BatchIntegralFunction export OptimizationFunction diff --git a/src/problems/basic_problems.jl b/src/problems/basic_problems.jl index 6cb329290..aece5fa7b 100644 --- a/src/problems/basic_problems.jl +++ b/src/problems/basic_problems.jl @@ -334,25 +334,17 @@ which are `Number`s or `AbstractVector`s with the same geometry as `u`. ### Constructors -IntegralProblem{iip}(f,lb,ub,p=NullParameters(); - nout=1, batch = 0, kwargs...) +``` +IntegralProblem(f,domain,p=NullParameters(); kwargs...) +IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...) +``` -- f: the integrand, `y = f(u,p)` for out-of-place or `f(y,u,p)` for in-place. +- f: the integrand, callable function `y = f(u,p)` for out-of-place (default) or an + `IntegralFunction` or `BatchIntegralFunction` for inplace and batching optimizations. +- domain: an object representing an integration domain, i.e. the tuple `(lb, ub)`. - lb: Either a number or vector of lower bounds. - ub: Either a number or vector of upper bounds. - p: The parameters associated with the problem. -- nout: The output size of the function f. Defaults to 1, i.e., a scalar valued function. - If `nout > 1` f is a vector valued function . -- batch: The preferred number of points to batch. This allows user-side parallelization - of the integrand. If `batch == 0` no batching is performed. - If `batch > 0` both `u` and `y` get an additional dimension added to it. - This means that: - if `f` is a multi variable function each `u[:,i]` is a different point to evaluate `f` at, - if `f` is a single variable function each `u[i]` is a different point to evaluate `f` at, - if `f` is a vector valued function each `y[:,i]` is the evaluation of `f` at a different point, - if `f` is a scalar valued function `y[i]` is the evaluation of `f` at a different point. - Note that batch is a suggestion for the number of points, - and it is not necessarily true that batch is the same as batchsize in all algorithms. - kwargs: Keyword arguments copied to the solvers. Additionally, we can supply iip like IntegralProblem{iip}(...) as true or false to declare at @@ -362,28 +354,58 @@ compile time whether the integrator function is in-place. The fields match the names of the constructor arguments. """ -struct IntegralProblem{isinplace, P, F, B, K} <: AbstractIntegralProblem{isinplace} +struct IntegralProblem{isinplace, P, F, T, K} <: AbstractIntegralProblem{isinplace} f::F - lb::B - ub::B - nout::Int + domain::T p::P - batch::Int kwargs::K - @add_kwonly function IntegralProblem{iip}(f, lb, ub, p = NullParameters(); - nout = 1, - batch = 0, kwargs...) where {iip} - @assert typeof(lb)==typeof(ub) "Type of lower and upper bound must match" + @add_kwonly function IntegralProblem{iip}(f::AbstractIntegralFunction{iip}, domain, + p = NullParameters(); + kwargs...) where {iip} warn_paramtype(p) - new{iip, typeof(p), typeof(f), typeof(lb), typeof(kwargs)}(f, lb, ub, nout, p, - batch, kwargs) + new{iip, typeof(p), typeof(f), typeof(domain), typeof(kwargs)}(f, + domain, p, kwargs) end end TruncatedStacktraces.@truncate_stacktrace IntegralProblem 1 4 -function IntegralProblem(f, lb, ub, args...; kwargs...) - IntegralProblem{isinplace(f, 3)}(f, lb, ub, args...; kwargs...) +function IntegralProblem(f::AbstractIntegralFunction, + domain, + p = NullParameters(); + kwargs...) + IntegralProblem{isinplace(f)}(f, domain, p; kwargs...) +end + +function IntegralProblem(f::AbstractIntegralFunction, + lb::B, + ub::B, + p = NullParameters(); + kwargs...) where {B} + IntegralProblem(f, (lb, ub), p; kwargs...) +end + +function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...) + if nout !== nothing || batch !== nothing + @warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details." + end + + max_batch = batch === nothing ? 0 : batch + g = if isinplace(f, 3) + output_prototype = Vector{Float64}(undef, nout === nothing ? 1 : nout) + if max_batch == 0 + IntegralFunction(f, output_prototype) + else + BatchIntegralFunction(f, output_prototype, max_batch=max_batch) + end + else + if max_batch == 0 + IntegralFunction(f) + else + BatchIntegralFunction(f, max_batch=max_batch) + end + end + IntegralProblem(g, args...; kwargs...) end struct QuadratureProblem end @@ -391,6 +413,56 @@ struct QuadratureProblem end @doc doc""" +Defines a integral problem over pre-sampled data. +Documentation Page: https://docs.sciml.ai/Integrals/stable/ + +## Mathematical Specification of a data Integral Problem + +Sampled integral problems are defined as: + +```math +\sum_i w_i y_i +``` +where `y_i` are sampled values of the integrand, and `w_i` are weights +assigned by a quadrature rule, which depend on sampling points `x`. + +## Problem Type + +### Constructors + +``` +SampledIntegralProblem(y::AbstractArray, x::AbstractVector; dim=ndims(y), kwargs...) +``` +- y: The sampled integrand, must be a subtype of `AbstractArray`. + It is assumed that the values of `y` along dimension `dim` + correspond to the integrand evaluated at sampling points `x` +- x: Sampling points, must be a subtype of `AbstractVector`. +- dim: Dimension along which to integrate. Defaults to the last dimension of `y`. +- kwargs: Keyword arguments copied to the solvers. + +### Fields + +The fields match the names of the constructor arguments. +""" +struct SampledIntegralProblem{Y, X, K} <: AbstractIntegralProblem{false} + y::Y + x::X + dim::Int + kwargs::K + @add_kwonly function SampledIntegralProblem(y::AbstractArray, x::AbstractVector; + dim = ndims(y), + kwargs...) + @assert dim<=ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`" + @assert length(x)==size(y, dim) "The integrand `y` must have the same length as the sampling points `x` along the integrated dimension." + @assert axes(x, 1)==axes(y, dim) "The integrand `y` must obey the same indexing as the sampling points `x` along the integrated dimension." + new{typeof(y), typeof(x), typeof(kwargs)}(y, x, dim, kwargs) + end +end + +TruncatedStacktraces.@truncate_stacktrace SampledIntegralProblem 1 4 + +@doc doc""" + Defines an optimization problem. Documentation Page: https://docs.sciml.ai/Optimization/stable/API/optimization_problem/ diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 3beec8702..2e9f1feab 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -3,6 +3,11 @@ $(TYPEDEF) """ struct StandardBVProblem end +""" +$(TYPEDEF) +""" +struct TwoPointBVProblem end + @doc doc""" Defines an BVP problem. @@ -17,7 +22,7 @@ condition ``u_0`` which define an ODE: \frac{du}{dt} = f(u,p,t) ``` -along with an implicit function `bc!` which defines the residual equation, where +along with an implicit function `bc` which defines the residual equation, where ```math bc(u,p,t) = 0 @@ -36,22 +41,27 @@ u(t_f) = b ### Constructors ```julia -TwoPointBVProblem{isinplace}(f,bc!,u0,tspan,p=NullParameters();kwargs...) -BVProblem{isinplace}(f,bc!,u0,tspan,p=NullParameters();kwargs...) +TwoPointBVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...) +BVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...) ``` or if we have an initial guess function `initialGuess(t)` for the given BVP, we can pass the initial guess to the problem constructors: ```julia -TwoPointBVProblem{isinplace}(f,bc!,initialGuess,tspan,p=NullParameters();kwargs...) -BVProblem{isinplace}(f,bc!,initialGuess,tspan,p=NullParameters();kwargs...) +TwoPointBVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...) +BVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...) ``` -For any BVP problem type, `bc!` is the inplace function: +For any BVP problem type, `bc` must be inplace if `f` is inplace. Otherwise it must be +out-of-place. + +If the bvp is a StandardBVProblem (also known as a Multi-Point BV Problem) it must define +either of the following functions ```julia bc!(residual, u, p, t) +residual = bc(u, p, t) ``` where `residual` computed from the current `u`. `u` is an array of solution values @@ -61,6 +71,16 @@ time points, and for shooting type methods `u=sol` the ODE solution. Note that all features of the `ODESolution` are present in this form. In both cases, the size of the residual matches the size of the initial condition. +If the bvp is a TwoPointBVProblem it must define either of the following functions + +```julia +bc!((resid_a, resid_b), (u_a, u_b), p) +resid_a, resid_b = bc((u_a, u_b), p) +``` + +where `resid_a` and `resid_b` are the residuals at the two endpoints, `u_a` and `u_b` are +the solution values at the two endpoints, and `p` are the parameters. + Parameters are optional, and if not given, then a `NullParameters()` singleton will be used which will throw nice errors if you try to index non-existent parameters. Any extra keyword arguments are passed on to the solvers. For example, @@ -88,16 +108,20 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <: problem_type::PT kwargs::K - @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip}, bc, u0, tspan, - p = NullParameters(), - problem_type = StandardBVProblem(); - kwargs...) where {iip} + @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan, + p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP} _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), isinplace(f), typeof(p), - typeof(f), typeof(f.bc), - typeof(problem_type), typeof(kwargs)}(f, f.bc, u0, _tspan, p, - problem_type, kwargs) + prob_type = TP ? TwoPointBVProblem() : StandardBVProblem() + # Needed to ensure that `problem_type` doesn't get passed in kwargs + if problem_type === nothing + problem_type = prob_type + else + @assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use." + end + return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc), + typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, problem_type, + kwargs) end function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip} @@ -107,52 +131,43 @@ end TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2 -function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...) - BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...) +function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) + iip = isinplace(f, 4) + return BVProblem{iip}(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...) end -function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) - BVProblem(BVPFunction(f, bc), u0, tspan, p; kwargs...) +function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...) + return BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...) end -""" -$(TYPEDEF) -""" -struct TwoPointBVPFunction{bF} - bc::bF +# This is mostly a fake stuct and isn't used anywhere +# But we need it for function calls like TwoPointBVProblem{iip}(...) = ... +struct TwoPointBVPFunction{iip} end + +@inline TwoPointBVPFunction(args...; kwargs...) = BVPFunction(args...; kwargs..., twopoint=true) +@inline function TwoPointBVPFunction{iip}(args...; kwargs...) where {iip} + return BVPFunction{iip}(args...; kwargs..., twopoint=true) end -TwoPointBVPFunction(; bc = error("No argument bc")) = TwoPointBVPFunction(bc) -(f::TwoPointBVPFunction)(residual, ua, ub, p) = f.bc(residual, ua, ub, p) -(f::TwoPointBVPFunction)(residual, u, p) = f.bc(residual, u[1], u[end], p) -""" -$(TYPEDEF) -""" -struct TwoPointBVProblem{iip} end -function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) - iip = isinplace(f, 4) - TwoPointBVProblem{iip}(f, bc, u0, tspan, p; kwargs...) +function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters(); + bcresid_prototype=nothing, kwargs...) + return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p; + kwargs...) end -function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); - kwargs...) where {iip} - BVProblem{iip}(f, TwoPointBVPFunction(bc), u0, tspan, p; kwargs...) +function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan, + p = NullParameters(); kwargs...) where {iip, twopoint} + @assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`." + return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...) end # Allow previous timeseries solution -function TwoPointBVProblem(f::AbstractODEFunction, - bc, - sol::T, - tspan::Tuple, - p = NullParameters()) where {T <: AbstractTimeseriesSolution} - TwoPointBVProblem(f, bc, sol.u, tspan, p) +function TwoPointBVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple, + p = NullParameters(); kwargs...) where {T <: AbstractTimeseriesSolution} + return TwoPointBVProblem(f, bc, sol.u, tspan, p; kwargs...) end # Allow initial guess function for the initial guess -function TwoPointBVProblem(f::AbstractODEFunction, - bc, - initialGuess, - tspan::AbstractVector, - p = NullParameters(); - kwargs...) +function TwoPointBVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector, + p = NullParameters(); kwargs...) u0 = [initialGuess(i) for i in tspan] - TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p) + return TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p; kwargs...) end diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index 20949d644..ce4425db3 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -94,7 +94,7 @@ prob = ODEProblemLibrary.prob_ode_linear sol = solve(prob) ``` """ -struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: +mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: AbstractODEProblem{uType, tType, isinplace} """The ODE is `du = f(u,p,t)` for out-of-place and f(du,u,p,t) for in-place.""" f::F @@ -162,6 +162,16 @@ struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: end TruncatedStacktraces.@truncate_stacktrace ODEProblem 3 1 2 +function Base.setproperty!(prob::ODEProblem, s::Symbol, v) + @warn "Mutation of ODEProblem detected. SciMLBase v2.0 has made ODEProblem temporarily mutable in order to allow for interfacing with EnzymeRules due to a current limitation in the rule system. This change is only intended to be temporary and ODEProblem will return to being a struct in a later non-breaking release. Do not rely on this behavior, use with caution." + Base.setfield!(prob, s, v) +end + +function Base.setproperty!(prob::ODEProblem, s::Symbol, v, order::Symbol) + @warn "Mutation of ODEProblem detected. SciMLBase v2.0 has made ODEProblem temporarily mutable in order to allow for interfacing with EnzymeRules due to a current limitation in the rule system. This change is only intended to be temporary and ODEProblem will return to being a struct in a later non-breaking release. Do not rely on this behavior, use with caution." + Base.setfield!(prob, s, v, order) +end + """ ODEProblem(f::ODEFunction,u0,tspan,p=NullParameters(),callback=CallbackSet()) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 1867e1d58..c74808df2 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -203,6 +203,27 @@ function Base.showerror(io::IO, e::NonconformingFunctionsError) printstyled(io, e.nonconforming; bold = true, color = :red) end +const INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE = """ + Nonconforming functions detected. If an integrand function `f` is defined + as out-of-place (`f(u,p)`), then no integrand_prototype can be passed into the + function constructor. Likewise if `f` is defined as in-place (`f(out,u,p)`), then + an integrand_prototype is required. Either change the use of the function + constructor or define the appropriate dispatch for `f`. + """ + +struct IntegrandMismatchFunctionError <: Exception + iip::Bool + integrand_passed::Bool +end + +function Base.showerror(io::IO, e::IntegrandMismatchFunctionError) + println(io, INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE) + print(io, "Mismatch: IIP=") + printstyled(io, e.iip; bold = true, color = :red) + print(io, ", Integrand passed=") + printstyled(io, e.integrand_passed; bold = true, color = :red) +end + """ $(TYPEDEF) """ @@ -440,20 +461,20 @@ and exponential integrators. ```julia SplitFunction{iip,specialize}(f1,f2; - mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, - analytic = __has_analytic(f) ? f.analytic : nothing, - tgrad= __has_tgrad(f) ? f.tgrad : nothing, - jac = __has_jac(f) ? f.jac : nothing, - jvp = __has_jvp(f) ? f.jvp : nothing, - vjp = __has_vjp(f) ? f.vjp : nothing, - jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, - sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, - paramjac = __has_paramjac(f) ? f.paramjac : nothing, - syms = __has_syms(f) ? f.syms : nothing, - indepsym= __has_indepsym(f) ? f.indepsym : nothing, - paramsyms = __has_paramsyms(f) ? f.paramsyms : nothing, - colorvec = __has_colorvec(f) ? f.colorvec : nothing, - sys = __has_sys(f) ? f.sys : nothing) + mass_matrix = __has_mass_matrix(f1) ? f1.mass_matrix : I, + analytic = __has_analytic(f1) ? f1.analytic : nothing, + tgrad= __has_tgrad(f1) ? f1.tgrad : nothing, + jac = __has_jac(f1) ? f1.jac : nothing, + jvp = __has_jvp(f1) ? f1.jvp : nothing, + vjp = __has_vjp(f1) ? f1.vjp : nothing, + jac_prototype = __has_jac_prototype(f1) ? f1.jac_prototype : nothing, + sparsity = __has_sparsity(f1) ? f1.sparsity : jac_prototype, + paramjac = __has_paramjac(f1) ? f1.paramjac : nothing, + syms = __has_syms(f1) ? f1.syms : nothing, + indepsym= __has_indepsym(f1) ? f1.indepsym : nothing, + paramsyms = __has_paramsyms(f1) ? f1.paramsyms : nothing, + colorvec = __has_colorvec(f1) ? f1.colorvec : nothing, + sys = __has_sys(f1) ? f1.sys : nothing) ``` Note that only the functions `f_i` themselves are required. These functions should @@ -461,7 +482,7 @@ be given as `f_i!(du,u,p,t)` or `du = f_i(u,p,t)`. See the section on `iip` for more details on in-place vs out-of-place handling. All of the remaining functions are optional for improving or accelerating -the usage of `f`. These include: +the usage of the `SplitFunction`. These include: - `mass_matrix`: the mass matrix `M` represented in the ODE function. Can be used to determine that the equation is actually a differential-algebraic equation (DAE) @@ -2124,8 +2145,7 @@ TruncatedStacktraces.@truncate_stacktrace OptimizationFunction 1 2 """ $(TYPEDEF) """ -abstract type AbstractBVPFunction{iip} <: - AbstractDiffEqFunction{iip} end +abstract type AbstractBVPFunction{iip, twopoint} <: AbstractDiffEqFunction{iip} end @doc doc""" BVPFunction{iip,F,BF,TMM,Ta,Tt,TJ,BCTJ,JVP,VJP,JP,BCJP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV,BCTCV} <: AbstractBVPFunction{iip,specialize} @@ -2230,11 +2250,9 @@ For more details on this argument, see the ODEFunction documentation. The fields of the BVPFunction type directly match the names of the inputs. """ -struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP, - BCJP, SP, TW, TWt, - TPJ, - S, S2, S3, O, TCV, BCTCV, - SYS} <: AbstractBVPFunction{iip} +struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, + JP, BCJP, BCRP, SP, TW, TWt, TPJ, S, S2, S3, O, TCV, BCTCV, + SYS} <: AbstractBVPFunction{iip, twopoint} f::F bc::BF mass_matrix::TMM @@ -2246,6 +2264,7 @@ struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP, vjp::VJP jac_prototype::JP bcjac_prototype::BCJP + bcresid_prototype::BCRP sparsity::SP Wfact::TW Wfact_t::TWt @@ -2261,11 +2280,122 @@ end TruncatedStacktraces.@truncate_stacktrace BVPFunction 1 2 +@doc doc""" + IntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip} + +A representation of an integrand `f` defined by: + +```math +f(u, p) +``` + +For an in-place form of `f` see the `iip` section below for details on in-place or +out-of-place handling. + +```julia +IntegralFunction{iip,specialize}(f, [integrand_prototype]) +``` + +Note that only `f` is required, and in the case of inplace integrands a mutable container +`integrand_prototype` to store the result of the integrand. If `integrand_prototype` is +present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place. + +## iip: In-Place vs Out-Of-Place + +Out-of-place functions must be of the form ``y = f(u, p)`` and in-place functions of the form +``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or +arrays), in-place functions must provide a container `integrand_prototype` that is of the +right type for the variable ``y``, and the result is written to this container in-place. +When in-place forms are used, in-place array operations, i.e. broadcasting, may be used by +algorithms to reduce allocations. If `integrand_prototype` is not provided, `f` is assumed +to be out-of-place and quadrature is performed assuming immutable return types. + +## specialize + +This field is currently unused + +## Fields + +The fields of the IntegralFunction type directly match the names of the inputs. +""" +struct IntegralFunction{iip, specialize, F, T} <: + AbstractIntegralFunction{iip} + f::F + integrand_prototype::T +end + +TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2 + +@doc doc""" +BatchIntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip} + +A representation of an integrand `f` that can be evaluated at multiple points simultaneously +using threads, the gpu, or distributed memory defined by: + +```math +y = f(u, p) +``` + +``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose +output must be returned as an array whose last "batching" dimension corresponds to integrand +evaluations at the different points in ``u``. In general, the integration algorithm is +allowed to vary the number of evaluation points between subsequent calls to `f`. + +For an in-place form of `f` see the `iip` section below for details on in-place or +out-of-place handling. + +```julia +BatchIntegralFunction{iip,specialize}(f, [integrand_prototype]; + max_batch=typemax(Int)) +``` +Note that only `f` is required, and in the case of inplace integrands a mutable container +`integrand_prototype` to store the result of the integrand of one integrand, without a last +"batching" dimension. + +The keyword `max_batch` is used to set a soft limit on the number of points to batch at the +same time so that memory usage is controlled. + +If `integrand_prototype` is present, `f` is interpreted as in-place, and otherwise `f` is +assumed to be out-of-place. + +## iip: In-Place vs Out-Of-Place + +Out-of-place functions must be of the form ``y = f(u,p)`` and in-place functions of the form +``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or +arrays), in-place functions must provide a container `integrand_prototype` of the right type +for a single integrand evaluation. The integration algorithm will then allocate a ``y`` +array with the same element type as `integrand_prototype` and an additional last "batching" +dimension to store multiple integrand evaluations. In the out-of-place case, the algorithm +may infer the type of ``y`` by passing `f` an empty array of input points. This means ``y`` +is a vector in the out-of-place case, or a matrix/array in the in-place case. The number of +batched points may vary between subsequent calls to `f`. When in-place forms are used, +in-place array operations may be used by algorithms to reduce allocations. If +`integrand_prototype` is not provided, `f` is assumed to be out-of-place. + +## specialize + +This field is currently unused + +## Fields + +The fields of the BatchIntegralFunction type directly match the names of the inputs. +""" +struct BatchIntegralFunction{iip, specialize, F, T} <: + AbstractIntegralFunction{iip} + f::F + integrand_prototype::T + max_batch::Int +end + +TruncatedStacktraces.@truncate_stacktrace BatchIntegralFunction 1 2 + ######### Backwards Compatibility Overloads (f::ODEFunction)(args...) = f.f(args...) (f::NonlinearFunction)(args...) = f.f(args...) (f::IntervalNonlinearFunction)(args...) = f.f(args...) +(f::IntegralFunction)(args...) = f.f(args...) +(f::BatchIntegralFunction)(args...) = f.f(args...) function (f::DynamicalODEFunction)(u, p, t) ArrayPartition(f.f1(u.x[1], u.x[2], p, t), f.f2(u.x[1], u.x[2], p, t)) @@ -3648,9 +3778,8 @@ function NonlinearFunction{iip, specialize}(f; nothing, sys = __has_sys(f) ? f.sys : nothing, resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing) where { - iip, - specialize, -} + iip, specialize} + if mass_matrix === I && typeof(f) <: Tuple mass_matrix = ((I for i in 1:length(f))...,) end @@ -3814,35 +3943,28 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); cons_expr, sys) end -function BVPFunction{iip, specialize}(f, bc; - mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : - I, +function BVPFunction{iip, specialize, twopoint}(f, bc; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, analytic = __has_analytic(f) ? f.analytic : nothing, tgrad = __has_tgrad(f) ? f.tgrad : nothing, jac = __has_jac(f) ? f.jac : nothing, bcjac = __has_jac(bc) ? bc.jac : nothing, jvp = __has_jvp(f) ? f.jvp : nothing, vjp = __has_vjp(f) ? f.vjp : nothing, - jac_prototype = __has_jac_prototype(f) ? - f.jac_prototype : - nothing, - bcjac_prototype = __has_jac_prototype(bc) ? - bc.jac_prototype : - nothing, - sparsity = __has_sparsity(f) ? f.sparsity : - jac_prototype, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing, + bcresid_prototype = nothing, + sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, Wfact = __has_Wfact(f) ? f.Wfact : nothing, Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing, paramjac = __has_paramjac(f) ? f.paramjac : nothing, syms = __has_syms(f) ? f.syms : nothing, indepsym = __has_indepsym(f) ? f.indepsym : nothing, - paramsyms = __has_paramsyms(f) ? f.paramsyms : - nothing, - observed = __has_observed(f) ? f.observed : - DEFAULT_OBSERVED, + paramsyms = __has_paramsyms(f) ? f.paramsyms : nothing, + observed = __has_observed(f) ? f.observed : DEFAULT_OBSERVED, colorvec = __has_colorvec(f) ? f.colorvec : nothing, bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing, - sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize} + sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, twopoint} if mass_matrix === I && typeof(f) <: Tuple mass_matrix = ((I for i in 1:length(f))...,) end @@ -3882,7 +4004,7 @@ function BVPFunction{iip, specialize}(f, bc; _bccolorvec = bccolorvec end - bciip = isinplace(bc, 4, "bc", iip) + bciip = !twopoint ? isinplace(bc, 4, "bc", iip) : isinplace(bc, 3, "bc", iip) jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip @@ -3892,20 +4014,25 @@ function BVPFunction{iip, specialize}(f, bc; Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip - nonconforming = (jaciip, - tgradiip, - jvpiip, - vjpiip, - Wfactiip, - Wfact_tiip, + nonconforming = (bciip, jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip, paramjaciip) .!= iip bc_nonconforming = bcjaciip .!= bciip if any(nonconforming) nonconforming = findall(nonconforming) - functions = ["jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming] + functions = ["bc", "jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", + "paramjac"][nonconforming] throw(NonconformingFunctionsError(functions)) end + if twopoint + if iip && (bcresid_prototype === nothing || length(bcresid_prototype) != 2) + error("bcresid_prototype must be a tuple / indexable collection of length 2 for a inplace TwoPointBVPFunction") + end + if bcresid_prototype !== nothing && length(bcresid_prototype) == 2 + bcresid_prototype = ArrayPartition(bcresid_prototype[1], bcresid_prototype[2]) + end + end + if any(bc_nonconforming) bc_nonconforming = findall(bc_nonconforming) functions = ["bcjac"][bc_nonconforming] @@ -3913,48 +4040,97 @@ function BVPFunction{iip, specialize}(f, bc; end if specialize === NoSpecialize - BVPFunction{iip, specialize, Any, Any, Any, Any, Any, - Any, Any, Any, Any, Any, Any, Any, Any, Any, + BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), Any, typeof(_colorvec), typeof(_bccolorvec), Any}(f, bc, mass_matrix, - analytic, - tgrad, - jac, bcjac, jvp, vjp, - jac_prototype, - bcjac_prototype, - sparsity, Wfact, - Wfact_t, - paramjac, syms, - indepsym, paramsyms, - observed, + analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype, + bcjac_prototype, bcresid_prototype, + sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, _bccolorvec, sys) else - BVPFunction{iip, specialize, typeof(f), typeof(bc), typeof(mass_matrix), - typeof(analytic), - typeof(tgrad), - typeof(jac), typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype), - typeof(bcjac_prototype), - typeof(sparsity), typeof(Wfact), typeof(Wfact_t), - typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), - typeof(observed), + BVPFunction{iip, specialize, twopoint, typeof(f), typeof(bc), typeof(mass_matrix), + typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp), + typeof(vjp), typeof(jac_prototype), + typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity), + typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), + typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(f, bc, mass_matrix, analytic, tgrad, jac, bcjac, jvp, vjp, - jac_prototype, bcjac_prototype, sparsity, + jac_prototype, bcjac_prototype, bcresid_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, _bccolorvec, sys) end end -function BVPFunction{iip}(f, bc; kwargs...) where {iip} - BVPFunction{iip, FullSpecialize}(f, bc; kwargs...) +function BVPFunction{iip}(f, bc; twopoint::Bool=false, kwargs...) where {iip} + BVPFunction{iip, FullSpecialize, twopoint}(f, bc; kwargs...) end BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f -function BVPFunction(f, bc; kwargs...) - BVPFunction{isinplace(f, 4), FullSpecialize}(f, bc; kwargs...) +function BVPFunction(f, bc; twopoint::Bool=false, kwargs...) + BVPFunction{isinplace(f, 4), FullSpecialize, twopoint}(f, bc; kwargs...) end BVPFunction(f::BVPFunction; kwargs...) = f +function IntegralFunction{iip, specialize}(f, integrand_prototype) where {iip, specialize} + IntegralFunction{iip, specialize, typeof(f), typeof(integrand_prototype)}(f, + integrand_prototype) +end + +function IntegralFunction{iip}(f, integrand_prototype) where {iip} + return IntegralFunction{iip, FullSpecialize}(f, integrand_prototype) +end +function IntegralFunction(f) + calculated_iip = isinplace(f, 3, "integral", true) + if calculated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, false)) + end + IntegralFunction{false}(f, nothing) +end +function IntegralFunction(f, integrand_prototype) + calcuated_iip = isinplace(f, 3, "integral", true) + if !calcuated_iip + throw(IntegrandMismatchFunctionError(calcuated_iip, true)) + end + IntegralFunction{true}(f, integrand_prototype) +end + +function BatchIntegralFunction{iip, specialize}(f, integrand_prototype; + max_batch::Integer = typemax(Int)) where {iip, specialize} + BatchIntegralFunction{ + iip, + specialize, + typeof(f), + typeof(integrand_prototype), + }(f, + integrand_prototype, + max_batch) +end + +function BatchIntegralFunction{iip}(f, + integrand_prototype; + kwargs...) where {iip} + return BatchIntegralFunction{iip, FullSpecialize}(f, + integrand_prototype; + kwargs...) +end + +function BatchIntegralFunction(f; kwargs...) + calculated_iip = isinplace(f, 3, "batchintegral", true) + if calculated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, false)) + end + BatchIntegralFunction{false}(f, nothing; kwargs...) +end +function BatchIntegralFunction(f, integrand_prototype; kwargs...) + calculated_iip = isinplace(f, 3, "batchintegral", true) + if !calculated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, true)) + end + BatchIntegralFunction{true}(f, integrand_prototype; kwargs...) +end + ########## Existence Functions # Check that field/property exists (may be nothing) @@ -4064,7 +4240,9 @@ for S in [:ODEFunction :NonlinearFunction :IntervalNonlinearFunction :IncrementingODEFunction - :BVPFunction] + :BVPFunction + :IntegralFunction + :BatchIntegralFunction] @eval begin function ConstructionBase.constructorof(::Type{<:$S{iip}}) where { iip, diff --git a/src/solutions/zygote.jl b/src/solutions/zygote.jl index 08090bdbf..d41d07e0f 100644 --- a/src/solutions/zygote.jl +++ b/src/solutions/zygote.jl @@ -1,6 +1,6 @@ @adjoint function getindex(VA::ODESolution, i::Int) function ODESolution_getindex_pullback(Δ) - Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] + Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] (Δ′, nothing) end diff --git a/test/downstream/ensemble_bvp.jl b/test/downstream/ensemble_bvp.jl index 4f3b2bc2a..ad08236c7 100644 --- a/test/downstream/ensemble_bvp.jl +++ b/test/downstream/ensemble_bvp.jl @@ -19,4 +19,4 @@ tspan = (0.0, pi / 2) p = [rand()] bvp = BVProblem(ode!, bc!, initial_guess, tspan, p) ensemble_prob = EnsembleProblem(bvp, prob_func = prob_func) -sim = solve(ensemble_prob, GeneralMIRK4(), trajectories = 10, dt = 0.1) +sim = solve(ensemble_prob, MIRK4(), trajectories = 10, dt = 0.1) diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index a150371d8..269985a39 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -453,7 +453,17 @@ NonlinearFunction(nfoop, vjp = nvjp) intf(u) = 1.0 @test_throws SciMLBase.TooFewArgumentsError IntegralProblem(intf, 0.0, 1.0) intf(u, p) = 1.0 +p = 2.0 + IntegralProblem(intf, 0.0, 1.0) +IntegralProblem(intf, 0.0, 1.0, p) +IntegralProblem(intf, [0.0], [1.0]) +IntegralProblem(intf, [0.0], [1.0], p) + +x = [1.0, 2.0] +y = rand(2, 2) +SampledIntegralProblem(y, x) +SampledIntegralProblem(y, x; dim = 2) # Optimization @@ -518,8 +528,14 @@ BVPFunction(bfoop, bcoop, jac = bjac) bjac(du, u, p, t) = [1.0] bcjac(du, u, p, t) = [1.0] BVPFunction(bfiip, bciip, jac = bjac, bcjac = bcjac) -BVPFunction(bfoop, bciip, jac = bjac, bcjac = bcjac) -BVPFunction(bfiip, bcoop, jac = bjac, bcjac = bcjac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, + bciip, + jac = bjac, + bcjac = bcjac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, + bcoop, + jac = bjac, + bcjac = bcjac) BVPFunction(bfoop, bcoop, jac = bjac, bcjac = bcjac) bWfact(u, t) = [1.0] @@ -530,10 +546,10 @@ bWfact(u, p, t) = [1.0] @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact = bWfact) bWfact(u, p, gamma, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, Wfact = bWfact) -BVPFunction(bfoop, bciip, Wfact = bWfact) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, Wfact = bWfact) bWfact(du, u, p, gamma, t) = [1.0] BVPFunction(bfiip, bciip, Wfact = bWfact) -BVPFunction(bfoop, bciip, Wfact = bWfact) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, Wfact = bWfact) bWfact_t(u, t) = [1.0] @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact_t = bWfact_t) @@ -545,20 +561,24 @@ bWfact_t(u, p, gamma, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, Wfact_t = bWfact_t) -BVPFunction(bfoop, bciip, Wfact_t = bWfact_t) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, + bciip, + Wfact_t = bWfact_t) bWfact_t(du, u, p, gamma, t) = [1.0] BVPFunction(bfiip, bciip, Wfact_t = bWfact_t) -BVPFunction(bfoop, bciip, Wfact_t = bWfact_t) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, + bciip, + Wfact_t = bWfact_t) btgrad(u, t) = [1.0] @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, tgrad = btgrad) @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, tgrad = btgrad) btgrad(u, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, tgrad = btgrad) -BVPFunction(bfoop, bciip, tgrad = btgrad) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, tgrad = btgrad) btgrad(du, u, p, t) = [1.0] BVPFunction(bfiip, bciip, tgrad = btgrad) -BVPFunction(bfoop, bciip, tgrad = btgrad) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, tgrad = btgrad) bparamjac(u, t) = [1.0] @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, paramjac = bparamjac) @@ -567,27 +587,66 @@ bparamjac(u, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, paramjac = bparamjac) -BVPFunction(bfoop, bciip, paramjac = bparamjac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, + bciip, + paramjac = bparamjac) bparamjac(du, u, p, t) = [1.0] BVPFunction(bfiip, bciip, paramjac = bparamjac) -BVPFunction(bfoop, bciip, paramjac = bparamjac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, + bciip, + paramjac = bparamjac) bjvp(u, p, t) = [1.0] @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, jvp = bjvp) @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, jvp = bjvp) bjvp(u, v, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, jvp = bjvp) -BVPFunction(bfoop, bciip, jvp = bjvp) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, jvp = bjvp) bjvp(du, u, v, p, t) = [1.0] BVPFunction(bfiip, bciip, jvp = bjvp) -BVPFunction(bfoop, bciip, jvp = bjvp) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, jvp = bjvp) bvjp(u, p, t) = [1.0] @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, vjp = bvjp) @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, vjp = bvjp) bvjp(u, v, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, vjp = bvjp) -BVPFunction(bfoop, bciip, vjp = bvjp) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, vjp = bvjp) bvjp(du, u, v, p, t) = [1.0] BVPFunction(bfiip, bciip, vjp = bvjp) -BVPFunction(bfoop, bciip, vjp = bvjp) + +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, vjp = bvjp) + +# IntegralFunction + +ioop(u, p) = p * u +iiip(y, u, p) = y .= u * p +i1(u) = u +itoo(y, u, p, a) = y .= u * p + +IntegralFunction(ioop) +IntegralFunction(iiip, Float64[]) + +@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(ioop, Float64[]) +@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(iiip) +@test_throws SciMLBase.TooFewArgumentsError IntegralFunction(i1) +@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo) +@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo, Float64[]) + +# BatchIntegralFunction + +boop(u, p) = p .* u +biip(y, u, p) = y .= p .* u +bi1(u) = u +bitoo(y, u, p, a) = y .= p .* u + +BatchIntegralFunction(boop) +BatchIntegralFunction(boop, max_batch = 20) +BatchIntegralFunction(biip, Float64[]) +BatchIntegralFunction(biip, Float64[], max_batch = 20) + +@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(boop, Float64[]) +@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(biip) +@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1) +@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo) +@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo, Float64[])