Skip to content

Commit

Permalink
Merge pull request #1599 from SciML/myb/io
Browse files Browse the repository at this point in the history
Add a pass that converts unbound inputs to parameters
  • Loading branch information
YingboMa authored May 25, 2022
2 parents 56e5844 + b20d4dc commit 4f999b1
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ using .BipartiteGraphs

include("variables.jl")
include("parameters.jl")
include("inputoutput.jl")

include("utils.jl")
include("domains.jl")
Expand Down Expand Up @@ -152,6 +151,7 @@ include("systems/alias_elimination.jl")
include("structural_transformation/StructuralTransformations.jl")

@reexport using .StructuralTransformations
include("inputoutput.jl")

for S in subtypes(ModelingToolkit.AbstractSystem)
S = nameof(S)
Expand Down
62 changes: 62 additions & 0 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,65 @@ function toparam(sys, ctrls::AbstractVector)
end
ODESystem(eqs, name = nameof(sys))
end

function inputs_to_parameters!(state::TransformationState)
@unpack structure, fullvars, sys = state
@unpack var_to_diff, graph, solvable_graph = structure
@assert solvable_graph === nothing

inputs = BitSet()
var_reidx = zeros(Int, length(fullvars))
ninputs = 0
nvar = 0
new_parameters = []
input_to_parameters = Dict()
new_fullvars = []
for (i, v) in enumerate(fullvars)
if isinput(v) && !is_bound(sys, v)
if var_to_diff[i] !== nothing
error("Input $(fullvars[i]) is differentiated!")
end
push!(inputs, i)
ninputs += 1
var_reidx[i] = -1
p = toparam(v)
push!(new_parameters, p)
input_to_parameters[v] = p
else
nvar += 1
var_reidx[i] = nvar
push!(new_fullvars, v)
end
end
ninputs == 0 && return state

nvars = ndsts(graph) - ninputs
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))

for ie in 1:nsrcs(graph)
for iv in 𝑠neighbors(graph, ie)
iv = var_reidx[iv]
iv > 0 || continue
add_edge!(new_graph, ie, iv)
end
end

new_var_to_diff = DiffGraph(nvars, true)
for (i, v) in enumerate(var_to_diff)
new_i = var_reidx[i]
(new_i < 1 || v === nothing) && continue
new_v = var_reidx[v]
@assert new_v > 0
new_var_to_diff[new_i] = new_v
end
@set! structure.var_to_diff = new_var_to_diff
@set! structure.graph = new_graph

@set! sys.eqs = map(Base.Fix2(substitute, input_to_parameters), equations(sys))
@set! sys.states = setdiff(states(sys), keys(input_to_parameters))
@set! sys.ps = [parameters(sys); new_parameters]

@set! state.sys = sys
@set! state.fullvars = new_fullvars
@set! state.structure = structure
end
2 changes: 1 addition & 1 deletion src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ function find_eq_solvables!(state::TearingState, ieq; may_be_zero = false,
to_rm = Int[]
for j in 𝑠neighbors(graph, ieq)
var = fullvars[j]
isinput(var) && continue
#isinput(var) && continue
a, b, islinear = linear_expansion(term, var)
a = unwrap(a)
islinear || continue
Expand Down
2 changes: 2 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,8 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
sys = expand_connections(sys)
sys = alias_elimination(sys)
state = TearingState(sys)
state = inputs_to_parameters!(state)
sys = state.sys
check_consistency(state)
if sys isa ODESystem
sys = dae_order_lowering(dummy_derivative(sys, state))
Expand Down
18 changes: 18 additions & 0 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,32 @@ D = Differential(tv)
@test !is_bound(sys2, sys.u)
@test !is_bound(sys2, sys2.sys.u)

fsys2 = flatten(sys2)
@test is_bound(fsys2, sys.x)
@test !is_bound(fsys2, sys.u)
@test !is_bound(fsys2, sys2.sys.u)


@test is_bound(sys3, sys.u) # I would like to write sys3.sys.u here but that's not how the variable is stored in the equations
@test is_bound(sys3, sys.x)

@test is_bound(sys4, sys.u)
@test !is_bound(sys4, u)

fsys4 = flatten(sys4)
@test is_bound(fsys4, sys.u)
@test !is_bound(fsys4, u)

@test isequal(inputs(sys), [u])
@test isequal(inputs(sys2), [sys.u])

@test isempty(bound_inputs(sys))
@test isequal(unbound_inputs(sys), [u])

@test isempty(bound_inputs(sys2))
@test isempty(bound_inputs(fsys2))
@test isequal(unbound_inputs(sys2), [sys.u])
@test isequal(unbound_inputs(fsys2), [sys.u])

@test isequal(bound_inputs(sys3), [sys.u])
@test isempty(unbound_inputs(sys3))
Expand Down Expand Up @@ -161,3 +173,9 @@ p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps)
x = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), dvs)
u = [rand()]
@test f[1](x, u, p, 1) == [u; 0; 0; 0]

@parameters t
@variables x(t) u(t) [input=true]
eqs = [Differential(t)(x) ~ u]
@named sys = ODESystem(eqs, t)
structural_simplify(sys)
3 changes: 1 addition & 2 deletions test/reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,9 @@ D = Differential(t)

eqs = [D(x) ~ σ * (y - x)
D(y) ~ x *- z) - y + β
0 ~ z - x + y
0 ~ a + z
u ~ z + a]

lorenz1 = ODESystem(eqs, t, name = :lorenz1)
lorenz1_reduced = structural_simplify(lorenz1)
@test z in Set(states(lorenz1_reduced))
@test z in Set(parameters(lorenz1_reduced))

0 comments on commit 4f999b1

Please sign in to comment.