Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: do not scalarize parameters, fix some tests #2469

Merged
merged 29 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
298c29c
Array equations/variables support in `structural_simplify`
YingboMa Feb 15, 2024
75d343f
Support array states in ODEProblem/ODEFunction
YingboMa Feb 16, 2024
ae6bd49
feat: scalarize equations in ODESystem, fix vars! and namespace_expr
AayushSabharwal Feb 19, 2024
0287c1d
feat!: do not scalarize parameters, fix some tests
AayushSabharwal Feb 15, 2024
36388d0
test: fix test file paths
AayushSabharwal Feb 15, 2024
f9f7ae1
refactor: support clock systems without index caches
AayushSabharwal Feb 16, 2024
1591342
test: test callbacks without parameter splitting
AayushSabharwal Feb 16, 2024
f48d30d
refactor: format
AayushSabharwal Feb 16, 2024
6abcc49
refactor: disable treating symbolic defaults as param dependencies
AayushSabharwal Feb 18, 2024
9eb96a5
feat: add support for parameter dependencies
AayushSabharwal Feb 19, 2024
55f2730
docs: update NEWS with parameter dependencies
AayushSabharwal Feb 19, 2024
c26d4d9
feat: un-scalarize inferred parameters, improve parameter initialization
AayushSabharwal Feb 19, 2024
812e004
feat: flatten equations to avoid scalarizing array arguments
AayushSabharwal Feb 19, 2024
059f6e8
fix formatting
ChrisRackauckas Feb 19, 2024
55cd1d8
fix typo
ChrisRackauckas Feb 19, 2024
0c26977
fix: fix `vars!`
AayushSabharwal Feb 20, 2024
8d7c677
fix: refactor IndexCache for non-scalarized unknowns
AayushSabharwal Feb 20, 2024
9e2c9bc
fix: do not call flatten_equations in JumpSystem
AayushSabharwal Feb 20, 2024
a6add74
fix: handle broadcasted equations and array variables in ODESystem co…
AayushSabharwal Feb 20, 2024
a41a64f
fix: use variable_index in calculate_massmatrix
AayushSabharwal Feb 20, 2024
c6c96dd
fix: do not scalarize in system constructors
AayushSabharwal Feb 20, 2024
d7265c1
test: fix mass matrix tests
AayushSabharwal Feb 20, 2024
1275d6e
fixup! fix: fix `vars!`
AayushSabharwal Feb 20, 2024
3e0aea0
fix: fix IndexCache to not put matrices as nonnumeric parameters
AayushSabharwal Feb 21, 2024
1218152
feat: add copy method for MTKParameters
AayushSabharwal Feb 21, 2024
9d5c211
Skip partial_state_selection test
YingboMa Feb 21, 2024
7ce0f57
format
ChrisRackauckas Feb 22, 2024
703da35
Update Project.toml
ChrisRackauckas Feb 22, 2024
3ede8ff
Update Project.toml
ChrisRackauckas Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@
equations. For example, `[p[1] => 1.0, p[2] => 2.0]` is no longer allowed in default equations, use
`[p => [1.0, 2.0]]` instead. Also, array equations like for `@variables u[1:2]` have `D(u) ~ A*u` as an
array equation. If the scalarized version is desired, use `scalarize(u)`.
- Parameter dependencies are now supported. They can be specified using the syntax
`(single_parameter => expression_involving_other_parameters)` and a `Vector` of these can be passed to
the `parameter_dependencies` keyword argument of `ODESystem`, `SDESystem` and `JumpSystem`. The dependent
parameters are updated whenever other parameters are modified, e.g. in callbacks.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.1"
SymbolicUtils = "1.0"
Symbolics = "5.7"
Symbolics = "5.21"
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
Expand Down
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using PrecompileTools, Reexport
using RecursiveArrayTools

using SymbolicIndexingInterface
export independent_variables, unknowns, parameters
export independent_variables, unknowns, parameters, full_parameters
import SymbolicUtils
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
Symbolic, isadd, ismul, ispow, issym, FnType,
Expand Down
3 changes: 2 additions & 1 deletion src/bipartite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ function Base.push!(m::Matching, v)
end
end

function complete(m::Matching{U}, N = maximum((x for x in m.match if isa(x, Int)); init=0)) where {U}
function complete(m::Matching{U},
N = maximum((x for x in m.match if isa(x, Int)); init = 0)) where {U}
m.inv_match !== nothing && return m
inv_match = Union{U, Int}[unassigned for _ in 1:N]
for (i, eq) in enumerate(m.match)
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varl
old_level_vars = ()
ict = IncrementalCycleTracker(
DiCMOBiGraph{true}(graph,
complete(Matching(ndsts(graph)), nsrcs(graph))),
complete(Matching(ndsts(graph)), nsrcs(graph))),
dir = :in)

while level >= 0
Expand Down
105 changes: 63 additions & 42 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
h = getsymbolhash(sym)
return haskey(ic.unknown_idx, h) ||
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) ||
hasname(sym) && is_variable(sys, getname(sym))
(istree(sym) && operation(sym) === getindex &&
is_variable(sys, first(arguments(sym))))
else
return any(isequal(sym), variable_symbols(sys)) ||
hasname(sym) && is_variable(sys, getname(sym))
Expand All @@ -214,18 +215,15 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
h = getsymbolhash(sym)
return if haskey(ic.unknown_idx, h)
ic.unknown_idx[h]
else
h = getsymbolhash(default_toterm(sym))
if haskey(ic.unknown_idx, h)
ic.unknown_idx[h]
elseif hasname(sym)
variable_index(sys, getname(sym))
else
nothing
end
end
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]

