diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index aa7ad8bd69..bf3bbb227d 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -5,109 +5,92 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi """ function generate_initializesystem(sys::ODESystem; u0map = Dict(), - name = nameof(sys), - guesses = Dict(), check_defguess = false, - default_dd_value = 0.0, - algebraic_only = false, initialization_eqs = [], - check_units = true, - kwargs...) - sts, eqs = unknowns(sys), equations(sys) + guesses = Dict(), + default_dd_guess = 0.0, + algebraic_only = false, + check_units = true, check_defguess = false, + name = nameof(sys), kwargs...) + vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)]) + vars_set = Set(vars) # for efficient in-lookup + + eqs = equations(sys) idxs_diff = isdiffeq.(eqs) idxs_alge = .!idxs_diff - num_alge = sum(idxs_alge) - - # Start the equations list with algebraic equations - eqs_ics = eqs[idxs_alge] - u0 = Vector{Pair}(undef, 0) + # prepare map for dummy derivative substitution eqs_diff = eqs[idxs_diff] - diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs)) - observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=> - Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs))) - full_diffmap = merge(diffmap, observed_diffmap) + D = Differential(get_iv(sys)) + diffmap = merge( + Dict(eq.lhs => eq.rhs for eq in eqs_diff), + Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys)) + ) - full_states = unique([sts; getfield.((observed(sys)), :lhs)]) - set_full_states = Set(full_states) + # 1) process dummy derivatives and u0map into initialization system + eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations + defs = copy(defaults(sys)) # copy so we don't modify sys.defaults guesses = merge(get_guesses(sys), todict(guesses)) schedule = getfield(sys, :schedule) - - if schedule !== nothing - guessmap = [x[1] => get(guesses, x[1], default_dd_value) - for x in schedule.dummy_sub] - dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap)) - if u0map === nothing || isempty(u0map) - filtered_u0 = u0map - else - filtered_u0 = Pair[] - for x in u0map - y = get(schedule.dummy_sub, x[1], x[1]) - y = ModelingToolkit.fixpoint_sub(y, full_diffmap) - - if y ∈ set_full_states - # defer initialization until defaults are merged below - push!(filtered_u0, y => x[2]) + if !isnothing(schedule) + for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub) + # set dummy derivatives to default_dd_guess unless specified + push!(defs, x[1] => get(guesses, x[1], default_dd_guess)) + end + if !isnothing(u0map) + for (y, x) in u0map + y = get(schedule.dummy_sub, y, y) + y = fixpoint_sub(y, diffmap) + if y ∈ vars_set + # variables specified in u0 overrides defaults + push!(defs, y => x) elseif y isa Symbolics.Arr - # scalarize array # TODO: don't scalarize arrays - _y = collect(y) - for i in eachindex(_y) - push!(filtered_u0, _y[i] => x[2][i]) - end + # TODO: don't scalarize arrays + push!(defs, collect(y) .=> x) elseif y isa Symbolics.BasicSymbolic - # y is a derivative expression expanded - # add to the initialization equations - push!(eqs_ics, y ~ x[2]) + # y is a derivative expression expanded; add it to the initialization equations + push!(eqs_ics, y ~ x) else error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.") end end - filtered_u0 = todict(filtered_u0) end - else - dd_guess = Dict() - filtered_u0 = todict(u0map) end - defs = merge(defaults(sys), filtered_u0) - - for st in full_states - if st ∈ keys(defs) - def = defs[st] - + # 2) process other variables + for var in vars + if var ∈ keys(defs) + def = defs[var] if def isa Equation - st ∉ keys(guesses) && check_defguess && - error("Invalid setup: unknown $(st) has an initial condition equation with no guess.") + # TODO: this behavior is not tested! + var ∉ keys(guesses) && check_defguess && + error("Invalid setup: variable $(var) has an initial condition equation with no guess.") push!(eqs_ics, def) - push!(u0, st => guesses[st]) + push!(defs, var => guesses[var]) else - push!(eqs_ics, st ~ def) - push!(u0, st => def) + push!(eqs_ics, var ~ def) end - elseif st ∈ keys(guesses) - push!(u0, st => guesses[st]) + elseif var ∈ keys(guesses) + push!(defs, var => guesses[var]) elseif check_defguess - error("Invalid setup: unknown $(st) has no default value or initial guess") + error("Invalid setup: variable $(var) has no default value or initial guess") end end + # 3) process explicitly provided initialization equations if !algebraic_only - for eq in [get_initialization_eqs(sys); initialization_eqs] - _eq = ModelingToolkit.fixpoint_sub(eq, full_diffmap) - push!(eqs_ics, _eq) + initialization_eqs = [get_initialization_eqs(sys); initialization_eqs] + for eq in initialization_eqs + eq = fixpoint_sub(eq, diffmap) # expand dummy derivatives + push!(eqs_ics, eq) end end - pars = [parameters(sys); get_iv(sys)] - nleqs = [eqs_ics; observed(sys)] - - sys_nl = NonlinearSystem(nleqs, - full_states, - pars; - defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess), - parameter_dependencies = parameter_dependencies(sys), + pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter + eqs_ics = [eqs_ics; observed(sys)] + return NonlinearSystem( + eqs_ics, vars, pars; + defaults = defs, parameter_dependencies = parameter_dependencies(sys), checks = check_units, - name, - kwargs...) - - return sys_nl + name, kwargs... + ) end