Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array equations/variables support in structural_simplify #2472

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,9 @@
# metadata from the rescoped variable
rescoped = renamespace(n, O)
similarterm(O, operation(rescoped), renamed,
metadata = metadata(rescoped))::T
metadata = metadata(rescoped))
else
similarterm(O, operation(O), renamed, metadata = metadata(O))::T
similarterm(O, operation(O), renamed, metadata = metadata(O))

Check warning on line 708 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L708

Added line #L708 was not covered by tests
end
elseif isvariable(O)
renamespace(n, O)
Expand Down
23 changes: 15 additions & 8 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,18 +292,20 @@
end
end

function generate_connection_set(sys::AbstractSystem, find = nothing, replace = nothing)
function generate_connection_set(
sys::AbstractSystem, find = nothing, replace = nothing; scalarize = false)
connectionsets = ConnectionSet[]
domain_csets = ConnectionSet[]
sys = generate_connection_set!(connectionsets, domain_csets, sys, find, replace)
sys = generate_connection_set!(
connectionsets, domain_csets, sys, find, replace, scalarize)
csets = merge(connectionsets)
domain_csets = merge([csets; domain_csets], true)

sys, (csets, domain_csets)
end

function generate_connection_set!(connectionsets, domain_csets,
sys::AbstractSystem, find, replace, namespace = nothing)
sys::AbstractSystem, find, replace, scalarize, namespace = nothing)
subsys = get_systems(sys)

isouter = generate_isouter(sys)
Expand All @@ -325,8 +327,13 @@
end
neweq isa AbstractArray ? append!(eqs, neweq) : push!(eqs, neweq)
else
if lhs isa Number || lhs isa Symbolic
push!(eqs, eq) # split connections and equations
if lhs isa Number || lhs isa Symbolic || eltype(lhs) <: Symbolic
# split connections and equations
if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray
append!(eqs, Symbolics.scalarize(eq))

Check warning on line 333 in src/systems/connectors.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/connectors.jl#L333

Added line #L333 was not covered by tests
else
push!(eqs, eq)
end
elseif lhs isa Connection && get_systems(lhs) === :domain
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
else
Expand Down Expand Up @@ -356,7 +363,7 @@
end
@set! sys.systems = map(
s -> generate_connection_set!(connectionsets, domain_csets, s,
find, replace,
find, replace, scalarize,
renamespace(namespace, s)),
subsys)
@set! sys.eqs = eqs
Expand Down Expand Up @@ -471,8 +478,8 @@
end

function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
debug = false, tol = 1e-10)
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace)
debug = false, tol = 1e-10, scalarize = true)
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize)
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)
sys = flatten(sys, true)
Expand Down
27 changes: 27 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,21 @@
nothing,
isdde = false,
kwargs...)
array_vars = Dict{Any, Vector{Int}}()
for (j, x) in enumerate(dvs)
if istree(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)

Check warning on line 160 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L158-L160

Added lines #L158 - L160 were not covered by tests
end
end
subs = Dict()
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
inds = inds′

Check warning on line 166 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
end
subs[k] = term(view, Sym{Any}(Symbol("ˍ₋arg1")), inds)
end

Check warning on line 169 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L168-L169

Added lines #L168 - L169 were not covered by tests
if isdde
eqs = delay_to_function(sys)
else
Expand All @@ -164,6 +179,7 @@
# substitute x(t) by just x
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
[eq.rhs for eq in eqs]
rhss = fast_substitute(rhss, subs)

# TODO: add an optional check on the ordering of observed equations
u = map(x -> time_varying_as_func(value(x), sys), dvs)
Expand Down Expand Up @@ -764,6 +780,17 @@
defs = mergedefaults(defs, parammap, ps)
end
defs = mergedefaults(defs, u0map, dvs)
for (k, v) in defs
if Symbolics.isarraysymbolic(k)
ks = scalarize(k)
length(ks) == length(v) || error("$k has default value $v with unmatched size")
for (kk, vv) in zip(ks, v)
if !haskey(defs, kk)
defs[kk] = vv

Check warning on line 789 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L785-L789

Added lines #L785 - L789 were not covered by tests
end
end

Check warning on line 791 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L791

Added line #L791 was not covered by tests
end
end

if symbolic_u0
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
Expand Down
10 changes: 5 additions & 5 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
gui_metadata = nothing)
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
deqs = scalarize(deqs)
#deqs = scalarize(deqs)
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."

