From 2dccdc39662477335a3ffe271e584df164c94f38 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 28 May 2024 15:46:47 +0530 Subject: [PATCH 1/4] feat: allow parameters in ODESystem to be unknowns in initialization system --- src/systems/nonlinear/initializesystem.jl | 29 ++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 2421f20bf2..8c997715dc 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -5,6 +5,7 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi """ function generate_initializesystem(sys::ODESystem; u0map = Dict(), + pmap = Dict(), name = nameof(sys), guesses = Dict(), check_defguess = false, default_dd_value = 0.0, @@ -69,6 +70,32 @@ function generate_initializesystem(sys::ODESystem; defs = merge(defaults(sys), filtered_u0) guesses = merge(get_guesses(sys), todict(guesses), dd_guess) + all_params = parameters(sys) + pars = [parameters(sys); get_iv(sys)] + paramsubs = Dict() + for p in all_params + haskey(pmap, p) && continue + paramsubs[p] = tovar(p) + push!(full_states, tovar(p)) + deleteat!(pars, findfirst(isequal(p), pars)) + if haskey(defs, p) + def = defs[p] + if def isa Equation + p ∉ keys(guesses) && check_defguess && + error("Invalid setup: parameter $(p) has an initial condition equation with no guess.") + push!(eqs_ics, def) + push!(u0, p => guesses[p]) + else + push!(eqs_ics, p ~ def) + push!(u0, p => def) + end + elseif haskey(guesses, p) + push!(u0, p => guesses[p]) + elseif check_defguess + error("Invalid setup: parameter $(p) has no default value or initial guess") + end + end + if !algebraic_only for st in full_states if st ∈ keys(defs) @@ -91,12 +118,12 @@ function generate_initializesystem(sys::ODESystem; end end - pars = [parameters(sys); get_iv(sys)] nleqs = if algebraic_only [eqs_ics; observed(sys)] else [eqs_ics; get_initialization_eqs(sys); observed(sys)] end + nleqs = fast_substitute(nleqs, paramsubs) sys_nl = NonlinearSystem(nleqs, full_states, From 41fe639d409565c4a4cb3f78c2e9d73adb1ac898 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 28 May 2024 15:47:18 +0530 Subject: [PATCH 2/4] feat: support unknown parameters during initialization --- src/systems/diffeqs/abstractodesystem.jl | 69 ++++++++++++++++++------ 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index b976945f79..b23ff7e51d 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -324,6 +324,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, split_idxs = nothing, initializeprob = nothing, initializeprobmap = nothing, + initializeprob_updatep! = nothing, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`") @@ -506,7 +507,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, sparsity = sparsity ? jacobian_sparsity(sys) : nothing, analytic = analytic, initializeprob = initializeprob, - initializeprobmap = initializeprobmap) + initializeprobmap = initializeprobmap, + initializeprob_updatep! = initializeprob_updatep!) end """ @@ -538,6 +540,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) checkbounds = false, initializeprob = nothing, initializeprobmap = nothing, + initializeprob_updatep! = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`") @@ -611,7 +614,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) jac_prototype = jac_prototype, observed = observedfun, initializeprob = initializeprob, - initializeprobmap = initializeprobmap) + initializeprobmap = initializeprobmap, + initializeprob_updatep! = initializeprob_updatep!) end function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...) @@ -862,7 +866,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; varmap = canonicalize_varmap(varmap) varlist = collect(map(unwrap, dvs)) missingvars = setdiff(varlist, collect(keys(varmap))) - # Append zeros to the variables which are determined by the initialization system # This essentially bypasses the check for if initial conditions are defined for DAEs # since they will be checked in the initialization problem's construction @@ -873,11 +876,14 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap)) elseif parammap isa AbstractArray if isempty(parammap) - parammap = SciMLBase.NullParameters() + parammap = Dict() else parammap = Dict(unwrap.(parameters(sys)) .=> parammap) end + elseif parammap === nothing || parammap isa SciMLBase.NullParameters + parammap = Dict() end + missingpars = setdiff(parameters(sys), keys(parammap)) if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing clockedparammap = Dict() @@ -886,7 +892,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; v = unwrap(v) is_discrete_domain(v) || continue op = operation(v) - if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() && + if !isa(op, Symbolics.Operator) && !isempty(parammap) && haskey(parammap, v) error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).") end @@ -909,7 +915,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; # TODO: make it work with clocks # ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first if sys isa ODESystem && build_initializeprob && - (implicit_dae || !isempty(missingvars)) && + (implicit_dae || !isempty(missingvars) || !isempty(missingpars)) && all(isequal(Continuous()), ci.var_domain) && ModelingToolkit.get_tearing_state(sys) !== nothing && t !== nothing @@ -921,15 +927,43 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; end initializeprob = ModelingToolkit.InitializationProblem( sys, t, u0map, parammap; guesses, warn_initialize_determined) - initializeprobmap = getu(initializeprob, unknowns(sys)) - + unks = unknowns(sys) + initializeprobmap = isempty(unks) ? (_...) -> nothing : + getu(initializeprob, unknowns(sys)) + if any(p -> is_variable(initializeprob, p) || is_observed(initializeprob, p), + parameters(sys)) + punknowns = [p + for p in parameters(sys) + if is_variable(initializeprob, p) || + is_observed(initializeprob, p)] + initializeprob_updatep! = let getter = getu(initializeprob, tovar.(punknowns)), + setter = setp(sys, punknowns) + + function (ps, initsol) + setter(ps, getter(initsol)) + end + end + else + punknowns = [] + initializeprob_updatep! = nothing + end zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0) + zeropars = Dict() + for p in punknowns + zeropars[p] = if Symbolics.isarraysymbolic(p) + collect(unwrap.(zero(p))) + else + unwrap(zero(p)) + end + end trueinit = collect(merge(zerovars, eltype(u0map) <: Pair ? todict(u0map) : u0map)) u0map isa StaticArraysCore.StaticArray && (trueinit = SVector{length(trueinit)}(trueinit)) else initializeprob = nothing initializeprobmap = nothing + initializeprob_updatep! = nothing + zeropars = Dict() trueinit = u0map end @@ -940,7 +974,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; parammap == SciMLBase.NullParameters() && isempty(defs) nothing else - MTKParameters(sys, parammap, trueinit) + if parammap === nothing || parammap == SciMLBase.NullParameters() + parammap = Dict() + else + parammap = todict(parammap) + end + MTKParameters(sys, merge(parammap, zeropars), trueinit) end else u0, p, defs = get_u0_p(sys, @@ -975,6 +1014,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; sparse = sparse, eval_expression = eval_expression, initializeprob = initializeprob, initializeprobmap = initializeprobmap, + initializeprob_updatep! = initializeprob_updatep!, kwargs...) implicit_dae ? (f, du0, u0, p) : (f, u0, p) end @@ -1602,13 +1642,15 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`") end + parammap = parammap isa SciMLBase.NullParameters ? Dict() : todict(parammap) if isempty(u0map) && get_initializesystem(sys) !== nothing isys = get_initializesystem(sys) elseif isempty(u0map) && get_initializesystem(sys) === nothing - isys = structural_simplify(generate_initializesystem(sys); fully_determined = false) + isys = structural_simplify( + generate_initializesystem(sys; pmap = parammap); fully_determined = false) else isys = structural_simplify( - generate_initializesystem(sys; u0map); fully_determined = false) + generate_initializesystem(sys; u0map, pmap = parammap); fully_determined = false) end uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)]) @@ -1628,10 +1670,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, if warn_initialize_determined && neqs < nunknown @warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false." end - - parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ? - [get_iv(sys) => t] : - merge(todict(parammap), Dict(get_iv(sys) => t)) + parammap[get_iv(sys)] = t if isempty(u0map) u0map = Dict() end From ccfed66b8a25b4ba8e301fa0521cc890dcb65711 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 29 May 2024 15:41:58 +0530 Subject: [PATCH 3/4] fixup! feat: support unknown parameters during initialization --- src/systems/abstractsystem.jl | 5 +- src/systems/diffeqs/abstractodesystem.jl | 69 ++++++++++++------------ 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 739451509b..c1557c9495 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1849,6 +1849,8 @@ function linearization_function(sys::AbstractSystem, inputs, end initfn = NonlinearFunction(initsys) initprobmap = getu(initsys, unknowns(sys)) + initprob_init! = generate_initializeprob_init(sys, initsys) + initprob_update! = generate_initializeprob_update(sys, initsys) ps = full_parameters(sys) lin_fun = let diff_idxs = diff_idxs, alge_idxs = alge_idxs, @@ -1856,7 +1858,8 @@ function linearization_function(sys::AbstractSystem, inputs, sts = unknowns(sys), get_initprob_u_p = get_initprob_u_p, fun = ODEFunction{true, SciMLBase.FullSpecialize}( - sys, unknowns(sys), ps; initializeprobmap = initprobmap), + sys, unknowns(sys), ps; initializeprob_init! = initprob_init!, + initializeprob_update! = initprob_update!), initfn = initfn, h = build_explicit_observed_function(sys, outputs), chunk = ForwardDiff.Chunk(input_idxs), diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index b23ff7e51d..15ded1c372 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -280,6 +280,25 @@ function isautonomous(sys::AbstractODESystem) all(iszero, tgrad) end +struct GetAndSetFunctor{G, S} + getter::G + setter::S +end + +function (gs::GetAndSetFunctor)(dest, source) + gs.setter(dest, gs.getter(source)) +end + +function generate_initializeprob_init(sys::AbstractSystem, initsys::AbstractSystem) + syms = vcat(variable_symbols(initsys), parameter_symbols(initsys)) + return GetAndSetFunctor(getu(sys, syms), setu(initsys, syms)) +end + +function generate_initializeprob_update(sys::AbstractSystem, initsys::AbstractSystem) + syms = vcat(variable_symbols(sys), parameter_symbols(sys)) + return GetAndSetFunctor(getu(initsys, syms), setu(sys, syms)) +end + """ ```julia DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys), @@ -323,8 +342,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, analytic = nothing, split_idxs = nothing, initializeprob = nothing, - initializeprobmap = nothing, - initializeprob_updatep! = nothing, + initializeprob_init! = nothing, + initializeprob_update! = nothing, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`") @@ -507,8 +526,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, sparsity = sparsity ? jacobian_sparsity(sys) : nothing, analytic = analytic, initializeprob = initializeprob, - initializeprobmap = initializeprobmap, - initializeprob_updatep! = initializeprob_updatep!) + initializeprob_init! = initializeprob_init!, + initializeprob_update! = initializeprob_update!) end """ @@ -539,8 +558,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) eval_module = @__MODULE__, checkbounds = false, initializeprob = nothing, - initializeprobmap = nothing, - initializeprob_updatep! = nothing, + initializeprob_init! = nothing, + initializeprob_update! = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`") @@ -614,8 +633,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) jac_prototype = jac_prototype, observed = observedfun, initializeprob = initializeprob, - initializeprobmap = initializeprobmap, - initializeprob_updatep! = initializeprob_updatep!) + initializeprob_init! = initializeprob_init!, + initializeprob_update! = initializeprob_update!) end function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...) @@ -927,26 +946,11 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; end initializeprob = ModelingToolkit.InitializationProblem( sys, t, u0map, parammap; guesses, warn_initialize_determined) - unks = unknowns(sys) - initializeprobmap = isempty(unks) ? (_...) -> nothing : - getu(initializeprob, unknowns(sys)) - if any(p -> is_variable(initializeprob, p) || is_observed(initializeprob, p), - parameters(sys)) - punknowns = [p - for p in parameters(sys) - if is_variable(initializeprob, p) || - is_observed(initializeprob, p)] - initializeprob_updatep! = let getter = getu(initializeprob, tovar.(punknowns)), - setter = setp(sys, punknowns) - - function (ps, initsol) - setter(ps, getter(initsol)) - end - end - else - punknowns = [] - initializeprob_updatep! = nothing - end + punknowns = [p + for p in parameters(sys) + if is_variable(initializeprob, p) || is_observed(initializeprob, p)] + initializeprob_init! = generate_initializeprob_init(sys, initializeprob.f.sys) + initializeprob_update! = generate_initializeprob_update(sys, initializeprob.f.sys) zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0) zeropars = Dict() for p in punknowns @@ -961,9 +965,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; (trueinit = SVector{length(trueinit)}(trueinit)) else initializeprob = nothing - initializeprobmap = nothing - initializeprob_updatep! = nothing zeropars = Dict() + initializeprob_init! = nothing + initializeprob_update! = nothing trueinit = u0map end @@ -1012,9 +1016,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; checkbounds = checkbounds, p = p, linenumbers = linenumbers, parallel = parallel, simplify = simplify, sparse = sparse, eval_expression = eval_expression, - initializeprob = initializeprob, - initializeprobmap = initializeprobmap, - initializeprob_updatep! = initializeprob_updatep!, + initializeprob = initializeprob, initializeprob_init! = initializeprob_init!, + initializeprob_update! = initializeprob_update!, kwargs...) implicit_dae ? (f, du0, u0, p) : (f, u0, p) end From f217bd39694e26f01b93c921c0d75ac09a3445e1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 29 May 2024 15:42:14 +0530 Subject: [PATCH 4/4] fix: fix SII.observed for time-independent systems --- src/systems/abstractsystem.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index c1557c9495..b75bd7dabf 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -495,10 +495,18 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym) end function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym) - return let _fn = build_explicit_observed_function(sys, sym) - fn(u, p, t) = _fn(u, p, t) - fn(u, p::MTKParameters, t) = _fn(u, p..., t) - fn + if is_time_dependent(sys) + return let _fn = build_explicit_observed_function(sys, sym) + fn(u, p, t) = _fn(u, p, t) + fn(u, p::MTKParameters, t) = _fn(u, p..., t) + fn + end + else + return let _fn = build_explicit_observed_function(sys, sym) + fn2(u, p) = _fn(u, p) + fn2(u, p::MTKParameters) = _fn(u, p...) + fn2 + end end end