Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: do not scalarize parameters, fix some tests #2469

Merged
merged 29 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
298c29c
Array equations/variables support in `structural_simplify`
YingboMa Feb 15, 2024
75d343f
Support array states in ODEProblem/ODEFunction
YingboMa Feb 16, 2024
ae6bd49
feat: scalarize equations in ODESystem, fix vars! and namespace_expr
AayushSabharwal Feb 19, 2024
0287c1d
feat!: do not scalarize parameters, fix some tests
AayushSabharwal Feb 15, 2024
36388d0
test: fix test file paths
AayushSabharwal Feb 15, 2024
f9f7ae1
refactor: support clock systems without index caches
AayushSabharwal Feb 16, 2024
1591342
test: test callbacks without parameter splitting
AayushSabharwal Feb 16, 2024
f48d30d
refactor: format
AayushSabharwal Feb 16, 2024
6abcc49
refactor: disable treating symbolic defaults as param dependencies
AayushSabharwal Feb 18, 2024
9eb96a5
feat: add support for parameter dependencies
AayushSabharwal Feb 19, 2024
55f2730
docs: update NEWS with parameter dependencies
AayushSabharwal Feb 19, 2024
c26d4d9
feat: un-scalarize inferred parameters, improve parameter initialization
AayushSabharwal Feb 19, 2024
812e004
feat: flatten equations to avoid scalarizing array arguments
AayushSabharwal Feb 19, 2024
059f6e8
fix formatting
ChrisRackauckas Feb 19, 2024
55cd1d8
fix typo
ChrisRackauckas Feb 19, 2024
0c26977
fix: fix `vars!`
AayushSabharwal Feb 20, 2024
8d7c677
fix: refactor IndexCache for non-scalarized unknowns
AayushSabharwal Feb 20, 2024
9e2c9bc
fix: do not call flatten_equations in JumpSystem
AayushSabharwal Feb 20, 2024
a6add74
fix: handle broadcasted equations and array variables in ODESystem co…
AayushSabharwal Feb 20, 2024
a41a64f
fix: use variable_index in calculate_massmatrix
AayushSabharwal Feb 20, 2024
c6c96dd
fix: do not scalarize in system constructors
AayushSabharwal Feb 20, 2024
d7265c1
test: fix mass matrix tests
AayushSabharwal Feb 20, 2024
1275d6e
fixup! fix: fix `vars!`
AayushSabharwal Feb 20, 2024
3e0aea0
fix: fix IndexCache to not put matrices as nonnumeric parameters
AayushSabharwal Feb 21, 2024
1218152
feat: add copy method for MTKParameters
AayushSabharwal Feb 21, 2024
9d5c211
Skip partial_state_selection test
YingboMa Feb 21, 2024
7ce0f57
format
ChrisRackauckas Feb 22, 2024
703da35
Update Project.toml
ChrisRackauckas Feb 22, 2024
3ede8ff
Update Project.toml
ChrisRackauckas Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@
equations. For example, `[p[1] => 1.0, p[2] => 2.0]` is no longer allowed in default equations, use
`[p => [1.0, 2.0]]` instead. Also, array equations like for `@variables u[1:2]` have `D(u) ~ A*u` as an
array equation. If the scalarized version is desired, use `scalarize(u)`.
- Parameter dependencies are now supported. They can be specified using the syntax
`(single_parameter => expression_involving_other_parameters)` and a `Vector` of these can be passed to
the `parameter_dependencies` keyword argument of `ODESystem`, `SDESystem` and `JumpSystem`. The dependent
parameters are updated whenever other parameters are modified, e.g. in callbacks.
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using PrecompileTools, Reexport
using RecursiveArrayTools

using SymbolicIndexingInterface
export independent_variables, unknowns, parameters
export independent_variables, unknowns, parameters, full_parameters
import SymbolicUtils
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
Symbolic, isadd, ismul, ispow, issym, FnType,
Expand Down
3 changes: 2 additions & 1 deletion src/bipartite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ function Base.push!(m::Matching, v)
end
end

function complete(m::Matching{U}, N = maximum((x for x in m.match if isa(x, Int)); init=0)) where {U}
function complete(m::Matching{U},
N = maximum((x for x in m.match if isa(x, Int)); init = 0)) where {U}
m.inv_match !== nothing && return m
inv_match = Union{U, Int}[unassigned for _ in 1:N]
for (i, eq) in enumerate(m.match)
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varl
old_level_vars = ()
ict = IncrementalCycleTracker(
DiCMOBiGraph{true}(graph,
complete(Matching(ndsts(graph)), nsrcs(graph))),
complete(Matching(ndsts(graph)), nsrcs(graph))),
dir = :in)

while level >= 0
Expand Down
86 changes: 53 additions & 33 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@
ic = get_index_cache(sys)
h = getsymbolhash(sym)
return haskey(ic.unknown_idx, h) ||
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) ||
hasname(sym) && is_variable(sys, getname(sym))
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym)))
else
return any(isequal(sym), variable_symbols(sys)) ||
hasname(sym) && is_variable(sys, getname(sym))
Expand All @@ -220,8 +219,6 @@
h = getsymbolhash(default_toterm(sym))
if haskey(ic.unknown_idx, h)
ic.unknown_idx[h]
elseif hasname(sym)
variable_index(sys, getname(sym))
else
nothing
end
Expand Down Expand Up @@ -264,8 +261,7 @@
else
h = getsymbolhash(default_toterm(sym))
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
hasname(sym) && is_parameter(sys, getname(sym))
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
end
end
return any(isequal(sym), parameter_symbols(sys)) ||
Expand All @@ -286,27 +282,12 @@
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
h = getsymbolhash(sym)
return if haskey(ic.param_idx, h)
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
elseif haskey(ic.discrete_idx, h)
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
elseif haskey(ic.constant_idx, h)
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
elseif haskey(ic.dependent_idx, h)
ParameterIndex(nothing, ic.dependent_idx[h])
return if (idx = ParameterIndex(ic, sym)) !== nothing
idx
elseif (idx = ParameterIndex(ic, default_toterm(sym))) !== nothing
idx

