Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: do not scalarize parameters, fix some tests #2469

Merged
merged 29 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
298c29c
Array equations/variables support in `structural_simplify`
YingboMa Feb 15, 2024
75d343f
Support array states in ODEProblem/ODEFunction
YingboMa Feb 16, 2024
ae6bd49
feat: scalarize equations in ODESystem, fix vars! and namespace_expr
AayushSabharwal Feb 19, 2024
0287c1d
feat!: do not scalarize parameters, fix some tests
AayushSabharwal Feb 15, 2024
36388d0
test: fix test file paths
AayushSabharwal Feb 15, 2024
f9f7ae1
refactor: support clock systems without index caches
AayushSabharwal Feb 16, 2024
1591342
test: test callbacks without parameter splitting
AayushSabharwal Feb 16, 2024
f48d30d
refactor: format
AayushSabharwal Feb 16, 2024
6abcc49
refactor: disable treating symbolic defaults as param dependencies
AayushSabharwal Feb 18, 2024
9eb96a5
feat: add support for parameter dependencies
AayushSabharwal Feb 19, 2024
55f2730
docs: update NEWS with parameter dependencies
AayushSabharwal Feb 19, 2024
c26d4d9
feat: un-scalarize inferred parameters, improve parameter initialization
AayushSabharwal Feb 19, 2024
812e004
feat: flatten equations to avoid scalarizing array arguments
AayushSabharwal Feb 19, 2024
059f6e8
fix formatting
ChrisRackauckas Feb 19, 2024
55cd1d8
fix typo
ChrisRackauckas Feb 19, 2024
0c26977
fix: fix `vars!`
AayushSabharwal Feb 20, 2024
8d7c677
fix: refactor IndexCache for non-scalarized unknowns
AayushSabharwal Feb 20, 2024
9e2c9bc
fix: do not call flatten_equations in JumpSystem
AayushSabharwal Feb 20, 2024
a6add74
fix: handle broadcasted equations and array variables in ODESystem co…
AayushSabharwal Feb 20, 2024
a41a64f
fix: use variable_index in calculate_massmatrix
AayushSabharwal Feb 20, 2024
c6c96dd
fix: do not scalarize in system constructors
AayushSabharwal Feb 20, 2024
d7265c1
test: fix mass matrix tests
AayushSabharwal Feb 20, 2024
1275d6e
fixup! fix: fix `vars!`
AayushSabharwal Feb 20, 2024
3e0aea0
fix: fix IndexCache to not put matrices as nonnumeric parameters
AayushSabharwal Feb 21, 2024
1218152
feat: add copy method for MTKParameters
AayushSabharwal Feb 21, 2024
9d5c211
Skip partial_state_selection test
YingboMa Feb 21, 2024
7ce0f57
format
ChrisRackauckas Feb 22, 2024
703da35
Update Project.toml
ChrisRackauckas Feb 22, 2024
3ede8ff
Update Project.toml
ChrisRackauckas Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,12 @@ function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys
# metadata from the rescoped variable
rescoped = renamespace(n, O)
similarterm(O, operation(rescoped), renamed,
metadata = metadata(rescoped))::T
metadata = metadata(rescoped))
elseif Symbolics.isarraysymbolic(O)
# promote_symtype doesn't work for array symbolics
similarterm(O, operation(O), renamed, symtype(O), metadata = metadata(O))
else
similarterm(O, operation(O), renamed, metadata = metadata(O))::T
similarterm(O, operation(O), renamed, metadata = metadata(O))
end
elseif isvariable(O)
renamespace(n, O)
Expand Down
23 changes: 15 additions & 8 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,18 +292,20 @@ function connection2set!(connectionsets, namespace, ss, isouter)
end
end

function generate_connection_set(sys::AbstractSystem, find = nothing, replace = nothing)
function generate_connection_set(
sys::AbstractSystem, find = nothing, replace = nothing; scalarize = false)
connectionsets = ConnectionSet[]
domain_csets = ConnectionSet[]
sys = generate_connection_set!(connectionsets, domain_csets, sys, find, replace)
sys = generate_connection_set!(
connectionsets, domain_csets, sys, find, replace, scalarize)
csets = merge(connectionsets)
domain_csets = merge([csets; domain_csets], true)

sys, (csets, domain_csets)
end

