From 45f5cc122bf500f8582090e1efa0f12935be4f43 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 25 Sep 2024 17:18:13 -0400 Subject: [PATCH 1/5] Alias aware state priority --- .../StructuralTransformations.jl | 2 +- .../partial_state_selection.jl | 125 +++++++++++++++--- src/systems/alias_elimination.jl | 25 ++++ 3 files changed, 131 insertions(+), 21 deletions(-) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 1220d517cc..4a45e01b3a 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -24,7 +24,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di invalidate_cache!, Substitutions, get_or_construct_tearing_state, filter_kwargs, lower_varname_with_unit, setio, SparseMatrixCLIL, get_fullvars, has_equations, observed, - Schedule, schedule + Schedule, schedule, AliasGraph using ModelingToolkit.BipartiteGraphs import .BipartiteGraphs: invview, complete diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index 8a0ae5276e..6d2838c544 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -170,11 +170,16 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match end function dummy_derivative_graph!(state::TransformationState, jac = nothing; - state_priority = nothing, log = Val(false), kwargs...) + state_priority = nothing, mm = nothing, log = Val(false), kwargs...) state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...) complete!(state.structure) var_eq_matching = complete(pantelides!(state; kwargs...)) - dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log) + if mm === nothing + ag = nothing + else + ag = AliasGraph(mm, ndsts(state.structure.graph)) + end + dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, ag, log) end struct DummyDerivativeSummary @@ -182,30 +187,110 @@ struct DummyDerivativeSummary state_priority::Vector{Vector{Float64}} end +function extended_state_priority(state_priority, var_to_diff, ::Nothing) + sp = zeros(length(var_to_diff)) + diff_to_var = invview(var_to_diff) + for v in 1:length(var_to_diff) + var = v + min_p = max_p = 0.0 + while var_to_diff[var] !== nothing + var = var_to_diff[var] + end + while true + p = state_priority(var) + max_p = max(max_p, p) + min_p = min(min_p, p) + (var = diff_to_var[var]) === nothing && break + end + sp[v] = min_p < 0 ? min_p : max_p + end + sp +end + +function extended_state_priority(state_priority, var_to_diff, ag::AliasGraph) + sp = map(state_priority, 1:length(var_to_diff)) + diff_to_var = invview(var_to_diff) + prop_graph = SimpleDiGraph{Int}(ag) + for (v, dv) in enumerate(var_to_diff) + dv isa Int && add_edge!(prop_graph, v, dv) + end + prop_state_priority!(sp, prop_graph) + sp +end + +function maxabs(x, y) + ax = abs(x) + ay = abs(y) + ax == ay ? min(x, y) : (ax > ay ? x : y) +end + +function prop_state_priority!(sp, graph) + visited = BitSet() + function visit!(sp, graph, v) + push!(visited, v) + for n in inneighbors(graph, v) + n in visited && continue + visit!(sp, graph, n) + end + for n in outneighbors(graph, v) + sp[n] = maxabs(sp[n], sp[v]) + end + end + for v in vertices(graph) + visit!(sp, graph, v) + end + sp +end +#= +State priority handling: + +Phase 1: +Before dummy derivatives we cannot assume differentiations are real, so we can +only set state priorities and not change the scheduling: + ia -> a (1) + ^ + || + (2) v -> D(v) (3) + ^ + || +x -> D(x) + + ^ +-> or | update priority by picking the max on these direction only (don't go +reverse) + +Phase 2: +After dummy derivatives we do: + + ia a + ^ ^ + || || + v -> D(v) + ^ + || +x -> D(x) + +x v + ia + +(priority, is_der) lex order + +Build an alias graph with alias edges (±1 <->) and derivative edges (->), when +both edges coincide, take the derivative edge, then pick the variable on each +non-highest differentiation level that has the highest value by the above +ordering to obtain state variables. +=# + function dummy_derivative_graph!( structure::SystemStructure, var_eq_matching, jac = nothing, - state_priority = nothing, ::Val{log} = Val(false)) where {log} + state_priority = nothing, ag = nothing, ::Val{log} = Val(false)) where {log} @unpack eq_to_diff, var_to_diff, graph = structure diff_to_eq = invview(eq_to_diff) diff_to_var = invview(var_to_diff) invgraph = invview(graph) - extended_sp = let state_priority = state_priority, var_to_diff = var_to_diff, - diff_to_var = diff_to_var - - var -> begin - min_p = max_p = 0.0 - while var_to_diff[var] !== nothing - var = var_to_diff[var] - end - while true - p = state_priority(var) - max_p = max(max_p, p) - min_p = min(min_p, p) - (var = diff_to_var[var]) === nothing && break - end - min_p < 0 ? min_p : max_p - end - end + #ag = nothing + extended_sp_vec = extended_state_priority(state_priority, var_to_diff, ag) + extended_sp = Base.Fix1(getindex, extended_sp_vec) var_sccs = find_var_sccs(graph, var_eq_matching) var_perm = Int[] diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index fb4fedc920..70623d5037 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -381,6 +381,31 @@ end swap!(v, i, j) = v[i], v[j] = v[j], v[i] +struct AliasGraph <: Graphs.AbstractGraph{Int} + graph::SimpleGraph{Int} + neg_edge::Set{Tuple{Int, Int}} +end +function AliasGraph(mm::SparseMatrixCLIL, nv::Int) + graph = SimpleGraph{Int}(nv) + neg_edge = Set{Tuple{Int, Int}}() + for r in eachrow(mm) + @unpack nzval, nzind = r.vec + length(nzval) == 2 || continue + (v1, v2) = nzval + (abs(v1) == abs(v2) == 1) || continue + (i1, i2) = nzind + add_edge!(graph, i1, i2) + push!(neg_edge, (i1, i2)) + end + AliasGraph(graph, neg_edge) +end +isneg(ag::AliasGraph, a, b) = (a, b) in ag.neg_edge +for f in [:dst, :edges, :edgetype, :has_edge, :has_vertex, + :inneighbors, :is_directed, :ne, :nv, :outneighbors, + :src, :vertices] + @eval Graphs.$f(ag::AliasGraph) = Graphs.$f(ag.graph) +end + """ $(SIGNATURES) From 6954970f819d09a8fbd286e1ec4f68cedfd4bb98 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 26 Sep 2024 12:45:31 -0400 Subject: [PATCH 2/5] WIP --- src/structural_transformation/partial_state_selection.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index 6d2838c544..2385e65b9e 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -179,7 +179,7 @@ function dummy_derivative_graph!(state::TransformationState, jac = nothing; else ag = AliasGraph(mm, ndsts(state.structure.graph)) end - dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, ag, log) + dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, ag, mm, log) end struct DummyDerivativeSummary @@ -283,12 +283,12 @@ ordering to obtain state variables. function dummy_derivative_graph!( structure::SystemStructure, var_eq_matching, jac = nothing, - state_priority = nothing, ag = nothing, ::Val{log} = Val(false)) where {log} + state_priority = nothing, ag = nothing, mm = nothing, ::Val{log} = Val(false)) where {log} @unpack eq_to_diff, var_to_diff, graph = structure diff_to_eq = invview(eq_to_diff) diff_to_var = invview(var_to_diff) invgraph = invview(graph) - #ag = nothing + ag = nothing extended_sp_vec = extended_state_priority(state_priority, var_to_diff, ag) extended_sp = Base.Fix1(getindex, extended_sp_vec) From 79ee3a604a2da92c2b439d88a38f440044f0eb43 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 4 Oct 2024 17:43:11 -0400 Subject: [PATCH 3/5] Enable extended sp --- src/structural_transformation/partial_state_selection.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index 2385e65b9e..f44e32a126 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -288,7 +288,6 @@ function dummy_derivative_graph!( diff_to_eq = invview(eq_to_diff) diff_to_var = invview(var_to_diff) invgraph = invview(graph) - ag = nothing extended_sp_vec = extended_state_priority(state_priority, var_to_diff, ag) extended_sp = Base.Fix1(getindex, extended_sp_vec) From aa68a7187c228ebad1ad513bbe409eac3c8897a3 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 28 Oct 2024 13:11:59 -0400 Subject: [PATCH 4/5] Fix the propagation algorithm --- .../partial_state_selection.jl | 15 ++++++++------- src/systems/alias_elimination.jl | 4 ++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index f44e32a126..0eab0da4fb 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -179,7 +179,8 @@ function dummy_derivative_graph!(state::TransformationState, jac = nothing; else ag = AliasGraph(mm, ndsts(state.structure.graph)) end - dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, ag, mm, log) + dummy_derivative_graph!( + state.structure, var_eq_matching, jac, state_priority, ag, mm, log) end struct DummyDerivativeSummary @@ -209,7 +210,6 @@ end function extended_state_priority(state_priority, var_to_diff, ag::AliasGraph) sp = map(state_priority, 1:length(var_to_diff)) - diff_to_var = invview(var_to_diff) prop_graph = SimpleDiGraph{Int}(ag) for (v, dv) in enumerate(var_to_diff) dv isa Int && add_edge!(prop_graph, v, dv) @@ -228,16 +228,17 @@ function prop_state_priority!(sp, graph) visited = BitSet() function visit!(sp, graph, v) push!(visited, v) - for n in inneighbors(graph, v) - n in visited && continue - visit!(sp, graph, n) - end for n in outneighbors(graph, v) + n in visited && continue sp[n] = maxabs(sp[n], sp[v]) + visit!(sp, graph, n) end end for v in vertices(graph) - visit!(sp, graph, v) + if all(x -> iszero(sp[v]), inneighbors(graph, v)) + visit!(sp, graph, v) + empty!(visited) + end end sp end diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 70623d5037..2fbdc205c7 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -401,8 +401,8 @@ function AliasGraph(mm::SparseMatrixCLIL, nv::Int) end isneg(ag::AliasGraph, a, b) = (a, b) in ag.neg_edge for f in [:dst, :edges, :edgetype, :has_edge, :has_vertex, - :inneighbors, :is_directed, :ne, :nv, :outneighbors, - :src, :vertices] + :inneighbors, :is_directed, :ne, :nv, :outneighbors, + :src, :vertices] @eval Graphs.$f(ag::AliasGraph) = Graphs.$f(ag.graph) end From babf18f8fcdadca85045251b6b23b74f5a0a7a30 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 28 Oct 2024 13:16:59 -0400 Subject: [PATCH 5/5] Fix starting condition --- src/structural_transformation/partial_state_selection.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index 0eab0da4fb..18886e6572 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -235,7 +235,7 @@ function prop_state_priority!(sp, graph) end end for v in vertices(graph) - if all(x -> iszero(sp[v]), inneighbors(graph, v)) + if isempty(inneighbors(graph, v)) visit!(sp, graph, v) empty!(visited) end