Skip to content

Commit

Permalink
Improves OptimizationSystem (#1787)
Browse files Browse the repository at this point in the history
* improves nested system handling for OptimizationSystem

Co-authored-by: Yingbo Ma <[email protected]>
  • Loading branch information
ValentinKaisermayer and YingboMa authored Aug 31, 2022
1 parent a457b3f commit d8e4810
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Format check
run: |
julia -e '
out = Cmd(`git diff --name-only`) |> read |> String
out = Cmd(`git diff`) |> read |> String
if out == ""
exit(0)
else
Expand Down
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ export JumpSystem
export ODEProblem, SDEProblem
export NonlinearFunction, NonlinearFunctionExpr
export NonlinearProblem, BlockNonlinearProblem, NonlinearProblemExpr
export OptimizationProblem, OptimizationProblemExpr
export OptimizationProblem, OptimizationProblemExpr, constraints
export AutoModelingToolkit
export SteadyStateProblem, SteadyStateProblemExpr
export JumpProblem, DiscreteProblem
Expand Down
6 changes: 3 additions & 3 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ for prop in [:eqs
:systems
:structure
:op
:equality_constraints
:inequality_constraints
:constraints
:controls
:loss
:bcs
Expand Down Expand Up @@ -1227,7 +1226,8 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
(; A, B, C, D)
end

function linearize(sys, inputs, outputs; op = Dict(), t = 0.0, allow_input_derivatives = false,
function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,
allow_input_derivatives = false,
kwargs...)
lin_fun, ssys = linearization_function(sys, inputs, outputs; kwargs...)
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys
Expand Down
79 changes: 56 additions & 23 deletions src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ $(FIELDS)
```julia
@variables x y z
@parameters σ ρ β
@parameters a b c
op = σ*(y-x) + x*(ρ-z)-y + x*y - β*z
@named os = OptimizationSystem(op, [x,y,z],[σ,ρ,β])
op = a*(y-x) + x*(b-z)-y + x*y - c*z
@named os = OptimizationSystem(op, [x,y,z], [a,b,c])
```
"""
struct OptimizationSystem <: AbstractTimeIndependentSystem
"""Vector of equations defining the system."""
"""Objective function of the system."""
op::Any
"""Unknown variables."""
states::Vector
Expand All @@ -26,18 +26,15 @@ struct OptimizationSystem <: AbstractTimeIndependentSystem
"""Array variables."""
var_to_name::Any
observed::Vector{Equation}
constraints::Vector
"""
Name: the name of the system. These are required to have unique names.
"""
"""List of constraint equations of the system."""
constraints::Vector # {Union{Equation,Inequality}}
"""The unique name of the system."""
name::Symbol
"""
systems: The internal systems
"""
"""The internal systems."""
systems::Vector{OptimizationSystem}
"""
defaults: The default values to use when initial conditions and/or
parameters are not supplied in `ODEProblem`.
The default values to use when initial guess and/or
parameters are not supplied in `OptimizationProblem`.
"""
defaults::Dict
"""
Expand All @@ -48,7 +45,7 @@ struct OptimizationSystem <: AbstractTimeIndependentSystem
constraints, name, systems, defaults, metadata = nothing;
checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckUnits) > 0
check_units(op)
unwrap(op) isa Symbolic && check_units(op)
check_units(observed)
all_dimensionless([states; ps]) || check_units(constraints)
end
Expand All @@ -69,6 +66,11 @@ function OptimizationSystem(op, states, ps;
metadata = nothing)
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))

constraints = value.(scalarize(constraints))
states′ = value.(scalarize(states))
ps′ = value.(scalarize(ps))

if !(isempty(default_u0) && isempty(default_p))
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:OptimizationSystem, force = true)
Expand All @@ -80,12 +82,12 @@ function OptimizationSystem(op, states, ps;
defaults = todict(defaults)
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))