function generate_connection_set!(connectionsets, domain_csets,
sys::AbstractSystem, find, replace, namespace = nothing)
sys::AbstractSystem, find, replace, scalarize, namespace = nothing)
subsys = get_systems(sys)

isouter = generate_isouter(sys)
Expand All @@ -325,8 +327,13 @@ function generate_connection_set!(connectionsets, domain_csets,
end
neweq isa AbstractArray ? append!(eqs, neweq) : push!(eqs, neweq)
else
if lhs isa Number || lhs isa Symbolic
push!(eqs, eq) # split connections and equations
if lhs isa Number || lhs isa Symbolic || eltype(lhs) <: Symbolic
# split connections and equations
if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray
append!(eqs, Symbolics.scalarize(eq))
else
push!(eqs, eq)
end
elseif lhs isa Connection && get_systems(lhs) === :domain
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
else
Expand Down Expand Up @@ -356,7 +363,7 @@ function generate_connection_set!(connectionsets, domain_csets,
end
@set! sys.systems = map(
s -> generate_connection_set!(connectionsets, domain_csets, s,
find, replace,
find, replace, scalarize,
renamespace(namespace, s)),
subsys)
@set! sys.eqs = eqs
Expand Down Expand Up @@ -471,8 +478,8 @@ function domain_defaults(sys, domain_csets)
end

function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
debug = false, tol = 1e-10)
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace)
debug = false, tol = 1e-10, scalarize = true)
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize)
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)
sys = flatten(sys, true)
Expand Down
27 changes: 27 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,21 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
nothing,
isdde = false,
kwargs...)
array_vars = Dict{Any, Vector{Int}}()
for (j, x) in enumerate(dvs)
if istree(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)
end
end
subs = Dict()
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
inds = inds′
end
subs[k] = term(view, Sym{Any}(Symbol("ˍ₋arg1")), inds)
end
if isdde
eqs = delay_to_function(sys)
else
Expand All @@ -164,6 +179,7 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
# substitute x(t) by just x
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
[eq.rhs for eq in eqs]
rhss = fast_substitute(rhss, subs)

# TODO: add an optional check on the ordering of observed equations
u = map(x -> time_varying_as_func(value(x), sys), dvs)
Expand Down Expand Up @@ -764,6 +780,17 @@ function get_u0_p(sys,
defs = mergedefaults(defs, parammap, ps)
end
defs = mergedefaults(defs, u0map, dvs)
for (k, v) in defs
if Symbolics.isarraysymbolic(k)
ks = scalarize(k)
length(ks) == length(v) || error("$k has default value $v with unmatched size")
for (kk, vv) in zip(ks, v)
if !haskey(defs, kk)
defs[kk] = vv
end
end
end
end

if symbolic_u0
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
Expand Down
12 changes: 5 additions & 7 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,12 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
gui_metadata = nothing)
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
deqs = scalarize(deqs)
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."

iv′ = value(scalarize(iv))
ps′ = value.(scalarize(ps))
ctrl′ = value.(scalarize(controls))
dvs′ = value.(scalarize(dvs))
deqs = reduce(vcat, scalarize(deqs); init = Equation[])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should never scalarize during construction for MTKv9. We should just follow what I did in #2472

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one wants a scalarized system for simulation, then one needs to call structural_simplify.

iv′ = value(iv)
ps′ = value.(ps)
ctrl′ = value.(controls)
dvs′ = value.(dvs)
dvs′ = filter(x -> !isdelay(x, iv), dvs′)

if !(isempty(default_u0) && isempty(default_p))
Expand Down Expand Up @@ -236,7 +235,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
end

function ODESystem(eqs, iv; kwargs...)
eqs = scalarize(eqs)
# NOTE: this assumes that the order of algebraic equations doesn't matter
diffvars = OrderedSet()
allunknowns = OrderedSet()
Expand Down
16 changes: 14 additions & 2 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ function TearingState(sys; quick_cancel = false, check = true)
end

