diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index eff19afb07..2091b41125 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -11,9 +11,35 @@ function generate_initializesystem(sys::ODESystem; default_dd_guess = 0.0, algebraic_only = false, check_units = true, check_defguess = false, + implicit_dae = false, name = nameof(sys), kwargs...) trueobs, eqs = unhack_observed(observed(sys), equations(sys)) vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) + + if implicit_dae + pre_simplification_sys = sys + while get_parent(pre_simplification_sys) !== nothing + pre_simplification_sys = get_parent(pre_simplification_sys) + end + schedule = get_schedule(sys) + if schedule === nothing + throw(ArgumentError("The system must be structurally simplified to create an initialization system for an implicit DAE.")) + end + old_eqs = equations(pre_simplification_sys) + inv_dummy_sub = Dict() + for (k, v) in schedule.dummy_sub + if isequal(default_toterm(k), v) + inv_dummy_sub[v] = k + end + end + new_eqs = Symbolics.fast_substitute.([trueobs; eqs], (inv_dummy_sub,)) + filter!(eq -> !isequal(eq.lhs, eq.rhs), new_eqs) + new_sys = ODESystem(new_eqs, get_iv(sys); name = nameof(sys)) + new_sys = dummy_derivative(new_sys; to_index_zero = true, array_hack = false, cse_hack = false) + trueobs = observed(new_sys) + eqs = equations(new_sys) + vars = unique([unknowns(new_sys); getfield.(trueobs, :lhs)]) + end vars_set = Set(vars) # for efficient in-lookup idxs_diff = isdiffeq.(eqs)