Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BVProblem with constraints #3323

Merged
merged 60 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
9733460
init
vyudu Nov 22, 2024
d95e4a7
Merge remote-tracking branch 'origin/master' into MTK
vyudu Nov 22, 2024
b3da813
up
vyudu Dec 1, 2024
86c82ce
Merge remote-tracking branch 'origin/master' into MTK
vyudu Dec 1, 2024
4affeac
up
vyudu Dec 1, 2024
a3429ea
up
vyudu Dec 1, 2024
f751fbb
up
vyudu Dec 2, 2024
a9fdfd6
up
vyudu Dec 3, 2024
a9f2106
up
vyudu Dec 4, 2024
18fdd5f
up
vyudu Dec 13, 2024
9d65a33
fixing create_array
vyudu Dec 16, 2024
999ec30
revert Project.toml
vyudu Dec 16, 2024
9226ad6
Up
vyudu Dec 16, 2024
0cb4893
Merge remote-tracking branch 'origin/master' into MTK
vyudu Dec 16, 2024
67d8164
formatting
vyudu Dec 16, 2024
25988f3
up
vyudu Dec 17, 2024
bb28d4f
up
vyudu Dec 17, 2024
b2bf7c0
fix
vyudu Dec 17, 2024
3751c2a
up
vyudu Dec 20, 2024
ef1f089
up
vyudu Jan 8, 2025
d23d6f7
Merge remote-tracking branch 'origin/master' into MTK
vyudu Jan 8, 2025
2a25200
extend BVProblem for constraint equations
vyudu Jan 9, 2025
50504ab
adding tests
vyudu Jan 11, 2025
5d082ab
up
vyudu Jan 11, 2025
b83e003
refactor the bc creation function
vyudu Jan 14, 2025
db5eb66
up
vyudu Jan 14, 2025
e802946
test update
vyudu Jan 15, 2025
e74e047
fix
vyudu Jan 15, 2025
86d4144
test more solvers:
vyudu Jan 17, 2025
ec386fe
Refactor constraints
vyudu Jan 28, 2025
90ce80d
refactor tests
vyudu Jan 28, 2025
a15c670
fix sym validation
vyudu Jan 28, 2025
c6ef04a
remove file
vyudu Jan 28, 2025
7878225
up
vyudu Jan 28, 2025
5bcfdff
up
vyudu Jan 28, 2025
0493b5d
remove lines
vyudu Jan 28, 2025
1d32b6e
up
vyudu Jan 28, 2025
2b3ca96
up
vyudu Jan 28, 2025
0324522
fix typo
vyudu Jan 28, 2025
2a079be
Fix setter
vyudu Jan 28, 2025
d70a470
fix
vyudu Jan 28, 2025
37092f1
lower tol
vyudu Jan 29, 2025
e5eb8bd
fix Project.toml
vyudu Jan 29, 2025
2ae79ae
revert to OrdinaryDiffEq
vyudu Jan 30, 2025
8ae2803
merge master
vyudu Feb 3, 2025
13a242c
update to use updated codegen
vyudu Feb 3, 2025
2fcb9c9
up
vyudu Feb 3, 2025
25b56d7
working codegen
vyudu Feb 4, 2025
c35b797
revert to OrdinaryDiffEqDefault
vyudu Feb 4, 2025
25e84db
use MIRK
vyudu Feb 4, 2025
e6a6932
up
vyudu Feb 4, 2025
5e5c24c
revert to OrdinaryDiffEq
vyudu Feb 4, 2025
5338d4f
tests passing
vyudu Feb 4, 2025
810d4fa
remove problematic tests, codegen assumes MTKParameters
vyudu Feb 4, 2025
6740b8c
test fix
vyudu Feb 4, 2025
603c894
Update src/systems/diffeqs/odesystem.jl
ChrisRackauckas Feb 10, 2025
b10a4a6
Merge branch 'master' into BVP-with-constraints
ChrisRackauckas Feb 10, 2025
9b492cd
Merge remote-tracking branch 'vyudu/BVP-with-constraints' into BVP-wi…
vyudu Feb 11, 2025
dce19c7
Merge remote-tracking branch 'origin' into BVP-with-constraints
vyudu Feb 17, 2025
3642e1b
Merge branch 'master' into BVP-with-constraints
vyudu Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ AbstractTrees = "0.3, 0.4"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
BoundaryValueDiffEqAscher = "1.1.0"
BoundaryValueDiffEqMIRK = "1.4.0"
ChainRulesCore = "1"
Combinatorics = "1"
CommonSolve = "0.2.4"
Expand Down Expand Up @@ -157,6 +159,8 @@ julia = "1.9"
[extras]
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
Expand Down Expand Up @@ -189,4 +193,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging"]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging"]
8 changes: 4 additions & 4 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ include("systems/codegen_utils.jl")
include("systems/problem_utils.jl")
include("linearization.jl")

