Skip to content

Commit

Permalink
Merge pull request #636 from SciML/symbolicify
Browse files Browse the repository at this point in the history
[WIP] Keep symbolic expressions from mtk, make it moi compatible later
  • Loading branch information
Vaibhavdixit02 authored Dec 19, 2023
2 parents e4df1c8 + 75d197b commit a99f82c
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 128 deletions.
1 change: 1 addition & 0 deletions lib/OptimizationMOI/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
Expand Down
106 changes: 104 additions & 2 deletions lib/OptimizationMOI/src/OptimizationMOI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ using Reexport
@reexport using Optimization
using MathOptInterface
using Optimization.SciMLBase
using SymbolicIndexingInterface
using SparseArrays
import ModelingToolkit: parameters, states, varmap_to_vars, mergedefaults, toexpr
import ModelingToolkit
using ModelingToolkit: parameters, states, varmap_to_vars, mergedefaults
const MTK = ModelingToolkit
import Symbolics
using Symbolics

const MOI = MathOptInterface

Expand Down Expand Up @@ -116,6 +117,107 @@ function __moi_status_to_ReturnCode(status::MOI.TerminationStatusCode)
end
end

_get_variable_index_from_expr(expr::T) where {T} = throw(MalformedExprException("$expr"))
function _get_variable_index_from_expr(expr::Expr)
_is_var_ref_expr(expr)
return MOI.VariableIndex(expr.args[2])
end

function _is_var_ref_expr(expr::Expr)
expr.head == :ref || throw(MalformedExprException("$expr")) # x[i]
expr.args[1] == :x || throw(MalformedExprException("$expr"))
return true
end

function is_eq(expr::Expr)
expr.head == :call || throw(MalformedExprException("$expr"))
expr.args[1] in [:(==), :(=)]
end

function is_leq(expr::Expr)
expr.head == :call || throw(MalformedExprException("$expr"))
expr.args[1] == :(<=)
end

"""
rep_pars_vals!(expr::T, expr_map)
Replaces variable expressions of the form `:some_variable` or `:(getindex, :some_variable, j)` with
`x[i]` were `i` is the corresponding index in the state vector. Same for the parameters. The
variable/parameter pairs are provided via the `expr_map`.
Expects only expressions where the variables and parameters are of the form `:some_variable`
or `:(getindex, :some_variable, j)` or :(some_variable[j]).
"""
rep_pars_vals!(expr::T, expr_map) where {T} = expr
function rep_pars_vals!(expr::Symbol, expr_map)
for (f, n) in expr_map
isequal(f, expr) && return n
end
return expr
end
function rep_pars_vals!(expr::Expr, expr_map)
if (expr.head == :call && expr.args[1] == getindex) || (expr.head == :ref)
for (f, n) in expr_map
isequal(f, expr) && return n
end
end
Threads.@sync for i in eachindex(expr.args)
i == 1 && expr.head == :call && continue # first arg is the operator
Threads.@spawn expr.args[i] = rep_pars_vals!(expr.args[i], expr_map)
end
return expr
end

"""
symbolify!(e)
Ensures that a given expression is fully symbolic, e.g. no function calls.
"""
symbolify!(e) = e
function symbolify!(e::Expr)
if !(e.args[1] isa Symbol)
e.args[1] = Symbol(e.args[1])
end
symbolify!.(e.args)
return e
end

"""
convert_to_expr(eq, sys; expand_expr = false, pairs_arr = expr_map(sys))
Converts the given symbolic expression to a Julia `Expr` and replaces all symbols, i.e. states and
parameters with `x[i]` and `p[i]`.
# Arguments:
- `eq`: Expression to convert
- `sys`: Reference to the system holding the parameters and states
- `expand_expr=false`: If `true` the symbolic expression is expanded first.
"""
function convert_to_expr(eq, expr_map; expand_expr = false)
if expand_expr
eq = try
Symbolics.expand(eq) # PolyForm sometimes errors
catch e
Symbolics.expand(eq)
end
end
expr = ModelingToolkit.toexpr(eq)

expr = rep_pars_vals!(expr, expr_map)
expr = symbolify!(expr)
return expr
end

function get_expr_map(sys)
dvs = ModelingToolkit.states(sys)
ps = ModelingToolkit.parameters(sys)
return vcat([ModelingToolkit.toexpr(_s) => Expr(:ref, :x, i)
for (i, _s) in enumerate(dvs)],
[ModelingToolkit.toexpr(_p) => Expr(:ref, :p, i)
for (i, _p) in enumerate(ps)])
end