h = getsymbolhash(default_toterm(sym))
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
sym = unwrap(sym)
istree(sym) && operation(sym) === getindex || return nothing
idx = variable_index(sys, first(arguments(sym)))
idx === nothing && return nothing
return idx[arguments(sym)[(begin + 1):end]...]
end
idx = findfirst(isequal(sym), variable_symbols(sys))
if idx === nothing && hasname(sym)
Expand Down Expand Up @@ -264,8 +262,7 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
else
h = getsymbolhash(default_toterm(sym))
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
hasname(sym) && is_parameter(sys, getname(sym))
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
end
end
return any(isequal(sym), parameter_symbols(sys)) ||
Expand All @@ -286,27 +283,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
h = getsymbolhash(sym)
return if haskey(ic.param_idx, h)
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
elseif haskey(ic.discrete_idx, h)
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
elseif haskey(ic.constant_idx, h)
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
elseif haskey(ic.dependent_idx, h)
ParameterIndex(nothing, ic.dependent_idx[h])
return if (idx = ParameterIndex(ic, sym)) !== nothing
idx
elseif (idx = ParameterIndex(ic, default_toterm(sym))) !== nothing
idx
else
h = getsymbolhash(default_toterm(sym))
if haskey(ic.param_idx, h)
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
elseif haskey(ic.discrete_idx, h)
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
elseif haskey(ic.constant_idx, h)
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
elseif haskey(ic.dependent_idx, h)
ParameterIndex(nothing, ic.dependent_idx[h])
else
nothing
end
nothing
end
end

Expand All @@ -329,7 +311,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
end

function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
return parameters(sys)
return full_parameters(sys)
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
Expand Down Expand Up @@ -419,6 +401,7 @@ for prop in [:eqs
:metadata
:gui_metadata
:discrete_subsystems
:parameter_dependencies
:solved_unknowns
:split_idxs
:parent
Expand Down Expand Up @@ -703,9 +686,12 @@ function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys
# metadata from the rescoped variable
rescoped = renamespace(n, O)
similarterm(O, operation(rescoped), renamed,
metadata = metadata(rescoped))::T
metadata = metadata(rescoped))
elseif Symbolics.isarraysymbolic(O)
# promote_symtype doesn't work for array symbolics
similarterm(O, operation(O), renamed, symtype(O), metadata = metadata(O))
else
similarterm(O, operation(O), renamed, metadata = metadata(O))::T
similarterm(O, operation(O), renamed, metadata = metadata(O))
end
elseif isvariable(O)
renamespace(n, O)
Expand Down Expand Up @@ -747,7 +733,29 @@ function parameters(sys::AbstractSystem)
ps = first.(ps)
end
systems = get_systems(sys)
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
result = unique(isempty(systems) ? ps :
[ps; reduce(vcat, namespace_parameters.(systems))])
if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
filter(result) do sym
!haskey(pdeps, sym)
end
else
result
end
end

function dependent_parameters(sys::AbstractSystem)
if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
collect(keys(pdeps))
else
[]
end
end

function full_parameters(sys::AbstractSystem)
vcat(parameters(sys), dependent_parameters(sys))
end

# required in `src/connectors.jl:437`
Expand Down Expand Up @@ -1518,13 +1526,12 @@ function linearization_function(sys::AbstractSystem, inputs,
sys = ssys
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
ps = parameters(sys)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, p)
ps = reorder_parameters(sys, parameters(sys))
else
p = _p
p, split_idxs = split_parameters_by_type(p)
ps = parameters(sys)
if p isa Tuple
ps = Base.Fix1(getindex, ps).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
Expand Down Expand Up @@ -1610,7 +1617,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
kwargs...)
sts = unknowns(sys)
t = get_iv(sys)
ps = parameters(sys)
ps = full_parameters(sys)
p = reorder_parameters(sys, ps)

fun = generate_function(sys, sts, ps; expression = Val{false})[1]
Expand Down Expand Up @@ -2121,3 +2128,17 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
error("substituting symbols is not supported for $(typeof(sys))")
end
end

function process_parameter_dependencies(pdeps, ps)
pdeps === nothing && return pdeps, ps
if pdeps isa Vector && eltype(pdeps) <: Pair
pdeps = Dict(pdeps)
elseif !(pdeps isa Dict)
error("parameter_dependencies must be a `Dict` or `Vector{<:Pair}`")
end

ps = filter(ps) do p
!haskey(pdeps, p)
end
return pdeps, ps
end
13 changes: 5 additions & 8 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
update_inds = map(update_vars) do sym
@unpack portion, idx = parameter_index(sys, sym)
if portion == SciMLStructures.Discrete()
idx += length(ic.param_idx)
end
idx
pind = parameter_index(sys, sym)
discrete_linear_index(ic, pind)
end
else
psind = Dict(reverse(en) for en in enumerate(ps))
Expand Down Expand Up @@ -436,14 +433,14 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
end

function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
ps = full_parameters(sys); kwargs...)
cbs = continuous_events(sys)
isempty(cbs) && return nothing
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
ps = full_parameters(sys); kwargs...)
eqs = map(cb -> cb.eqs, cbs)
num_eqs = length.(eqs)
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
Expand Down Expand Up @@ -559,7 +556,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
end

function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
ps = full_parameters(sys); kwargs...)
has_discrete_events(sys) || return nothing
symcbs = discrete_events(sys)
isempty(symcbs) && return nothing
Expand Down
Loading
Loading