include("systems/optimization/constraints_system.jl")
include("systems/optimization/optimizationsystem.jl")
include("systems/optimization/modelingtoolkitize.jl")

include("systems/nonlinear/nonlinearsystem.jl")
include("systems/nonlinear/homotopy_continuation.jl")
include("systems/diffeqs/odesystem.jl")
Expand All @@ -170,10 +174,6 @@ include("systems/discrete_system/discrete_system.jl")

include("systems/jumps/jumpsystem.jl")

include("systems/optimization/constraints_system.jl")
include("systems/optimization/optimizationsystem.jl")
include("systems/optimization/modelingtoolkitize.jl")

include("systems/pde/pdesystem.jl")

include("systems/sparsematrixclil.jl")
Expand Down
1 change: 1 addition & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@
"""
Initial(x)

The `Initial` operator. Used by initializaton to store constant constraints on variables

Check warning on line 625 in src/systems/abstractsystem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"initializaton" should be "initialization".
of a system. See the documentation section on initialization for more information.
"""
struct Initial <: Symbolics.Operator end
Expand Down Expand Up @@ -823,6 +823,7 @@
:structure
:op
:constraints
:constraintsystem
:controls
:loss
:bcs
Expand Down
164 changes: 164 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
end

if !isnothing(get_constraintsystem(sys))
error("An ODESystem with constraints cannot be used to construct a regular ODEProblem.
Consider a BVProblem instead.")
end

f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
Expand All @@ -757,6 +763,164 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

"""
```julia
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
constraints = nothing, guesses = nothing,
version = nothing, tgrad = false,
jac = true, sparse = true,
simplify = false,
kwargs...) where {iip}
```

Create a boundary value problem from the [`ODESystem`](@ref).

`u0map` is used to specify fixed initial values for the states. Every variable
must have either an initial guess supplied using `guesses` or a fixed initial
value specified using `u0map`.

Boundary value conditions are supplied to ODESystems
in the form of a ConstraintsSystem. These equations
should specify values that state variables should
take at specific points, as in `x(0.5) ~ 1`). More general constraints that
should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be
specified as one of the equations used to build the `ODESystem`.

If an ODESystem without `constraints` is specified, it will be treated as an initial value problem.

```julia
@parameters g t_c = 0.5
@variables x(..) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x(t))) ~ λ * x(t)
D(D(y)) ~ λ * y - g
x(t)^2 + y^2 ~ 1]
cstr = [x(0.5) ~ 1]
@named cstrs = ConstraintsSystem(cstr, t)
@mtkbuild pend = ODESystem(eqs, t)

tspan = (0.0, 1.5)
u0map = [x(t) => 0.6, y => 0.8]
parammap = [g => 1]
guesses = [λ => 1]
constraints = [x(0.5) ~ 1]

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
```

If the `ODESystem` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
"""
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem(sys::AbstractODESystem,
u0map::StaticArray,
args...;
kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
guesses = Dict(),
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}

if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end
!isnothing(callback) && error("BVP solvers do not support callbacks.")

has_alg_eqs(sys) && error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.

sts = unknowns(sys)
ps = parameters(sys)
constraintsys = get_constraintsystem(sys)

if !isnothing(constraintsys)
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
end

# ODESystems without algebraic equations should use both fixed values + guesses
# for initialization.
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan, guesses,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]

fns = generate_function_bc(sys, u0, u0_idxs, tspan)
bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module)
bc(sol, p, t) = bc_oop(sol, p, t)
bc(resid, u, p, t) = bc_iip(resid, u, p, t)

return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
end

get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")

"""
generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan)

Given an ODESystem with constraints, generate the boundary condition function to pass to boundary value problem solvers.
Expression uses the constraints and the provided initial conditions.
"""
function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
iv = get_iv(sys)
sts = unknowns(sys)
ps = parameters(sys)
np = length(ps)
ns = length(sts)
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])

@variables sol(..)[1:ns]