"""
Replaces every expression `:x[i]` with `:x[MOI.VariableIndex(i)]`
"""
Expand Down
136 changes: 16 additions & 120 deletions lib/OptimizationMOI/src/moi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,29 @@ struct MOIOptimizationCache{F <: OptimizationFunction, RC, LB, UB, I, S, EX,
end

function MOIOptimizationCache(prob::OptimizationProblem, opt; kwargs...)
isnothing(prob.f.sys) &&
throw(ArgumentError("Expected an `OptimizationProblem` that was setup via an `OptimizationSystem`, consider `modelingtoolkitize(prob).`"))
f = prob.f
reinit_cache = Optimization.ReInitCache(prob.u0, prob.p)
if isnothing(f.sys)
if f.adtype isa Optimization.AutoModelingToolkit
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
f = Optimization.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons)
else
throw(ArgumentError("Expected an `OptimizationProblem` that was setup via an `OptimizationSystem`, or AutoModelingToolkit ad choice"))
end
end

# TODO: check if the problem is at most bilinear, i.e. affine and or quadratic terms in two variables
expr_map = get_expr_map(prob.f.sys)
expr = repl_getindex!(convert_to_expr(MTK.subs_constants(MTK.objective(prob.f.sys)),
prob.f.sys; expand_expr = false, expr_map))

cons = MTK.constraints(prob.f.sys)
expr = convert_to_expr(f.expr, expr_map; expand_expr = false)
expr = repl_getindex!(expr)
cons = MTK.constraints(f.sys)
cons_expr = Vector{Expr}(undef, length(cons))
Threads.@sync for i in eachindex(cons)
Threads.@spawn cons_expr[i] = repl_getindex!(convert_to_expr(Symbolics.canonical_form(MTK.subs_constants(cons[i])),
prob.f.sys;
expand_expr = false,
expr_map))
Threads.@spawn cons_expr[i] = repl_getindex!(convert_to_expr(f.cons_expr[i], expr_map; expand_expr = false))
end