states, ps = value.(states), value.(ps)
var_to_name = Dict()
process_variables!(var_to_name, defaults, states)
process_variables!(var_to_name, defaults, ps)
process_variables!(var_to_name, defaults, states)
process_variables!(var_to_name, defaults, ps)
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
OptimizationSystem(value(op), states, ps, var_to_name,

OptimizationSystem(value(op), states′, ps′, var_to_name,
observed,
constraints,
name, systems, defaults, metadata; checks = checks)
Expand Down Expand Up @@ -124,10 +126,38 @@ function generate_function(sys::OptimizationSystem, vs = states(sys), ps = param
end

function equations(sys::OptimizationSystem)
isempty(get_systems(sys)) ? get_op(sys) :
get_op(sys) + reduce(+, namespace_expr.(get_systems(sys)))
op = get_op(sys)
systems = get_systems(sys)
if isempty(systems)
op
else
op + reduce(+, map(sys_ -> namespace_expr(get_op(sys_), sys_), systems))
end
end

namespace_constraint(eq::Equation, sys) = namespace_equation(eq, sys)

# namespace_constraint(ineq::Inequality, sys) = namespace_inequality(ineq, sys)

# function namespace_inequality(ineq::Inequality, sys, n = nameof(sys))
# _lhs = namespace_expr(ineq.lhs, sys, n)
# _rhs = namespace_expr(ineq.rhs, sys, n)
# Inequality(
# namespace_expr(_lhs, sys, n),
# namespace_expr(_rhs, sys, n),
# ineq.relational_op,
# )
# end

function namespace_constraints(sys::OptimizationSystem)
namespace_constraint.(get_constraints(sys), Ref(sys))
end

function constraints(sys::OptimizationSystem)
cs = get_constraints(sys)
systems = get_systems(sys)
isempty(systems) ? cs : [cs; reduce(vcat, namespace_constraints.(systems))]
end
namespace_expr(sys::OptimizationSystem) = namespace_expr(get_op(sys), sys)

hessian_sparsity(sys::OptimizationSystem) = hessian_sparsity(get_op(sys), states(sys))

Expand Down Expand Up @@ -168,6 +198,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
kwargs...) where {iip}
dvs = states(sys)
ps = parameters(sys)
cstr = constraints(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
Expand Down Expand Up @@ -216,8 +247,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
hess_prototype = nothing
end

if length(sys.constraints) > 0
@named cons_sys = NonlinearSystem(sys.constraints, dvs, ps)
if length(cstr) > 0
@named cons_sys = NonlinearSystem(cstr, dvs, ps)
cons = generate_function(cons_sys, checkbounds = checkbounds,
linenumbers = linenumbers,
expression = Val{false})[2]
Expand All @@ -237,6 +268,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,

_f = DiffEqBase.OptimizationFunction{iip}(f,
sys = sys,
syms = nameof.(states(sys)),
SciMLBase.NoAD();
grad = _grad,
hess = _hess,
Expand All @@ -251,6 +283,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
else
_f = DiffEqBase.OptimizationFunction{iip}(f,
sys = sys,
syms = nameof.(states(sys)),
SciMLBase.NoAD();
grad = _grad,
hess = _hess,
Expand Down
22 changes: 22 additions & 0 deletions test/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,25 @@ OBS2 = OBS
@test isequal(OBS2, @nonamespace sys2.OBS)
@unpack OBS = sys2
@test isequal(OBS2, OBS)

# nested constraints
@testset "nested systems" begin
@variables x y
o1 = (x - 1)^2
o2 = (y - 1 / 2)^2
c1 = [
x ~ 1,
]
c2 = [
y ~ 1,
]
sys1 = OptimizationSystem(o1, [x], [], name = :sys1, constraints = c1)
sys2 = OptimizationSystem(o2, [y], [], name = :sys2, constraints = c2)
sys = OptimizationSystem(0, [], []; name = :sys, systems = [sys1, sys2],
constraints = [sys1.x + sys2.y ~ 2], checks = false)
prob = OptimizationProblem(sys, [0.0, 0.0])

@test isequal(constraints(sys), vcat(sys1.x + sys2.y ~ 2, sys1.x ~ 1, sys2.y ~ 1))
@test isequal(equations(sys), (sys1.x - 1)^2 + (sys2.y - 1 / 2)^2)
@test isequal(states(sys), [sys1.x, sys2.y])
end

0 comments on commit d8e4810

Please sign in to comment.