diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 1220d517cc..4a45e01b3a 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -24,7 +24,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di invalidate_cache!, Substitutions, get_or_construct_tearing_state, filter_kwargs, lower_varname_with_unit, setio, SparseMatrixCLIL, get_fullvars, has_equations, observed, - Schedule, schedule + Schedule, schedule, AliasGraph using ModelingToolkit.BipartiteGraphs import .BipartiteGraphs: invview, complete diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index 8a0ae5276e..18886e6572 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -170,11 +170,17 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match end function dummy_derivative_graph!(state::TransformationState, jac = nothing; - state_priority = nothing, log = Val(false), kwargs...) + state_priority = nothing, mm = nothing, log = Val(false), kwargs...) state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...) complete!(state.structure) var_eq_matching = complete(pantelides!(state; kwargs...)) - dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log) + if mm === nothing + ag = nothing + else + ag = AliasGraph(mm, ndsts(state.structure.graph)) + end + dummy_derivative_graph!( + state.structure, var_eq_matching, jac, state_priority, ag, mm, log) end struct DummyDerivativeSummary @@ -182,30 +188,109 @@ struct DummyDerivativeSummary state_priority::Vector{Vector{Float64}} end +function extended_state_priority(state_priority, var_to_diff, ::Nothing) + sp = zeros(length(var_to_diff)) + diff_to_var = invview(var_to_diff) + for v in 1:length(var_to_diff) + var = v + min_p = max_p = 0.0 + while var_to_diff[var] !== nothing + var = var_to_diff[var] + end + while true + p = state_priority(var) + max_p = max(max_p, p) + min_p = min(min_p, p) + (var = diff_to_var[var]) === nothing && break + end + sp[v] = min_p < 0 ? min_p : max_p + end + sp +end + +function extended_state_priority(state_priority, var_to_diff, ag::AliasGraph) + sp = map(state_priority, 1:length(var_to_diff)) + prop_graph = SimpleDiGraph{Int}(ag) + for (v, dv) in enumerate(var_to_diff) + dv isa Int && add_edge!(prop_graph, v, dv) + end + prop_state_priority!(sp, prop_graph) + sp +end + +function maxabs(x, y) + ax = abs(x) + ay = abs(y) + ax == ay ? min(x, y) : (ax > ay ? x : y) +end + +function prop_state_priority!(sp, graph) + visited = BitSet() + function visit!(sp, graph, v) + push!(visited, v) + for n in outneighbors(graph, v) + n in visited && continue + sp[n] = maxabs(sp[n], sp[v]) + visit!(sp, graph, n) + end + end + for v in vertices(graph) + if isempty(inneighbors(graph, v)) + visit!(sp, graph, v) + empty!(visited) + end + end + sp +end +#= +State priority handling: + +Phase 1: +Before dummy derivatives we cannot assume differentiations are real, so we can +only set state priorities and not change the scheduling: + ia -> a (1) + ^ + || + (2) v -> D(v) (3) + ^ + || +x -> D(x) + + ^ +-> or | update priority by picking the max on these direction only (don't go +reverse) + +Phase 2: +After dummy derivatives we do: + + ia a + ^ ^ + || || + v -> D(v) + ^ + || +x -> D(x) + +x v + ia + +(priority, is_der) lex order + +Build an alias graph with alias edges (±1 <->) and derivative edges (->), when +both edges coincide, take the derivative edge, then pick the variable on each +non-highest differentiation level that has the highest value by the above +ordering to obtain state variables. +=# + function dummy_derivative_graph!( structure::SystemStructure, var_eq_matching, jac = nothing, - state_priority = nothing, ::Val{log} = Val(false)) where {log} + state_priority = nothing, ag = nothing, mm = nothing, ::Val{log} = Val(false)) where {log} @unpack eq_to_diff, var_to_diff, graph = structure diff_to_eq = invview(eq_to_diff) diff_to_var = invview(var_to_diff) invgraph = invview(graph) - extended_sp = let state_priority = state_priority, var_to_diff = var_to_diff, - diff_to_var = diff_to_var - - var -> begin - min_p = max_p = 0.0 - while var_to_diff[var] !== nothing - var = var_to_diff[var] - end - while true - p = state_priority(var) - max_p = max(max_p, p) - min_p = min(min_p, p) - (var = diff_to_var[var]) === nothing && break - end - min_p < 0 ? min_p : max_p - end - end + extended_sp_vec = extended_state_priority(state_priority, var_to_diff, ag) + extended_sp = Base.Fix1(getindex, extended_sp_vec) var_sccs = find_var_sccs(graph, var_eq_matching) var_perm = Int[] diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index fb4fedc920..2fbdc205c7 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -381,6 +381,31 @@ end swap!(v, i, j) = v[i], v[j] = v[j], v[i] +struct AliasGraph <: Graphs.AbstractGraph{Int} + graph::SimpleGraph{Int} + neg_edge::Set{Tuple{Int, Int}} +end +function AliasGraph(mm::SparseMatrixCLIL, nv::Int) + graph = SimpleGraph{Int}(nv) + neg_edge = Set{Tuple{Int, Int}}() + for r in eachrow(mm) + @unpack nzval, nzind = r.vec + length(nzval) == 2 || continue + (v1, v2) = nzval + (abs(v1) == abs(v2) == 1) || continue + (i1, i2) = nzind + add_edge!(graph, i1, i2) + push!(neg_edge, (i1, i2)) + end + AliasGraph(graph, neg_edge) +end +isneg(ag::AliasGraph, a, b) = (a, b) in ag.neg_edge +for f in [:dst, :edges, :edgetype, :has_edge, :has_vertex, + :inneighbors, :is_directed, :ne, :nv, :outneighbors, + :src, :vertices] + @eval Graphs.$f(ag::AliasGraph) = Graphs.$f(ag.graph) +end + """ $(SIGNATURES)