return MOIOptimizationCache(prob.f,
Optimization.ReInitCache(prob.u0, prob.p),
return MOIOptimizationCache(f,
reinit_cache,
prob.lb,
prob.ub,
prob.int,
Expand Down Expand Up @@ -352,111 +356,3 @@ function collect_moi_terms!(expr::Expr, affine_terms, quadratic_terms, constant)

return
end

_get_variable_index_from_expr(expr::T) where {T} = throw(MalformedExprException("$expr"))
function _get_variable_index_from_expr(expr::Expr)
_is_var_ref_expr(expr)
return MOI.VariableIndex(expr.args[2])
end

function _is_var_ref_expr(expr::Expr)
expr.head == :ref || throw(MalformedExprException("$expr")) # x[i]
expr.args[1] == :x || throw(MalformedExprException("$expr"))
return true
end

function is_eq(expr::Expr)
expr.head == :call || throw(MalformedExprException("$expr"))
expr.args[1] in [:(==), :(=)]
end

function is_leq(expr::Expr)
expr.head == :call || throw(MalformedExprException("$expr"))
expr.args[1] == :(<=)
end

##############################################################################################
## TODO: remove if in ModelingToolkit
"""
rep_pars_vals!(expr::T, expr_map)
Replaces variable expressions of the form `:some_variable` or `:(getindex, :some_variable, j)` with
`x[i]` were `i` is the corresponding index in the state vector. Same for the parameters. The
variable/parameter pairs are provided via the `expr_map`.
Expects only expressions where the variables and parameters are of the form `:some_variable`
or `:(getindex, :some_variable, j)` or :(some_variable[j]).
"""
rep_pars_vals!(expr::T, expr_map) where {T} = expr
function rep_pars_vals!(expr::Symbol, expr_map)
for (f, n) in expr_map
isequal(f, expr) && return n
end
return expr
end
function rep_pars_vals!(expr::Expr, expr_map)
if (expr.head == :call && expr.args[1] == getindex) || (expr.head == :ref)
for (f, n) in expr_map
isequal(f, expr) && return n
end
end
Threads.@sync for i in eachindex(expr.args)
i == 1 && expr.head == :call && continue # first arg is the operator
Threads.@spawn expr.args[i] = rep_pars_vals!(expr.args[i], expr_map)
end
return expr
end

"""
symbolify!(e)
Ensures that a given expression is fully symbolic, e.g. no function calls.
"""
symbolify!(e) = e
function symbolify!(e::Expr)
if !(e.args[1] isa Symbol)
e.args[1] = Symbol(e.args[1])
end
symbolify!.(e.args)
return e
end

"""
get_expr_map(sys)
Make a map from every parameter and state of the given system to an expression indexing its position
in the state or parameter vector.
"""
function get_expr_map(sys)
dvs = ModelingToolkit.states(sys)
ps = ModelingToolkit.parameters(sys)
return vcat([ModelingToolkit.toexpr(_s) => Expr(:ref, :x, i)
for (i, _s) in enumerate(dvs)],
[ModelingToolkit.toexpr(_p) => Expr(:ref, :p, i)
for (i, _p) in enumerate(ps)])
end

"""
convert_to_expr(eq, sys; expand_expr = false, pairs_arr = expr_map(sys))
Converts the given symbolic expression to a Julia `Expr` and replaces all symbols, i.e. states and
parameters with `x[i]` and `p[i]`.
# Arguments:
- `eq`: Expression to convert
- `sys`: Reference to the system holding the parameters and states
- `expand_expr=false`: If `true` the symbolic expression is expanded first.
"""
function convert_to_expr(eq, sys; expand_expr = false, expr_map = get_expr_map(sys))
if expand_expr
eq = try
Symbolics.expand(eq) # PolyForm sometimes errors
catch e
Symbolics.expand(eq)
end
end
expr = ModelingToolkit.toexpr(eq)
expr = rep_pars_vals!(expr, expr_map)
expr = symbolify!(expr)
return expr
end
47 changes: 41 additions & 6 deletions lib/OptimizationMOI/src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ mutable struct MOIOptimizationNLPEvaluator{T, F <: OptimizationFunction, RC, LB,
H::HT
cons_H::Vector{CHT}
callback::CB
obj_expr::Union{Expr, Nothing}
cons_expr::Union{Vector{Expr}, Nothing}
end

function Base.getproperty(evaluator::MOIOptimizationNLPEvaluator, x::Symbol)
Expand Down Expand Up @@ -136,6 +138,37 @@ function MOIOptimizationNLPCache(prob::OptimizationProblem,
lcons = prob.lcons === nothing ? fill(T(-Inf), num_cons) : prob.lcons
ucons = prob.ucons === nothing ? fill(T(Inf), num_cons) : prob.ucons

if f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}
sys = MTK.modelingtoolkitize(prob)
if !isnothing(prob.p) && !(prob.p isa SciMLBase.NullParameters)
unames = variable_symbols(sys)
pnames = parameter_symbols(sys)
us = [unames[i] => prob.u0[i] for i in 1:length(prob.u0)]
ps = [pnames[i] => prob.p[i] for i in 1:length(prob.p)]
sysprob = OptimizationProblem(sys, us, ps)
else
unames = variable_symbols(sys)
us = [unames[i] => prob.u0[i] for i in 1:length(prob.u0)]
sysprob = OptimizationProblem(sys, us)
end

obj_expr = sysprob.f.expr
cons_expr = sysprob.f.cons_expr
else
sys = f.sys
obj_expr = f.expr
cons_expr = f.cons_expr
end

expr_map = get_expr_map(sys)
expr = convert_to_expr(obj_expr, expr_map; expand_expr = false)
expr = repl_getindex!(expr)
cons = MTK.constraints(sys)
_cons_expr = Vector{Expr}(undef, length(cons))
for i in eachindex(cons)
_cons_expr[i] = repl_getindex!(convert_to_expr(cons_expr[i], expr_map; expand_expr = false))
end

evaluator = MOIOptimizationNLPEvaluator(f,
reinit_cache,
prob.lb,
Expand All @@ -147,7 +180,9 @@ function MOIOptimizationNLPCache(prob::OptimizationProblem,
J,
H,
cons_H,
callback)
callback,
expr,
_cons_expr)
return MOIOptimizationNLPCache(evaluator, opt, NamedTuple(kwargs))
end

Expand Down Expand Up @@ -334,21 +369,21 @@ function MOI.eval_hessian_lagrangian(evaluator::MOIOptimizationNLPEvaluator{T},
end

function MOI.objective_expr(evaluator::MOIOptimizationNLPEvaluator)
expr = deepcopy(evaluator.f.expr)
expr = deepcopy(evaluator.obj_expr)
repl_getindex!(expr)
_replace_parameter_indices!(expr, evaluator.p)
_replace_variable_indices!(expr)
return expr
end

function MOI.constraint_expr(evaluator::MOIOptimizationNLPEvaluator, i)
# expr has the form f(x,p) == 0 or f(x,p) <= 0
cons_expr = deepcopy(evaluator.f.cons_expr[i].args[2])
# expr has the form f(x,p) == 0 or f(x,p) <= 0
cons_expr = deepcopy(evaluator.cons_expr[i].args[2])
compop = Symbol(evaluator.cons_expr[i].args[1])
repl_getindex!(cons_expr)
_replace_parameter_indices!(cons_expr, evaluator.p)
_replace_variable_indices!(cons_expr)
lb, ub = Float64(evaluator.lcons[i]), Float64(evaluator.ucons[i])
return :($lb <= $cons_expr <= $ub)
return Expr(:call, compop, cons_expr, 0.0)
end

function _add_moi_variables!(opt_setup, evaluator::MOIOptimizationNLPEvaluator)
Expand Down

0 comments on commit a99f82c

Please sign in to comment.