Skip to content

Commit

Permalink
Merge pull request #2323 from SciML/optsysfix
Browse files Browse the repository at this point in the history
Keep symbolic expressions as is and minor bugfixes optimization system
  • Loading branch information
Vaibhavdixit02 authored Dec 14, 2023
2 parents c10d206 + 9699028 commit ab2c452
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/systems/optimization/modelingtoolkitize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem; kwargs...)
if !isnothing(prob.lcons)
for i in 1:num_cons
if !isinf(prob.lcons[i])
if prob.lcons[i] != prob.ucons[i] &&
if prob.lcons[i] != prob.ucons[i]
push!(cons, prob.lcons[i] lhs[i])
else
push!(cons, lhs[i] ~ prob.ucons[i])
Expand Down
24 changes: 5 additions & 19 deletions src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,6 @@ end

hessian_sparsity(sys::OptimizationSystem) = hessian_sparsity(get_op(sys), states(sys))

function rep_pars_vals!(e::Expr, p)
rep_pars_vals!.(e.args, Ref(p))
replace!(e.args, p...)
end

function rep_pars_vals!(e, p) end

"""
```julia
DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
Expand Down Expand Up @@ -275,14 +268,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
f = generate_function(sys, checkbounds = checkbounds, linenumbers = linenumbers,
expression = Val{false})

obj_expr = toexpr(subs_constants(objective(sys)))
pairs_arr = if p isa SciMLBase.NullParameters
[Symbol(_s) => Expr(:ref, :x, i) for (i, _s) in enumerate(dvs)]
else
vcat([Symbol(_s) => Expr(:ref, :x, i) for (i, _s) in enumerate(dvs)],
[Symbol(_p) => p[i] for (i, _p) in enumerate(ps)])
end
rep_pars_vals!(obj_expr, pairs_arr)
obj_expr = subs_constants(objective(sys))

if grad
grad_oop, grad_iip = generate_gradient(sys, checkbounds = checkbounds,
linenumbers = linenumbers,
Expand Down Expand Up @@ -342,14 +329,13 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
else
_cons_h = nothing
end
cons_expr = toexpr.(subs_constants(constraints(cons_sys)))
rep_pars_vals!.(cons_expr, Ref(pairs_arr))
cons_expr = subs_constants(constraints(cons_sys))

if !haskey(kwargs, :lcons) && !haskey(kwargs, :ucons) # use the symbolically specified bounds
lcons = lcons_
ucons = ucons_
else # use the user supplied constraints bounds
haskey(kwargs, :lcons) && haskey(kwargs, :ucons) &&
(haskey(kwargs, :lcons) haskey(kwargs, :ucons)) &&
throw(ArgumentError("Expected both `ucons` and `lcons` to be supplied"))
haskey(kwargs, :lcons) && length(kwargs[:lcons]) != length(cstr) &&
throw(ArgumentError("Expected `lcons` to be of the same length as the vector of constraints"))
Expand Down Expand Up @@ -527,7 +513,7 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0map,
lcons = lcons_
ucons = ucons_
else # use the user supplied constraints bounds
!haskey(kwargs, :lcons) && !haskey(kwargs, :ucons) &&
(haskey(kwargs, :lcons) haskey(kwargs, :ucons)) &&
throw(ArgumentError("Expected both `ucons` and `lcons` to be supplied"))
haskey(kwargs, :lcons) && length(kwargs[:lcons]) != length(cstr) &&
throw(ArgumentError("Expected `lcons` to be of the same length as the vector of constraints"))
Expand Down
13 changes: 13 additions & 0 deletions test/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,16 @@ end

@test sol1.u sol2.u
end

@testset "#2323 keep symbolic expressions and xor condition on constraint bounds" begin
@variables x y
@parameters a b
loss = (a - x)^2 + b * (y - x^2)^2
@named sys = OptimizationSystem(loss, [x, y], [a, b], constraints = [x^2 + y^2 0.0])
@test_throws ArgumentError OptimizationProblem(sys, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0], lcons = [0.0])
@test_throws ArgumentError OptimizationProblem(sys, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0], ucons = [0.0])

prob = OptimizationProblem(sys, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0])
@test prob.f.expr isa Symbolics.Symbolic
@test all(prob.f.cons_expr[i].lhs isa Symbolics.Symbolic for i in 1:length(prob.f.cons_expr))
end

0 comments on commit ab2c452

Please sign in to comment.