Skip to content

Commit

Permalink
Clean up generate_initializesystem()
Browse files Browse the repository at this point in the history
  • Loading branch information
hersle committed Oct 3, 2024
1 parent 85d8d10 commit b237706
Showing 1 changed file with 58 additions and 75 deletions.
133 changes: 58 additions & 75 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,109 +5,92 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
"""
function generate_initializesystem(sys::ODESystem;
u0map = Dict(),
name = nameof(sys),
guesses = Dict(), check_defguess = false,
default_dd_value = 0.0,
algebraic_only = false,
initialization_eqs = [],
check_units = true,
kwargs...)
sts, eqs = unknowns(sys), equations(sys)
guesses = Dict(),
default_dd_guess = 0.0,
algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), kwargs...)
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
vars_set = Set(vars) # for efficient in-lookup

eqs = equations(sys)
idxs_diff = isdiffeq.(eqs)
idxs_alge = .!idxs_diff
num_alge = sum(idxs_alge)

# Start the equations list with algebraic equations
eqs_ics = eqs[idxs_alge]
u0 = Vector{Pair}(undef, 0)

# prepare map for dummy derivative substitution
eqs_diff = eqs[idxs_diff]
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
full_diffmap = merge(diffmap, observed_diffmap)
D = Differential(get_iv(sys))
diffmap = merge(
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
)

full_states = unique([sts; getfield.((observed(sys)), :lhs)])
set_full_states = Set(full_states)
# 1) process dummy derivatives and u0map into initialization system
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
guesses = merge(get_guesses(sys), todict(guesses))
schedule = getfield(sys, :schedule)

if schedule !== nothing
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
for x in schedule.dummy_sub]
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
if u0map === nothing || isempty(u0map)
filtered_u0 = u0map
else
filtered_u0 = Pair[]
for x in u0map
y = get(schedule.dummy_sub, x[1], x[1])
y = ModelingToolkit.fixpoint_sub(y, full_diffmap)

if y set_full_states
# defer initialization until defaults are merged below
push!(filtered_u0, y => x[2])
if !isnothing(schedule)
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
end
if !isnothing(u0map)
for (y, x) in u0map
y = get(schedule.dummy_sub, y, y)
y = fixpoint_sub(y, diffmap)
if y vars_set
# variables specified in u0 overrides defaults
push!(defs, y => x)
elseif y isa Symbolics.Arr
# scalarize array # TODO: don't scalarize arrays
_y = collect(y)
for i in eachindex(_y)
push!(filtered_u0, _y[i] => x[2][i])
end
# TODO: don't scalarize arrays
push!(defs, collect(y) .=> x)
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded
# add to the initialization equations
push!(eqs_ics, y ~ x[2])
# y is a derivative expression expanded; add it to the initialization equations
push!(eqs_ics, y ~ x)
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
end
filtered_u0 = todict(filtered_u0)
end
else
dd_guess = Dict()
filtered_u0 = todict(u0map)
end

defs = merge(defaults(sys), filtered_u0)

for st in full_states
if st keys(defs)
def = defs[st]

# 2) process other variables
for var in vars
if var keys(defs)
def = defs[var]
if def isa Equation
st keys(guesses) && check_defguess &&
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
# TODO: this behavior is not tested!
var keys(guesses) && check_defguess &&
error("Invalid setup: variable $(var) has an initial condition equation with no guess.")
push!(eqs_ics, def)
push!(u0, st => guesses[st])
push!(defs, var => guesses[var])
else
push!(eqs_ics, st ~ def)
push!(u0, st => def)
push!(eqs_ics, var ~ def)
end
elseif st keys(guesses)
push!(u0, st => guesses[st])
elseif var keys(guesses)
push!(defs, var => guesses[var])
elseif check_defguess
error("Invalid setup: unknown $(st) has no default value or initial guess")
error("Invalid setup: variable $(var) has no default value or initial guess")
end
end

# 3) process explicitly provided initialization equations
if !algebraic_only
for eq in [get_initialization_eqs(sys); initialization_eqs]
_eq = ModelingToolkit.fixpoint_sub(eq, full_diffmap)
push!(eqs_ics, _eq)
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
for eq in initialization_eqs
eq = fixpoint_sub(eq, diffmap) # expand dummy derivatives
push!(eqs_ics, eq)
end
end

pars = [parameters(sys); get_iv(sys)]
nleqs = [eqs_ics; observed(sys)]

sys_nl = NonlinearSystem(nleqs,
full_states,
pars;
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
parameter_dependencies = parameter_dependencies(sys),
pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter
eqs_ics = [eqs_ics; observed(sys)]
return NonlinearSystem(
eqs_ics, vars, pars;
defaults = defs, parameter_dependencies = parameter_dependencies(sys),
checks = check_units,
name,
kwargs...)

return sys_nl
name, kwargs...
)
end

0 comments on commit b237706

Please sign in to comment.