conssys = get_constraintsystem(sys)
cons = Any[]
if !isnothing(conssys)
cons = [con.lhs - con.rhs for con in constraints(conssys)]

for st in get_unknowns(conssys)
x = operation(st)
t = only(arguments(st))
idx = stidxmap[x(iv)]

cons = map(c -> Symbolics.substitute(c, Dict(x(t) => sol(t)[idx])), cons)
end
end

init_conds = Any[]
for i in u0_idxs
expr = sol(tspan[1])[i] - u0[i]
push!(init_conds, expr)
end

exprs = vcat(init_conds, cons)
_p = reorder_parameters(sys, ps)

build_function_wrapper(sys, exprs, sol, _p..., t; output_type = Array, kwargs...)
end

"""
```julia
DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
Expand Down
77 changes: 71 additions & 6 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ struct ODESystem <: AbstractODESystem
ctrls::Vector
"""Observed variables."""
observed::Vector{Equation}
"""System of constraints that must be satisfied by the solution to the system."""
constraintsystem::Union{Nothing, ConstraintsSystem}
"""
Time-derivative matrix. Note: this field will not be defined until
[`calculate_tgrad`](@ref) is called on the system.
Expand Down Expand Up @@ -191,7 +193,7 @@ struct ODESystem <: AbstractODESystem
"""
parent::Any

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
torn_matching, initializesystem, initialization_eqs, schedule,
connector_type, preface, cevents,
Expand All @@ -212,7 +214,7 @@ struct ODESystem <: AbstractODESystem
u = __get_unit_type(dvs, ps, iv)
check_units(u, deqs)
end
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
initializesystem, initialization_eqs, schedule, connector_type, preface,
cevents, devents, parameter_dependencies, assertions, metadata,
Expand All @@ -224,6 +226,7 @@ end
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Equation[],
constraintsystem = nothing,
systems = ODESystem[],
tspan = nothing,
name = nothing,
Expand Down Expand Up @@ -297,17 +300,29 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
if is_dde === nothing
is_dde = _check_if_dde(deqs, iv′, systems)
end

if !isempty(systems) && !isnothing(constraintsystem)
conssystems = ConstraintsSystem[]
for sys in systems
cons = get_constraintsystem(sys)
cons !== nothing && push!(conssystems, cons)
end
@show conssystems
@set! constraintsystem.systems = conssystems
end

assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)

ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems,
defaults, guesses, nothing, initializesystem,
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
disc_callbacks, parameter_dependencies, assertions,
metadata, gui_metadata, is_dde, tstops, checks = checks)
end

function ODESystem(eqs, iv; kwargs...)
function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
Expand Down Expand Up @@ -339,8 +354,22 @@ function ODESystem(eqs, iv; kwargs...)
end
algevars = setdiff(allunknowns, diffvars)

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
collect(new_ps); kwargs...)
consvars = OrderedSet()
constraintsystem = nothing
if !isempty(constraints)
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
for st in get_unknowns(constraintsystem)
iscall(st) ?
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
!in(st, allunknowns) && push!(consvars, st)
end
for p in parameters(constraintsystem)
!in(p, new_ps) && push!(new_ps, p)
end
end

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
collect(new_ps); constraintsystem, kwargs...)
end

# NOTE: equality does not check cached Jacobian
Expand Down Expand Up @@ -668,3 +697,39 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,

return nothing
end

# Validate that all the variables in the BVP constraints are well-formed states or parameters.
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
isempty(constraints) && return nothing

constraintsts = OrderedSet()
constraintps = OrderedSet()

for cons in constraints
collect_vars!(constraintsts, constraintps, cons, iv)
end

# Validate the states.
for var in constraintsts
if !iscall(var)
occursin(iv, var) && (var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
elseif length(arguments(var)) > 1
throw(ArgumentError("Too many arguments for variable $var."))
elseif length(arguments(var)) == 1
arg = only(arguments(var))
operation(var)(iv) ∈ sts ||
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))

isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat ||
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))

isparameter(arg) && push!(constraintps, arg)
else
var ∈ sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
end
end

ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
end
2 changes: 1 addition & 1 deletion src/systems/optimization/constraints_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ function ConstraintsSystem(constraints, unknowns, ps;
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))

cstr = value.(Symbolics.canonical_form.(scalarize(constraints)))
cstr = value.(Symbolics.canonical_form.(vcat(scalarize(constraints)...)))
unknowns′ = value.(scalarize(unknowns))
ps′ = value.(ps)

Expand Down
Loading
Loading