Check warning on line 288 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L285-L288

Added lines #L285 - L288 were not covered by tests
else
h = getsymbolhash(default_toterm(sym))
if haskey(ic.param_idx, h)
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
elseif haskey(ic.discrete_idx, h)
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
elseif haskey(ic.constant_idx, h)
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
elseif haskey(ic.dependent_idx, h)
ParameterIndex(nothing, ic.dependent_idx[h])
else
nothing
end
nothing

Check warning on line 290 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L290

Added line #L290 was not covered by tests
end
end

Expand All @@ -329,7 +310,7 @@
end

function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
return parameters(sys)
return full_parameters(sys)

Check warning on line 313 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L313

Added line #L313 was not covered by tests
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
Expand Down Expand Up @@ -419,6 +400,7 @@
:metadata
:gui_metadata
:discrete_subsystems
:parameter_dependencies
:solved_unknowns
:split_idxs
:parent
Expand Down Expand Up @@ -703,9 +685,12 @@
# metadata from the rescoped variable
rescoped = renamespace(n, O)
similarterm(O, operation(rescoped), renamed,
metadata = metadata(rescoped))::T
metadata = metadata(rescoped))
elseif Symbolics.isarraysymbolic(O)
# promote_symtype doesn't work for array symbolics
similarterm(O, operation(O), renamed, symtype(O), metadata = metadata(O))

Check warning on line 691 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L691

Added line #L691 was not covered by tests
else
similarterm(O, operation(O), renamed, metadata = metadata(O))::T
similarterm(O, operation(O), renamed, metadata = metadata(O))
end
elseif isvariable(O)
renamespace(n, O)
Expand Down Expand Up @@ -747,7 +732,29 @@
ps = first.(ps)
end
systems = get_systems(sys)
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
result = unique(isempty(systems) ? ps :
[ps; reduce(vcat, namespace_parameters.(systems))])
if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
filter(result) do sym
!haskey(pdeps, sym)

Check warning on line 740 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L739-L740

Added lines #L739 - L740 were not covered by tests
end
else
result
end
end

function dependent_parameters(sys::AbstractSystem)
if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
collect(keys(pdeps))

Check warning on line 750 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L750

Added line #L750 was not covered by tests
else
[]
end
end

function full_parameters(sys::AbstractSystem)
vcat(parameters(sys), dependent_parameters(sys))
end

# required in `src/connectors.jl:437`
Expand Down Expand Up @@ -1518,13 +1525,12 @@
sys = ssys
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
ps = parameters(sys)

Check warning on line 1528 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1528

Added line #L1528 was not covered by tests
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, p)
ps = reorder_parameters(sys, parameters(sys))
else
p = _p
p, split_idxs = split_parameters_by_type(p)
ps = parameters(sys)
if p isa Tuple
ps = Base.Fix1(getindex, ps).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
Expand Down Expand Up @@ -1610,7 +1616,7 @@
kwargs...)
sts = unknowns(sys)
t = get_iv(sys)
ps = parameters(sys)
ps = full_parameters(sys)

Check warning on line 1619 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1619

Added line #L1619 was not covered by tests
p = reorder_parameters(sys, ps)

fun = generate_function(sys, sts, ps; expression = Val{false})[1]
Expand Down Expand Up @@ -2121,3 +2127,17 @@
error("substituting symbols is not supported for $(typeof(sys))")
end
end

function process_parameter_dependencies(pdeps, ps)
pdeps === nothing && return pdeps, ps
if pdeps isa Vector && eltype(pdeps) <: Pair
pdeps = Dict(pdeps)
elseif !(pdeps isa Dict)
error("parameter_dependencies must be a `Dict` or `Vector{<:Pair}`")

Check warning on line 2136 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L2133-L2136

Added lines #L2133 - L2136 were not covered by tests
end

ps = filter(ps) do p
!haskey(pdeps, p)

Check warning on line 2140 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L2139-L2140

Added lines #L2139 - L2140 were not covered by tests
end
return pdeps, ps

Check warning on line 2142 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L2142

Added line #L2142 was not covered by tests
end
13 changes: 5 additions & 8 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,8 @@
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
update_inds = map(update_vars) do sym
@unpack portion, idx = parameter_index(sys, sym)
if portion == SciMLStructures.Discrete()
idx += length(ic.param_idx)
end
idx
pind = parameter_index(sys, sym)
discrete_linear_index(ic, pind)

Check warning on line 394 in src/systems/callbacks.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/callbacks.jl#L393-L394

Added lines #L393 - L394 were not covered by tests
end
else
psind = Dict(reverse(en) for en in enumerate(ps))
Expand Down Expand Up @@ -436,14 +433,14 @@
end

function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
ps = full_parameters(sys); kwargs...)
cbs = continuous_events(sys)
isempty(cbs) && return nothing
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
ps = full_parameters(sys); kwargs...)
eqs = map(cb -> cb.eqs, cbs)
num_eqs = length.(eqs)
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
Expand Down Expand Up @@ -559,7 +556,7 @@
end

function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
ps = parameters(sys); kwargs...)
ps = full_parameters(sys); kwargs...)
has_discrete_events(sys) || return nothing
symcbs = discrete_events(sys)
isempty(symcbs) && return nothing
Expand Down
Loading
Loading