vars = OrderedSet()
varsvec = []
for (i, eq′) in enumerate(eqs)
if eq′.lhs isa Connection
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
Expand All @@ -282,9 +283,18 @@ function TearingState(sys; quick_cancel = false, check = true)
eq = 0 ~ rhs - lhs
end
vars!(vars, eq.rhs, op = Symbolics.Operator)
for v in vars
v = scalarize(v)
if v isa AbstractArray
v = setmetadata.(v, VariableIrreducible, true)
append!(varsvec, v)
else
push!(varsvec, v)
end
end
isalgeq = true
unknownvars = []
for var in vars
for var in varsvec
ModelingToolkit.isdelay(var, iv) && continue
set_incidence = true
@label ANOTHER_VAR
Expand Down Expand Up @@ -340,6 +350,7 @@ function TearingState(sys; quick_cancel = false, check = true)
push!(symbolic_incidence, copy(unknownvars))
empty!(unknownvars)
empty!(vars)
empty!(varsvec)
if isalgeq
eqs[i] = eq
else
Expand All @@ -350,9 +361,10 @@ function TearingState(sys; quick_cancel = false, check = true)
# sort `fullvars` such that the mass matrix is as diagonal as possible.
dervaridxs = collect(dervaridxs)
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
var_to_old_var = Dict(zip(fullvars, fullvars))
for dervaridx in dervaridxs
dervar = fullvars[dervaridx]
diffvar = lower_order_var(dervar)
diffvar = var_to_old_var[lower_order_var(dervar)]
if !(diffvar in sorted_fullvars)
push!(sorted_fullvars, diffvar)
end
Expand Down
17 changes: 10 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,25 +342,28 @@ v == Set([D(y), u])
function vars(exprs::Symbolic; op = Differential)
istree(exprs) ? vars([exprs]; op = op) : Set([exprs])
end
vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op)
vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op)
vars(exprs; op = Differential) = foldl((x, y) -> vars!(x, y; op = op), exprs; init = Set())
vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
function vars!(vars, eq::Equation; op = Differential)
(vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars)
end
function vars!(vars, O; op = Differential)
if isvariable(O)
if isvariable(O) && !(istree(O) && operation(O) === getindex)
return push!(vars, O)
end

!istree(O) && return vars
if operation(O) === (getindex)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want structural_simplify to work correctly, we have to differentiate x and x[1] alone. https://github.com/SciML/ModelingToolkit.jl/pull/2472/files#diff-47c27891e951c8cd946b850dc2df31082624afdf57446c21cb6992f5f4b74aa2R351-R369 has the correct implementation

arr = first(arguments(O))
return vars!(vars, arr)
end

operation(O) isa op && return push!(vars, O)

if operation(O) === (getindex) &&
isvariable(first(arguments(O)))
return push!(vars, O)
end

isvariable(operation(O)) && push!(vars, O)

for arg in arguments(O)
vars!(vars, arg; op = op)
end
Expand Down Expand Up @@ -807,7 +810,7 @@ end
function fast_substitute(eq::T, subs::Pair) where {T <: Eq}
T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
end
fast_substitute(eqs::AbstractArray{<:Eq}, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(a, b) = substitute(a, b)
function fast_substitute(expr, pair::Pair)
a, b = pair
Expand Down
19 changes: 7 additions & 12 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,21 +517,16 @@ eqs = [D(x) ~ x * y
using StaticArrays
using SymbolicUtils: term
using SymbolicUtils.Code
using Symbolics: unwrap, wrap
function foo(a::Num, ms::AbstractVector)
a = unwrap(a)
ms = map(unwrap, ms)
wrap(term(foo, a, term(SVector, ms...)))
end
using Symbolics: unwrap, wrap, @register_symbolic
foo(a, ms::AbstractVector) = a + sum(ms)
@variables x(t) ms(t)[1:3]
ms = collect(ms)
eqs = [D(x) ~ foo(x, ms); D.(ms) .~ 1]
@register_symbolic foo(a, ms::AbstractVector)
@variables t x(t) ms(t)[1:3]
D = Differential(t)
eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)]
@named sys = ODESystem(eqs, t, [x; ms], [])
@named emptysys = ODESystem(Equation[], t)
@named outersys = compose(emptysys, sys)
outersys = complete(outersys)
prob = ODEProblem(outersys, [sys.x => 1.0; collect(sys.ms) .=> 1:3], (0, 1.0))
@mtkbuild outersys = compose(emptysys, sys)
prob = ODEProblem(outersys, [sys.x => 1.0, sys.ms => 1:3], (0, 1.0))
@test_nowarn solve(prob, Tsit5())

# x/x
Expand Down