iv′ = value(scalarize(iv))
ps′ = value.(scalarize(ps))
ctrl′ = value.(scalarize(controls))
dvs′ = value.(scalarize(dvs))
iv′ = value(iv)
ps′ = value.(ps)
ctrl′ = value.(controls)
dvs′ = value.(dvs)
dvs′ = filter(x -> !isdelay(x, iv), dvs′)

if !(isempty(default_u0) && isempty(default_p))
Expand Down
16 changes: 14 additions & 2 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@
end

vars = OrderedSet()
varsvec = []
for (i, eq′) in enumerate(eqs)
if eq′.lhs isa Connection
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
Expand All @@ -282,9 +283,18 @@
eq = 0 ~ rhs - lhs
end
vars!(vars, eq.rhs, op = Symbolics.Operator)
for v in vars
v = scalarize(v)
if v isa AbstractArray
v = setmetadata.(v, VariableIrreducible, true)
append!(varsvec, v)

Check warning on line 290 in src/systems/systemstructure.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/systemstructure.jl#L289-L290

Added lines #L289 - L290 were not covered by tests
else
push!(varsvec, v)
end
end
isalgeq = true
unknownvars = []
for var in vars
for var in varsvec
ModelingToolkit.isdelay(var, iv) && continue
set_incidence = true
@label ANOTHER_VAR
Expand Down Expand Up @@ -340,6 +350,7 @@
push!(symbolic_incidence, copy(unknownvars))
empty!(unknownvars)
empty!(vars)
empty!(varsvec)
if isalgeq
eqs[i] = eq
else
Expand All @@ -350,9 +361,10 @@
# sort `fullvars` such that the mass matrix is as diagonal as possible.
dervaridxs = collect(dervaridxs)
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
var_to_old_var = Dict(zip(fullvars, fullvars))
for dervaridx in dervaridxs
dervar = fullvars[dervaridx]
diffvar = lower_order_var(dervar)
diffvar = var_to_old_var[lower_order_var(dervar)]
if !(diffvar in sorted_fullvars)
push!(sorted_fullvars, diffvar)
end
Expand Down
17 changes: 10 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,22 @@
(vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars)
end
function vars!(vars, O; op = Differential)
!istree(O) && return vars
if operation(O) === (getindex)
arr = first(arguments(O))
!istree(arr) && return vars
operation(arr) isa op && return push!(vars, O)
isvariable(operation(O)) && return push!(vars, O)

Check warning on line 356 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L353-L356

Added lines #L353 - L356 were not covered by tests
end

if isvariable(O)
return push!(vars, O)
end
!istree(O) && return vars

operation(O) isa op && return push!(vars, O)

if operation(O) === (getindex) &&
isvariable(first(arguments(O)))
return push!(vars, O)
end

isvariable(operation(O)) && push!(vars, O)

for arg in arguments(O)
vars!(vars, arg; op = op)
end
Expand Down Expand Up @@ -807,7 +810,7 @@
function fast_substitute(eq::T, subs::Pair) where {T <: Eq}
T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
end
fast_substitute(eqs::AbstractArray{<:Eq}, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(a, b) = substitute(a, b)
function fast_substitute(expr, pair::Pair)
a, b = pair
Expand Down
19 changes: 7 additions & 12 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,21 +517,16 @@ eqs = [D(x) ~ x * y
using StaticArrays
using SymbolicUtils: term
using SymbolicUtils.Code
using Symbolics: unwrap, wrap
function foo(a::Num, ms::AbstractVector)
a = unwrap(a)
ms = map(unwrap, ms)
wrap(term(foo, a, term(SVector, ms...)))
end
using Symbolics: unwrap, wrap, @register_symbolic
foo(a, ms::AbstractVector) = a + sum(ms)
@variables x(t) ms(t)[1:3]
ms = collect(ms)
eqs = [D(x) ~ foo(x, ms); D.(ms) .~ 1]
@register_symbolic foo(a, ms::AbstractVector)
@variables t x(t) ms(t)[1:3]
D = Differential(t)
eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)]
@named sys = ODESystem(eqs, t, [x; ms], [])
@named emptysys = ODESystem(Equation[], t)
@named outersys = compose(emptysys, sys)
outersys = complete(outersys)
prob = ODEProblem(outersys, [sys.x => 1.0; collect(sys.ms) .=> 1:3], (0, 1.0))
@mtkbuild outersys = compose(emptysys, sys)
prob = ODEProblem(outersys, [sys.x => 1.0, sys.ms => 1:3], (0, 1.0))
@test_nowarn solve(prob, Tsit5())

# x/x
Expand Down
Loading