Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a nlls trait to BVProblem #567

Merged
merged 7 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ $(TYPEDEF)

Base for types which define BVP problems.
"""
abstract type AbstractBVProblem{uType, tType, isinplace} <:
abstract type AbstractBVProblem{uType, tType, isinplace, nlls} <:
AbstractODEProblem{uType, tType, isinplace} end

"""
Expand Down
66 changes: 56 additions & 10 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@doc doc"""

Defines an BVP problem.
Documentation Page: https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/
Documentation Page: [https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/](https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/)

## Mathematical Specification of a BVP Problem

Expand Down Expand Up @@ -41,16 +41,16 @@
### Constructors

```julia
TwoPointBVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
TwoPointBVProblem{isinplace}(f, bc, u0, tspan, p=NullParameters(); kwargs...)
BVProblem{isinplace}(f, bc, u0, tspan, p=NullParameters(); kwargs...)
```

or if we have an initial guess function `initialGuess(t)` for the given BVP,
or if we have an initial guess function `initialGuess(p, t)` for the given BVP,
we can pass the initial guess to the problem constructors:

```julia
TwoPointBVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
TwoPointBVProblem{isinplace}(f, bc, initialGuess, tspan, p=NullParameters(); kwargs...)
BVProblem{isinplace}(f, bc, initialGuess, tspan, p=NullParameters(); kwargs...)
```

For any BVP problem type, `bc` must be inplace if `f` is inplace. Otherwise it must be
Expand Down Expand Up @@ -104,9 +104,17 @@
* `tspan`: The timespan for the problem.
* `p`: The parameters for the problem. Defaults to `NullParameters`
* `kwargs`: The keyword arguments passed onto the solves.

### Special Keyword Arguments

- `nlls`: Specify that the BVP is a nonlinear least squares problem. Use `Val(true)` or
`Val(false)` for type stability. By default this is automatically inferred based on the
size of the input and outputs, however this is type unstable for any array type that
doesn't store array size as part of type information. If we can't reliably infer this,
we set it to `Nothing`. Downstreams solvers must be setup to deal with this case.
"""
struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace}
struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace, nlls}
f::F
u0::uType
tspan::tType
Expand All @@ -115,18 +123,56 @@
kwargs::K

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
p = NullParameters(); problem_type = nothing, kwargs...) where {iip, TP}
p = NullParameters(); problem_type = nothing, nlls = nothing,
kwargs...) where {iip, TP}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
prob_type = TP ? TwoPointBVProblem{iip}() : StandardBVProblem()

# Needed to ensure that `problem_type` doesn't get passed in kwargs
if problem_type === nothing
problem_type = prob_type
else
@assert prob_type===problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end
return new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f),

if nlls === nothing
if !hasmethod(length, Tuple{typeof(_u0)})
# If _u0 is a function for initial guess we won't be able to infer
__u0 = _u0 isa Function ?

Check warning on line 143 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L143

Added line #L143 was not covered by tests
(hasmethod(_u0, Tuple{typeof(p), typeof(first(_tspan))}) ?
_u0(p, first(_tspan)) : _u0(first(_tspan))) : nothing
else
__u0 = _u0
end
# Try to infer it
if __u0 isa Nothing
_nlls = Nothing

Check warning on line 151 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L151

Added line #L151 was not covered by tests
elseif problem_type isa TwoPointBVProblem
if f.bcresid_prototype !== nothing
l1, l2 = length(f.bcresid_prototype[1]), length(f.bcresid_prototype[2])
_nlls = l1 + l2 != length(__u0)
else
# iip without bcresid_prototype is not possible
if !iip
l1 = length(f.bc[1](u0, p))
l2 = length(f.bc[2](u0, p))
_nlls = l1 + l2 != length(__u0)

Check warning on line 161 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L158-L161

Added lines #L158 - L161 were not covered by tests
end
end
else
if f.bcresid_prototype !== nothing
_nlls = length(f.bcresid_prototype) != length(__u0)
else
_nlls = Nothing # Cannot reliably infer
end
end
else
_nlls = _unwrap_val(nlls)
end

return new{typeof(_u0), typeof(_tspan), iip, _nlls, typeof(p), typeof(f),
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
end

Expand Down
29 changes: 16 additions & 13 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
(prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
end

function remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, kwargs...)
function remake(prob::AbstractSciMLProblem; u0 = missing,

Check warning on line 48 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L48

Added line #L48 was not covered by tests
p = missing, interpret_symbolicmap = true, kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., u0, p)
end
Expand All @@ -54,7 +55,8 @@
_remake_internal(prob; kwargs...)
end

function remake(prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
function remake(

Check warning on line 58 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L58

Added line #L58 was not covered by tests
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
p = updated_p(prob, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., p)
end
Expand Down Expand Up @@ -128,16 +130,15 @@

Remake the given `BVProblem`.
"""
function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, interpret_symbolicmap = true, _kwargs...)
function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing,
u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing,
interpret_symbolicmap = true, _kwargs...) where {uType, tType, iip, nlls}
if tspan === missing
tspan = prob.tspan
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

iip = isinplace(prob)

if problem_type === missing
problem_type = prob.problem_type
end
Expand Down Expand Up @@ -170,9 +171,11 @@
end

if kwargs === missing
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, prob.kwargs..., _kwargs...)
BVProblem{iip}(
_f, bc, u0, tspan, p; problem_type, nlls = Val(nlls), prob.kwargs...,
_kwargs...)
else
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, kwargs...)
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls = Val(nlls), kwargs...)

Check warning on line 178 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L178

Added line #L178 was not covered by tests
end
end

Expand Down Expand Up @@ -254,7 +257,6 @@
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
Expand Down Expand Up @@ -393,10 +395,11 @@
end
if eltype(p) <: Pair
if interpret_symbolicmap
has_sys(prob.f) || throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
has_sys(prob.f) ||
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
else
return p
end
Expand Down
64 changes: 41 additions & 23 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1929,35 +1929,46 @@
A representation of a BVP function `f`, defined by:

```math
\frac{du}{dt}=f(u,p,t)
\frac{du}{dt} = f(u, p, t)
```

and the constraints:

```math
\frac{du}{dt}=g(u,p,t)
g(u, p, t) = 0
```

If the size of `g(u, p, t)` is different from the size of `u`, then the constraints are
interpreted as a least squares problem, i.e. the objective function is:

```math
\min_{u} \| g_i(u, p, t) \|^2
```

and all of its related functions, such as the Jacobian of `f`, its gradient
with respect to time, and more. For all cases, `u0` is the initial condition,
`p` are the parameters, and `t` is the independent variable.

```julia
BVPFunction{iip,specialize}(f, bc;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
bcjac = __has_jac(bc) ? bc.jac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
bccolorvec = __has_colorvec(f) ? bc.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing)
BVPFunction{iip, specialize}(f, bc;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
bcjac = __has_jac(bc) ? bc.jac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
syms = nothing,
indepsym= nothing,
paramsyms = nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
bccolorvec = __has_colorvec(f) ? bc.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
twopoint::Union{Val, Bool} = Val(false)
```

Note that both the function `f` and boundary condition `bc` are required. `f` should
Expand Down Expand Up @@ -1985,7 +1996,7 @@
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
The default is `nothing`, which means a dense Jacobian.
- `bcjac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
as the prototype and integrators will specialize on this structure where possible. Non-structured
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
The default is `nothing`, which means a dense Jacobian.
Expand All @@ -2003,6 +2014,11 @@
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.

Additional Options:

- `twopoint`: Specify that the BVP is a two-point boundary value problem. Use `Val(true)` or
`Val(false)` for type stability.

## iip: In-Place vs Out-Of-Place

For more details on this argument, see the ODEFunction documentation.
Expand Down Expand Up @@ -3801,7 +3817,7 @@

_f = prepare_function(f)

sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))

if specialize === NoSpecialize
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any,
Expand All @@ -3813,9 +3829,9 @@
sparsity, Wfact, Wfact_t, paramjac, observed,
_colorvec, _bccolorvec, sys)
else
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), typeof(mass_matrix),
typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp),
typeof(vjp), typeof(jac_prototype),
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc),
typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac),
typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(
Expand Down Expand Up @@ -3897,7 +3913,9 @@
function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
if sys === nothing &&
(syms !== nothing || paramsyms !== nothing || indepsym !== nothing)
Base.depwarn("The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead.", :syms)
Base.depwarn(

Check warning on line 3916 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L3916

Added line #L3916 was not covered by tests
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead.",
:syms)
sys = SymbolCache(syms, paramsyms, indepsym)
end
return sys
Expand Down
Loading
Loading