From 298c29c88ae432cd4c4b6280484187a4c1d799c9 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 15 Feb 2024 17:33:52 -0500 Subject: [PATCH 01/29] Array equations/variables support in `structural_simplify` --- src/systems/abstractsystem.jl | 4 ++-- src/systems/connectors.jl | 23 +++++++++++++++-------- src/systems/diffeqs/odesystem.jl | 10 +++++----- src/systems/systemstructure.jl | 16 ++++++++++++++-- src/utils.jl | 15 +++++++++------ test/odesystem.jl | 19 +++++++------------ 6 files changed, 52 insertions(+), 35 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 65e6ee7228..ee6eadd200 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -703,9 +703,9 @@ function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys # metadata from the rescoped variable rescoped = renamespace(n, O) similarterm(O, operation(rescoped), renamed, - metadata = metadata(rescoped))::T + metadata = metadata(rescoped)) else - similarterm(O, operation(O), renamed, metadata = metadata(O))::T + similarterm(O, operation(O), renamed, metadata = metadata(O)) end elseif isvariable(O) renamespace(n, O) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 7419d5c3db..11cfcd4e59 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -292,10 +292,12 @@ function connection2set!(connectionsets, namespace, ss, isouter) end end -function generate_connection_set(sys::AbstractSystem, find = nothing, replace = nothing) +function generate_connection_set( + sys::AbstractSystem, find = nothing, replace = nothing; scalarize = false) connectionsets = ConnectionSet[] domain_csets = ConnectionSet[] - sys = generate_connection_set!(connectionsets, domain_csets, sys, find, replace) + sys = generate_connection_set!( + connectionsets, domain_csets, sys, find, replace, scalarize) csets = merge(connectionsets) domain_csets = merge([csets; domain_csets], true) @@ -303,7 +305,7 @@ function generate_connection_set(sys::AbstractSystem, find = nothing, replace = end function generate_connection_set!(connectionsets, domain_csets, - sys::AbstractSystem, find, replace, namespace = nothing) + sys::AbstractSystem, find, replace, scalarize, namespace = nothing) subsys = get_systems(sys) isouter = generate_isouter(sys) @@ -325,8 +327,13 @@ function generate_connection_set!(connectionsets, domain_csets, end neweq isa AbstractArray ? append!(eqs, neweq) : push!(eqs, neweq) else - if lhs isa Number || lhs isa Symbolic - push!(eqs, eq) # split connections and equations + if lhs isa Number || lhs isa Symbolic || eltype(lhs) <: Symbolic + # split connections and equations + if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray + append!(eqs, Symbolics.scalarize(eq)) + else + push!(eqs, eq) + end elseif lhs isa Connection && get_systems(lhs) === :domain connection2set!(domain_csets, namespace, get_systems(rhs), isouter) else @@ -356,7 +363,7 @@ function generate_connection_set!(connectionsets, domain_csets, end @set! sys.systems = map( s -> generate_connection_set!(connectionsets, domain_csets, s, - find, replace, + find, replace, scalarize, renamespace(namespace, s)), subsys) @set! sys.eqs = eqs @@ -471,8 +478,8 @@ function domain_defaults(sys, domain_csets) end function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing; - debug = false, tol = 1e-10) - sys, (csets, domain_csets) = generate_connection_set(sys, find, replace) + debug = false, tol = 1e-10, scalarize = true) + sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize) ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets) _sys = expand_instream(instream_csets, sys; debug = debug, tol = tol) sys = flatten(sys, true) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 5e238cb7c8..935b1955fb 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -195,13 +195,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; gui_metadata = nothing) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - deqs = scalarize(deqs) + #deqs = scalarize(deqs) @assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters." - iv′ = value(scalarize(iv)) - ps′ = value.(scalarize(ps)) - ctrl′ = value.(scalarize(controls)) - dvs′ = value.(scalarize(dvs)) + iv′ = value(iv) + ps′ = value.(ps) + ctrl′ = value.(controls) + dvs′ = value.(dvs) dvs′ = filter(x -> !isdelay(x, iv), dvs′) if !(isempty(default_u0) && isempty(default_p)) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index c7e3e2230c..3f00410b10 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -268,6 +268,7 @@ function TearingState(sys; quick_cancel = false, check = true) end vars = OrderedSet() + varsvec = [] for (i, eq′) in enumerate(eqs) if eq′.lhs isa Connection check ? error("$(nameof(sys)) has unexpanded `connect` statements") : @@ -282,9 +283,18 @@ function TearingState(sys; quick_cancel = false, check = true) eq = 0 ~ rhs - lhs end vars!(vars, eq.rhs, op = Symbolics.Operator) + for v in vars + v = scalarize(v) + if v isa AbstractArray + v = setmetadata.(v, VariableIrreducible, true) + append!(varsvec, v) + else + push!(varsvec, v) + end + end isalgeq = true unknownvars = [] - for var in vars + for var in varsvec ModelingToolkit.isdelay(var, iv) && continue set_incidence = true @label ANOTHER_VAR @@ -340,6 +350,7 @@ function TearingState(sys; quick_cancel = false, check = true) push!(symbolic_incidence, copy(unknownvars)) empty!(unknownvars) empty!(vars) + empty!(varsvec) if isalgeq eqs[i] = eq else @@ -350,9 +361,10 @@ function TearingState(sys; quick_cancel = false, check = true) # sort `fullvars` such that the mass matrix is as diagonal as possible. dervaridxs = collect(dervaridxs) sorted_fullvars = OrderedSet(fullvars[dervaridxs]) + var_to_old_var = Dict(zip(fullvars, fullvars)) for dervaridx in dervaridxs dervar = fullvars[dervaridx] - diffvar = lower_order_var(dervar) + diffvar = var_to_old_var[lower_order_var(dervar)] if !(diffvar in sorted_fullvars) push!(sorted_fullvars, diffvar) end diff --git a/src/utils.jl b/src/utils.jl index 7a329ce96b..a8d54e4a34 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -348,19 +348,22 @@ function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) end function vars!(vars, O; op = Differential) + !istree(O) && return vars + if operation(O) === (getindex) + arr = first(arguments(O)) + !istree(arr) && return vars + operation(arr) isa op && return push!(vars, O) + isvariable(operation(O)) && return push!(vars, O) + end + if isvariable(O) return push!(vars, O) end - !istree(O) && return vars operation(O) isa op && return push!(vars, O) - if operation(O) === (getindex) && - isvariable(first(arguments(O))) - return push!(vars, O) - end - isvariable(operation(O)) && push!(vars, O) + for arg in arguments(O) vars!(vars, arg; op = op) end diff --git a/test/odesystem.jl b/test/odesystem.jl index 56bf2376d5..dd89428f16 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -517,21 +517,16 @@ eqs = [D(x) ~ x * y using StaticArrays using SymbolicUtils: term using SymbolicUtils.Code -using Symbolics: unwrap, wrap -function foo(a::Num, ms::AbstractVector) - a = unwrap(a) - ms = map(unwrap, ms) - wrap(term(foo, a, term(SVector, ms...))) -end +using Symbolics: unwrap, wrap, @register_symbolic foo(a, ms::AbstractVector) = a + sum(ms) -@variables x(t) ms(t)[1:3] -ms = collect(ms) -eqs = [D(x) ~ foo(x, ms); D.(ms) .~ 1] +@register_symbolic foo(a, ms::AbstractVector) +@variables t x(t) ms(t)[1:3] +D = Differential(t) +eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)] @named sys = ODESystem(eqs, t, [x; ms], []) @named emptysys = ODESystem(Equation[], t) -@named outersys = compose(emptysys, sys) -outersys = complete(outersys) -prob = ODEProblem(outersys, [sys.x => 1.0; collect(sys.ms) .=> 1:3], (0, 1.0)) +@mtkbuild outersys = compose(emptysys, sys) +prob = ODEProblem(outersys, [sys.x => 1.0, sys.ms => 1:3], (0, 1.0)) @test_nowarn solve(prob, Tsit5()) # x/x From 75d343fd3dc7e4e971bd66504396d9abaac7244a Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 16 Feb 2024 13:22:15 -0500 Subject: [PATCH 02/29] Support array states in ODEProblem/ODEFunction Co-authored-by: Aayush Sabharwal --- src/systems/diffeqs/abstractodesystem.jl | 27 ++++++++++++++++++++++++ src/utils.jl | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 561d13cc28..5c433b321d 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -152,6 +152,21 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), nothing, isdde = false, kwargs...) + array_vars = Dict{Any, Vector{Int}}() + for (j, x) in enumerate(dvs) + if istree(x) && operation(x) == getindex + arg = arguments(x)[1] + inds = get!(() -> Int[], array_vars, arg) + push!(inds, j) + end + end + subs = Dict() + for (k, inds) in array_vars + if inds == (inds′ = inds[1]:inds[end]) + inds = inds′ + end + subs[k] = term(view, Sym{Any}(Symbol("ˍ₋arg1")), inds) + end if isdde eqs = delay_to_function(sys) else @@ -164,6 +179,7 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), # substitute x(t) by just x rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] : [eq.rhs for eq in eqs] + rhss = fast_substitute(rhss, subs) # TODO: add an optional check on the ordering of observed equations u = map(x -> time_varying_as_func(value(x), sys), dvs) @@ -764,6 +780,17 @@ function get_u0_p(sys, defs = mergedefaults(defs, parammap, ps) end defs = mergedefaults(defs, u0map, dvs) + for (k, v) in defs + if Symbolics.isarraysymbolic(k) + ks = scalarize(k) + length(ks) == length(v) || error("$k has default value $v with unmatched size") + for (kk, vv) in zip(ks, v) + if !haskey(defs, kk) + defs[kk] = vv + end + end + end + end if symbolic_u0 u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) diff --git a/src/utils.jl b/src/utils.jl index a8d54e4a34..dd1bb5340c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -810,7 +810,7 @@ end function fast_substitute(eq::T, subs::Pair) where {T <: Eq} T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs)) end -fast_substitute(eqs::AbstractArray{<:Eq}, subs) = fast_substitute.(eqs, (subs,)) +fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,)) fast_substitute(a, b) = substitute(a, b) function fast_substitute(expr, pair::Pair) a, b = pair From ae6bd4900f101ebd7bb2e055060f9a382a0bfa3a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 Feb 2024 17:37:37 +0530 Subject: [PATCH 03/29] feat: scalarize equations in ODESystem, fix vars! and namespace_expr --- src/systems/abstractsystem.jl | 3 +++ src/systems/diffeqs/odesystem.jl | 4 +--- src/utils.jl | 14 +++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index ee6eadd200..efacfe6cc6 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -704,6 +704,9 @@ function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys rescoped = renamespace(n, O) similarterm(O, operation(rescoped), renamed, metadata = metadata(rescoped)) + elseif Symbolics.isarraysymbolic(O) + # promote_symtype doesn't work for array symbolics + similarterm(O, operation(O), renamed, symtype(O), metadata = metadata(O)) else similarterm(O, operation(O), renamed, metadata = metadata(O)) end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 935b1955fb..f2035cee50 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -195,9 +195,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; gui_metadata = nothing) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - #deqs = scalarize(deqs) @assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters." - + deqs = reduce(vcat, scalarize(deqs); init = Equation[]) iv′ = value(iv) ps′ = value.(ps) ctrl′ = value.(controls) @@ -236,7 +235,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; end function ODESystem(eqs, iv; kwargs...) - eqs = scalarize(eqs) # NOTE: this assumes that the order of algebraic equations doesn't matter diffvars = OrderedSet() allunknowns = OrderedSet() diff --git a/src/utils.jl b/src/utils.jl index dd1bb5340c..d4e068388b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -342,22 +342,22 @@ v == Set([D(y), u]) function vars(exprs::Symbolic; op = Differential) istree(exprs) ? vars([exprs]; op = op) : Set([exprs]) end +vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op) +vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op) vars(exprs; op = Differential) = foldl((x, y) -> vars!(x, y; op = op), exprs; init = Set()) vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) end function vars!(vars, O; op = Differential) + if isvariable(O) && !(istree(O) && operation(O) === getindex) + return push!(vars, O) + end + !istree(O) && return vars if operation(O) === (getindex) arr = first(arguments(O)) - !istree(arr) && return vars - operation(arr) isa op && return push!(vars, O) - isvariable(operation(O)) && return push!(vars, O) - end - - if isvariable(O) - return push!(vars, O) + return vars!(vars, arr) end operation(O) isa op && return push!(vars, O) From 0287c1d4fb70409275114a3ed4086ac7ce1c0dcf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 15 Feb 2024 17:41:10 +0530 Subject: [PATCH 04/29] feat!: do not scalarize parameters, fix some tests --- src/systems/abstractsystem.jl | 3 +- src/systems/callbacks.jl | 7 +- src/systems/clock_inference.jl | 53 ++-- src/systems/diffeqs/abstractodesystem.jl | 49 ++-- src/systems/diffeqs/odesystem.jl | 4 +- src/systems/diffeqs/sdesystem.jl | 1 - src/systems/index_cache.jl | 114 ++++++-- .../optimization/constraints_system.jl | 2 +- .../optimization/optimizationsystem.jl | 2 +- src/systems/parameter_buffer.jl | 261 +++++++++++------- src/utils.jl | 2 +- src/variables.jl | 2 +- test/mass_matrix.jl | 7 +- test/split_parameters.jl | 42 +-- 14 files changed, 349 insertions(+), 200 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index efacfe6cc6..d177f78c33 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1521,13 +1521,12 @@ function linearization_function(sys::AbstractSystem, inputs, sys = ssys x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op) u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true) + ps = parameters(sys) if has_index_cache(sys) && get_index_cache(sys) !== nothing p = MTKParameters(sys, p) - ps = reorder_parameters(sys, parameters(sys)) else p = _p p, split_idxs = split_parameters_by_type(p) - ps = parameters(sys) if p isa Tuple ps = Base.Fix1(getindex, ps).(split_idxs) ps = (ps...,) #if p is Tuple, ps should be Tuple diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 5efea3a77d..e0e0e7e7c8 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -390,11 +390,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin if has_index_cache(sys) && get_index_cache(sys) !== nothing ic = get_index_cache(sys) update_inds = map(update_vars) do sym - @unpack portion, idx = parameter_index(sys, sym) - if portion == SciMLStructures.Discrete() - idx += length(ic.param_idx) - end - idx + pind = parameter_index(sys, sym) + discrete_linear_index(ic, pind) end else psind = Dict(reverse(en) for en in enumerate(ps)) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 1e6da7786a..07b684de88 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -146,6 +146,7 @@ function generate_discrete_affect( @static if VERSION < v"1.7" error("The `generate_discrete_affect` function requires at least Julia 1.7") end + has_index_cache(osys) && get_index_cache(osys) !== nothing || error("System must have index_cache for clock support") out = Sym{Any}(:out) appended_parameters = parameters(syss[continuous_id]) param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p) @@ -169,14 +170,14 @@ function generate_discrete_affect( push!(fullvars, s) end needed_disc_to_cont_obs = [] - disc_to_cont_idxs = Int[] + disc_to_cont_idxs = ParameterIndex[] for v in inputs[continuous_id] vv = arguments(v)[1] if vv in fullvars push!(needed_disc_to_cont_obs, vv) # @show param_to_idx[v] v # @assert param_to_idx[v].portion isa SciMLStructures.Discrete # TOOD: remove - push!(disc_to_cont_idxs, param_to_idx[v].idx) + push!(disc_to_cont_idxs, param_to_idx[v]) end end append!(appended_parameters, input, unknowns(sys)) @@ -201,39 +202,36 @@ function generate_discrete_affect( ], [], let_block) - cont_to_disc_idxs = [parameter_index(osys, sym).idx for sym in input] - disc_range = [parameter_index(osys, sym).idx for sym in unknowns(sys)] + cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] + disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] save_vec = Expr(:ref, :Float64) for unk in unknowns(sys) - idx = parameter_index(osys, unk).idx - push!(save_vec.args, :(discretes[$idx])) + idx = parameter_index(osys, unk) + push!(save_vec.args, :($(parameter_values)(p, $idx))) end empty_disc = isempty(disc_range) disc_init = :(function (p, t) d2c_obs = $disc_to_cont_obs + disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range) + result = d2c_obs(disc_state, p..., t) + for (val, i) in zip(result, $disc_to_cont_idxs) + # prevent multiple updates to dependents + _set_parameter_unchecked!(p, val, i; update_dependent = false) + end discretes, repack, _ = $(SciMLStructures.canonicalize)( $(SciMLStructures.Discrete()), p) - d2c_view = view(discretes, $disc_to_cont_idxs) - disc_state = view(discretes, $disc_range) - copyto!(d2c_view, d2c_obs(disc_state, p..., t)) - repack(discretes) + repack(discretes) # to force recalculation of dependents end) # @show disc_to_cont_idxs # @show cont_to_disc_idxs # @show disc_range - affect! = :(function (integrator, saved_values) @unpack u, p, t = integrator c2d_obs = $cont_to_disc_obs d2c_obs = $disc_to_cont_obs - # Like Sample - discretes, repack, _ = $(SciMLStructures.canonicalize)( - $(SciMLStructures.Discrete()), p) - c2d_view = view(discretes, $cont_to_disc_idxs) - # Like Hold - d2c_view = view(discretes, $disc_to_cont_idxs) - disc_unknowns = view(discretes, $disc_range) + # TODO: find a way to do this without allocating + disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range] disc = $disc push!(saved_values.t, t) @@ -248,12 +246,25 @@ function generate_discrete_affect( # d2c comes last # @show t # @show "incoming", p - copyto!(c2d_view, c2d_obs(integrator.u, p..., t)) + result = c2d_obs(integrator.u, p..., t) + for (val, i) in zip(result, $cont_to_disc_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end # @show "after c2d", p - $empty_disc || disc(disc_unknowns, integrator.u, p..., t) + if !$empty_disc + disc(disc_unknowns, integrator.u, p..., t) + for (val, i) in zip(disc_unknowns, $disc_range) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end + end # @show "after state update", p - copyto!(d2c_view, d2c_obs(disc_unknowns, p..., t)) + result = d2c_obs(disc_unknowns, p..., t) + for (val, i) in zip(result, $disc_to_cont_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end # @show "after d2c", p + discretes, repack, _ = $(SciMLStructures.canonicalize)( + $(SciMLStructures.Discrete()), p) repack(discretes) end) sv = SavedValues(Float64, Vector{Float64}) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 5c433b321d..0454d02952 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -183,11 +183,7 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), # TODO: add an optional check on the ordering of observed equations u = map(x -> time_varying_as_func(value(x), sys), dvs) - p = if has_index_cache(sys) && get_index_cache(sys) !== nothing - reorder_parameters(get_index_cache(sys), ps isa Tuple ? reduce(vcat, ps) : ps) - else - (map(x -> time_varying_as_func(value(x), sys), ps),) - end + p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps)) t = get_iv(sys) if isdde @@ -802,6 +798,23 @@ function get_u0_p(sys, u0, p, defs end +function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false) + dvs = unknowns(sys) + ps = parameters(sys) + defs = defaults(sys) + if parammap !== nothing + defs = mergedefaults(defs, parammap, ps) + end + defs = mergedefaults(defs, u0map, dvs) + + if symbolic_u0 + u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) + else + u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true) + end + return u0, defs +end + function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; implicit_dae = false, du0map = nothing, version = nothing, tgrad = false, @@ -820,20 +833,24 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; ps = parameters(sys) iv = get_iv(sys) - u0, _p, defs = get_u0_p(sys, - u0map, - parammap; - tofloat, - use_union, - symbolic_u0) - if u0 !== nothing - u0 = u0_constructor(u0) - end - if has_index_cache(sys) && get_index_cache(sys) !== nothing + u0, defs = get_u0(sys, u0map, parammap; symbolic_u0) p = MTKParameters(sys, parammap) else - p = _p + u0, p, defs = get_u0_p(sys, + u0map, + parammap; + tofloat, + use_union, + symbolic_u0) + p, split_idxs = split_parameters_by_type(p) + if p isa Tuple + ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) + ps = (ps...,) #if p is Tuple, ps should be Tuple + end + end + if u0 !== nothing + u0 = u0_constructor(u0) end if implicit_dae && du0map !== nothing diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index f2035cee50..b3254f8b47 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -202,7 +202,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; ctrl′ = value.(controls) dvs′ = value.(dvs) dvs′ = filter(x -> !isdelay(x, iv), dvs′) - if !(isempty(default_u0) && isempty(default_p)) Base.depwarn( "`default_u0` and `default_p` are deprecated. Use `defaults` instead.", @@ -210,7 +209,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; end defaults = todict(defaults) defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults)) - var_to_name = Dict() process_variables!(var_to_name, defaults, dvs′) process_variables!(var_to_name, defaults, ps′) @@ -277,7 +275,7 @@ function ODESystem(eqs, iv; kwargs...) algevars = setdiff(allunknowns, diffvars) # the orders here are very important! return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv, - collect(Iterators.flatten((diffvars, algevars))), ps; kwargs...) + collect(Iterators.flatten((diffvars, algevars))), collect(ps); kwargs...) end # NOTE: equality does not check cached Jacobian diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 751502275c..e873e27ef0 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -407,7 +407,6 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys), error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`") end dvs = scalarize.(dvs) - ps = scalarize.(ps) f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, kwargs...) f_oop, f_iip = eval_expression ? diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 42bb90e804..11cd9fc2bf 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -1,7 +1,8 @@ abstract type SymbolHash end function getsymbolhash(sym) - hasmetadata(sym, SymbolHash) ? getmetadata(sym, SymbolHash) : hash(unwrap(sym)) + sym = unwrap(sym) + hasmetadata(sym, SymbolHash) ? getmetadata(sym, SymbolHash) : hash(sym) end struct BufferTemplate @@ -9,21 +10,28 @@ struct BufferTemplate length::Int end -struct ParameterIndex{P} +const DEPENDENT_PORTION = :dependent +const NONNUMERIC_PORTION = :nonnumeric + +struct ParameterIndex{P, I} portion::P - idx::Int + idx::I end +const IndexMap = Dict{UInt, Tuple{Int, Int}} + struct IndexCache unknown_idx::Dict{UInt, Int} - discrete_idx::Dict{UInt, Int} - param_idx::Dict{UInt, Int} - constant_idx::Dict{UInt, Int} - dependent_idx::Dict{UInt, Int} + discrete_idx::IndexMap + param_idx::IndexMap + constant_idx::IndexMap + dependent_idx::IndexMap + nonnumeric_idx::IndexMap discrete_buffer_sizes::Vector{BufferTemplate} param_buffer_sizes::Vector{BufferTemplate} constant_buffer_sizes::Vector{BufferTemplate} dependent_buffer_sizes::Vector{BufferTemplate} + nonnumeric_buffer_sizes::Vector{BufferTemplate} end function IndexCache(sys::AbstractSystem) @@ -38,6 +46,7 @@ function IndexCache(sys::AbstractSystem) tunable_buffers = Dict{DataType, Set{BasicSymbolic}}() constant_buffers = Dict{DataType, Set{BasicSymbolic}}() dependent_buffers = Dict{DataType, Set{BasicSymbolic}}() + nonnumeric_buffers = Dict{DataType, Set{BasicSymbolic}}() function insert_by_type!(buffers::Dict{DataType, Set{BasicSymbolic}}, sym) sym = unwrap(sym) @@ -83,28 +92,30 @@ function IndexCache(sys::AbstractSystem) haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue insert_by_type!( - if is_discrete_domain(p) - disc_buffers - elseif istunable(p, true) - tunable_buffers + if ctype <: Real || ctype <: Vector{<:Real} + if is_discrete_domain(p) + disc_buffers + elseif istunable(p, true) && size(p) !== Symbolics.Unknown() + tunable_buffers + else + constant_buffers + end else - constant_buffers + nonnumeric_buffers end, p ) end - function get_buffer_sizes_and_idxs(buffers::Dict{DataType, Set{BasicSymbolic}}) - idxs = Dict{UInt, Int}() + function get_buffer_sizes_and_idxs(buffers::Dict{DataType, Set{BasicSymbolic}}, track_linear_index = true) + idxs = IndexMap() buffer_sizes = BufferTemplate[] - idx = 1 - for (T, buf) in buffers - for p in buf + for (i, (T, buf)) in enumerate(buffers) + for (j, p) in enumerate(buf) h = getsymbolhash(p) - idxs[h] = idx + idxs[h] = (i, j) h = getsymbolhash(default_toterm(p)) - idxs[h] = idx - idx += 1 + idxs[h] = (i, j) end push!(buffer_sizes, BufferTemplate(T, length(buf))) end @@ -115,6 +126,7 @@ function IndexCache(sys::AbstractSystem) param_idxs, param_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers) const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers) dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers) + nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers) return IndexCache( unk_idxs, @@ -122,13 +134,47 @@ function IndexCache(sys::AbstractSystem) param_idxs, const_idxs, dependent_idxs, + nonnumeric_idxs, discrete_buffer_sizes, param_buffer_sizes, const_buffer_sizes, - dependent_buffer_sizes + dependent_buffer_sizes, + nonnumeric_buffer_sizes, ) end +function ParameterIndex(ic::IndexCache, p) + p = unwrap(p) + if istree(p) && operation(p) === getindex + sub_idx = Base.tail(arguments(p)) + p = arguments(p)[begin] + else + sub_idx = () + end + h = getsymbolhash(p) + return if haskey(ic.param_idx, h) + ParameterIndex(SciMLStructures.Tunable(), (ic.param_idx[h]..., sub_idx...)) + elseif haskey(ic.discrete_idx, h) + ParameterIndex(SciMLStructures.Discrete(), (ic.discrete_idx[h]..., sub_idx...)) + elseif haskey(ic.constant_idx, h) + ParameterIndex(SciMLStructures.Constants(), (ic.constant_idx[h]..., sub_idx...)) + elseif haskey(ic.dependent_idx, h) + ParameterIndex(DEPENDENT_PORTION, (ic.dependent_idx[h]..., sub_idx...)) + elseif haskey(ic.nonnumeric_idx, h) + ParameterIndex(NONNUMERIC_PORTION, (ic.nonnumeric_idx[h]..., sub_idx...)) + else + nothing + end +end + +function discrete_linear_index(ic::IndexCache, idx::ParameterIndex) + idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected") + ind = sum(temp.length for temp in ic.param_buffer_sizes; init = 0) + ind += sum(temp.length for temp in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1); init = 0) + ind += idx.idx[2] + return ind +end + function reorder_parameters(sys::AbstractSystem, ps; kwargs...) if has_index_cache(sys) && get_index_cache(sys) !== nothing reorder_parameters(get_index_cache(sys), ps; kwargs...) @@ -140,28 +186,36 @@ function reorder_parameters(sys::AbstractSystem, ps; kwargs...) end function reorder_parameters(ic::IndexCache, ps; drop_missing = false) - param_buf = ArrayPartition((fill(variable(:DEF), temp.length) for temp in ic.param_buffer_sizes)...) - disc_buf = ArrayPartition((fill(variable(:DEF), temp.length) for temp in ic.discrete_buffer_sizes)...) - const_buf = ArrayPartition((fill(variable(:DEF), temp.length) for temp in ic.constant_buffer_sizes)...) - dep_buf = ArrayPartition((fill(variable(:DEF), temp.length) for temp in ic.dependent_buffer_sizes)...) + param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.param_buffer_sizes) + disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.discrete_buffer_sizes) + const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.constant_buffer_sizes) + dep_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.dependent_buffer_sizes) + nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.nonnumeric_buffer_sizes) for p in ps h = getsymbolhash(p) if haskey(ic.discrete_idx, h) - disc_buf[ic.discrete_idx[h]] = unwrap(p) + i, j = ic.discrete_idx[h] + disc_buf[i][j] = unwrap(p) elseif haskey(ic.param_idx, h) - param_buf[ic.param_idx[h]] = unwrap(p) + i, j = ic.param_idx[h] + param_buf[i][j] = unwrap(p) elseif haskey(ic.constant_idx, h) - const_buf[ic.constant_idx[h]] = unwrap(p) + i, j = ic.constant_idx[h] + const_buf[i][j] = unwrap(p) elseif haskey(ic.dependent_idx, h) - dep_buf[ic.dependent_idx[h]] = unwrap(p) + i, j = ic.dependent_idx[h] + dep_buf[i][j] = unwrap(p) + elseif haskey(ic.nonnumeric_idx, h) + i, j = ic.nonnumeric_idx[h] + nonnumeric_buf[i][j] = unwrap(p) else error("Invalid parameter $p") end end result = broadcast.( - unwrap, (param_buf.x..., disc_buf.x..., const_buf.x..., dep_buf.x...)) + unwrap, (param_buf..., disc_buf..., const_buf..., nonnumeric_buf..., dep_buf...)) if drop_missing result = map(result) do buf filter(buf) do sym diff --git a/src/systems/optimization/constraints_system.jl b/src/systems/optimization/constraints_system.jl index 2701b12ee3..d17486b0a7 100644 --- a/src/systems/optimization/constraints_system.jl +++ b/src/systems/optimization/constraints_system.jl @@ -119,7 +119,7 @@ function ConstraintsSystem(constraints, unknowns, ps; cstr = value.(Symbolics.canonical_form.(scalarize(constraints))) unknowns′ = value.(scalarize(unknowns)) - ps′ = value.(scalarize(ps)) + ps′ = value.(ps) if !(isempty(default_u0) && isempty(default_p)) Base.depwarn( diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 5af43dd2bc..de16453776 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -98,7 +98,7 @@ function OptimizationSystem(op, unknowns, ps; throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) constraints = value.(scalarize(constraints)) unknowns′ = value.(scalarize(unknowns)) - ps′ = value.(scalarize(ps)) + ps′ = value.(ps) op′ = value(scalarize(op)) if !(isempty(default_u0) && isempty(default_p)) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index e18d1ccbf0..f93900906e 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -1,8 +1,9 @@ -struct MTKParameters{T, D, C, E, F, G} +struct MTKParameters{T, D, C, E, N, F, G} tunable::T discrete::D constant::C dependent::E + nonnumeric::N dependent_update_iip::F dependent_update_oop::G end @@ -35,22 +36,30 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals for (k, v) in p if !haskey(extra_params, unwrap(k))) end - tunable_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.param_buffer_sizes)...) - disc_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.discrete_buffer_sizes)...) - const_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.constant_buffer_sizes)...) - dep_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.dependent_buffer_sizes)...) + tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.param_buffer_sizes) + disc_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.discrete_buffer_sizes) + const_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.constant_buffer_sizes) + dep_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.dependent_buffer_sizes) + nonnumeric_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.nonnumeric_buffer_sizes) dependencies = Dict{Num, Num}() function set_value(sym, val) h = getsymbolhash(sym) if haskey(ic.param_idx, h) - tunable_buffer[ic.param_idx[h]] = val + i, j = ic.param_idx[h] + tunable_buffer[i][j] = val elseif haskey(ic.discrete_idx, h) - disc_buffer[ic.discrete_idx[h]] = val + i, j = ic.discrete_idx[h] + disc_buffer[i][j] = val elseif haskey(ic.constant_idx, h) - const_buffer[ic.constant_idx[h]] = val + i, j = ic.constant_idx[h] + const_buffer[i][j] = val elseif haskey(ic.dependent_idx, h) - dep_buffer[ic.dependent_idx[h]] = val + i, j = ic.dependent_idx[h] + dep_buffer[i][j] = val dependencies[wrap(sym)] = wrap(p[sym]) + elseif haskey(ic.nonnumeric_idx, h) + i, j = ic.nonnumeric_idx[h] + nonnumeric_buffer[i][j] = val elseif !isequal(default_toterm(sym), sym) set_value(default_toterm(sym), val) else @@ -62,25 +71,16 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals sym = unwrap(sym) ctype = concrete_symtype(sym) val = convert(ctype, fixpoint_sub(val, p)) - if size(sym) == () - set_value(sym, val) - else - if length(sym) != length(val) - error("Size of $sym does not match size of initial value $val") - end - for (i, j) in zip(eachindex(sym), eachindex(val)) - set_value(sym[i], val[j]) - end - end + set_value(sym, val) end - dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer.x)...) + dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer)...) for (sym, val) in dependencies h = getsymbolhash(sym) - idx = ic.dependent_idx[h] - dep_exprs[idx] = wrap(fixpoint_sub(val, dependencies)) + i, j = ic.dependent_idx[h] + dep_exprs.x[i][j] = wrap(fixpoint_sub(val, dependencies)) end - p = reorder_parameters(ic, parameters(sys))[begin:(end - length(dep_buffer.x))] + p = reorder_parameters(ic, parameters(sys))[begin:(end - length(dep_buffer))] update_function_iip, update_function_oop = if isempty(dep_exprs.x) nothing, nothing else @@ -90,33 +90,39 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals end # everything is an ArrayPartition so it's easy to figure out how many # distinct vectors we have for each portion as `ArrayPartition.x` - if isempty(tunable_buffer.x) - tunable_buffer = Float64[] - end - if isempty(disc_buffer.x) - disc_buffer = Float64[] - end - if isempty(const_buffer.x) - const_buffer = Float64[] - end - if isempty(dep_buffer.x) - dep_buffer = Float64[] - end - if use_union - tunable_buffer = restrict_array_to_union(tunable_buffer) - disc_buffer = restrict_array_to_union(disc_buffer) - const_buffer = restrict_array_to_union(const_buffer) - dep_buffer = restrict_array_to_union(dep_buffer) - elseif tofloat - tunable_buffer = Float64.(tunable_buffer) - disc_buffer = Float64.(disc_buffer) - const_buffer = Float64.(const_buffer) - dep_buffer = Float64.(dep_buffer) - end + # if use_union + # tunable_buffer = restrict_array_to_union(ArrayPartition(tunable_buffer)) + # disc_buffer = restrict_array_to_union(ArrayPartition(disc_buffer)) + # const_buffer = restrict_array_to_union(ArrayPartition(const_buffer)) + # dep_buffer = restrict_array_to_union(ArrayPartition(dep_buffer)) + # elseif tofloat + # tunable_buffer = Float64.(tunable_buffer) + # disc_buffer = Float64.(disc_buffer) + # const_buffer = Float64.(const_buffer) + # dep_buffer = Float64.(dep_buffer) + # end return MTKParameters{typeof(tunable_buffer), typeof(disc_buffer), typeof(const_buffer), - typeof(dep_buffer), typeof(update_function_iip), typeof(update_function_oop)}( - tunable_buffer, disc_buffer, const_buffer, dep_buffer, update_function_iip, - update_function_oop) + typeof(dep_buffer), typeof(nonnumeric_buffer), typeof(update_function_iip), + typeof(update_function_oop)}(tunable_buffer, disc_buffer, const_buffer, dep_buffer, + nonnumeric_buffer, update_function_iip, update_function_oop) +end + +function buffer_to_arraypartition(buf) + return ArrayPartition((eltype(v) isa AbstractArray ? buffer_to_arraypartition(v) : v for v in buf)...) +end + +function split_into_buffers(raw::AbstractArray, buf; recurse = true) + idx = 1 + function _helper(buf_v; recurse = true) + if eltype(buf_v) isa AbstractArray && recurse + return _helper.(buf_v; recurse = false) + else + res = raw[idx:idx+length(buf_v)-1] + idx += length(buf_v) + return res + end + end + return Tuple(_helper(buf_v; recurse) for buf_v in buf) end SciMLStructures.isscimlstructure(::MTKParameters) = true @@ -127,43 +133,48 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable) (SciMLStructures.Discrete, :discrete) (SciMLStructures.Constants, :constant)] @eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters) - function repack(values) - p.$field .= values + function repack(_) # aliases, so we don't need to use the parameter if p.dependent_update_iip !== nothing - p.dependent_update_iip(p.dependent, p...) + p.dependent_update_iip(ArrayPartition(p.dependent), p...) end p end - return p.$field, repack, true + return buffer_to_arraypartition(p.$field), repack, true end @eval function SciMLStructures.replace(::$Portion, p::MTKParameters, newvals) - @set! p.$field = newvals + @set! p.$field = split_into_buffers(newvals, p.$field) if p.dependent_update_oop !== nothing - @set! p.dependent = ArrayPartition(p.dependent_update_oop(p...)) + raw = p.dependent_update_oop(p...) + @set! p.dependent = split_into_buffers(raw, p.dependent; recurse = false) end p end @eval function SciMLStructures.replace!(::$Portion, p::MTKParameters, newvals) - p.$field .= newvals + src = split_into_buffers(newvals, p.$field) + dst = buffer_to_arraypartition(newvals) + dst .= src if p.dependent_update_iip !== nothing - p.dependent_update_iip(p.dependent, p...) + p.dependent_update_iip(ArrayPartition(p.dependent), p...) end nothing end end -function SymbolicIndexingInterface.parameter_values(p::MTKParameters, i::ParameterIndex) - @unpack portion, idx = i +function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex) + @unpack portion, idx = pind + i, j, k... = idx if portion isa SciMLStructures.Tunable - return p.tunable[idx] + return p.tunable[i][j][k...] elseif portion isa SciMLStructures.Discrete - return p.discrete[idx] + return p.discrete[i][j][k...] elseif portion isa SciMLStructures.Constants - return p.constant[idx] - elseif portion === nothing - return p.dependent[idx] + return p.constant[i][j][k...] + elseif portion === DEPENDENT_PORTION + return p.dependent[i][j][k...] + elseif portion === NONNUMERIC_PORTION + return isempty(k) ? p.nonnumeric[i][j] : p.nonnumeric[i][j][k...] else error("Unhandled portion $portion") end @@ -172,45 +183,88 @@ end function SymbolicIndexingInterface.set_parameter!( p::MTKParameters, val, idx::ParameterIndex) @unpack portion, idx = idx + i, j, k... = idx if portion isa SciMLStructures.Tunable - p.tunable[idx] = val + if isempty(k) + p.tunable[i][j] = val + else + p.tunable[i][j][k...] = val + end elseif portion isa SciMLStructures.Discrete - p.discrete[idx] = val + if isempty(k) + p.discrete[i][j] = val + else + p.discrete[i][j][k...] = val + end elseif portion isa SciMLStructures.Constants - p.constant[idx] = val - elseif portion === nothing - error("Cannot set value of parameter: ") + if isempty(k) + p.constant[i][j] = val + else + p.constant[i][j][k...] = val + end + elseif portion === DEPENDENT_PORTION + error("Cannot set value of dependent parameter") + elseif portion === NONNUMERIC_PORTION + if isempty(k) + p.nonnumeric[i][j] = val + else + p.nonnumeric[i][j][k...] = val + end else error("Unhandled portion $portion") end if p.dependent_update_iip !== nothing - p.dependent_update_iip(p.dependent, p...) + p.dependent_update_iip(ArrayPartition(p.dependent), p...) end end -function _set_parameter_unchecked!(p::MTKParameters, val, idx::ParameterIndex) +function _set_parameter_unchecked!(p::MTKParameters, val, idx::ParameterIndex; update_dependent = true) @unpack portion, idx = idx - update_dependent = true + i, j, k... = idx if portion isa SciMLStructures.Tunable - p.tunable[idx] = val + if isempty(k) + p.tunable[i][j] = val + else + p.tunable[i][j][k...] = val + end elseif portion isa SciMLStructures.Discrete - p.discrete[idx] = val + if isempty(k) + p.discrete[i][j] = val + else + p.discrete[i][j][k...] = val + end elseif portion isa SciMLStructures.Constants - p.constant[idx] = val - elseif portion === nothing - p.dependent[idx] = val + if isempty(k) + p.constant[i][j] = val + else + p.constant[i][j][k...] = val + end + elseif portion === DEPENDENT_PORTION + if isempty(k) + p.dependent[i][j] = val + else + p.dependent[i][j][k...] = val + end update_dependent = false + elseif portion === NONNUMERIC_PORTION + if isempty(k) + p.nonnumeric[i][j] = val + else + p.nonnumeric[i][j][k...] = val + end else error("Unhandled portion $portion") end update_dependent && p.dependent_update_iip !== nothing && - p.dependent_update_iip(p.dependent, p...) + p.dependent_update_iip(ArrayPartition(p.dependent), p...) end _subarrays(v::AbstractVector) = isempty(v) ? () : (v,) _subarrays(v::ArrayPartition) = v.x +_subarrays(v::Tuple) = v _num_subarrays(v::AbstractVector) = 1 _num_subarrays(v::ArrayPartition) = length(v.x) +_num_subarrays(v::Tuple) = length(v) # for compiling callbacks # getindex indexes the vectors, setindex! linearly indexes values # it's inconsistent, but we need it to be this way @@ -227,28 +281,42 @@ function Base.getindex(buf::MTKParameters, i) i <= _num_subarrays(buf.constant) && return _subarrays(buf.constant)[i] i -= _num_subarrays(buf.constant) end - isempty(buf.dependent) || return _subarrays(buf.dependent)[i] + if !isempty(buf.nonnumeric) + i <= _num_subarrays(buf.nonnumeric) && return _subarrays(buf.nonnumeric)[i] + i -= _num_subarrays(buf.nonnumeric) + end + if !isempty(buf.dependent) + i <= _num_subarrays(buf.dependent) && return _subarrays(buf.dependent)[i] + i -= _num_subarrays(buf.dependent) + end throw(BoundsError(buf, i)) end -function Base.setindex!(buf::MTKParameters, val, i) - if i <= length(buf.tunable) - buf.tunable[i] = val - elseif i <= length(buf.tunable) + length(buf.discrete) - buf.discrete[i - length(buf.tunable)] = val - else - buf.constant[i - length(buf.tunable) - length(buf.discrete)] = val +function Base.setindex!(p::MTKParameters, val, i) + function _helper(buf) + done = false + for v in buf + if i <= length(v) + v[i] = val + done = true + else + i -= length(v) + end + end + done end - if buf.dependent_update_iip !== nothing - buf.dependent_update_iip(buf.dependent, buf...) + _helper(p.tunable) || _helper(p.discrete) || _helper(p.constant) || _helper(p.nonnumeric) || throw(BoundsError(p, i)) + if p.dependent_update_iip !== nothing + p.dependent_update_iip(ArrayPartition(p.dependent), p...) end end function Base.iterate(buf::MTKParameters, state = 1) total_len = 0 - isempty(buf.tunable) || (total_len += _num_subarrays(buf.tunable)) - isempty(buf.discrete) || (total_len += _num_subarrays(buf.discrete)) - isempty(buf.constant) || (total_len += _num_subarrays(buf.constant)) - isempty(buf.dependent) || (total_len += _num_subarrays(buf.dependent)) + total_len += _num_subarrays(buf.tunable) + total_len += _num_subarrays(buf.discrete) + total_len += _num_subarrays(buf.constant) + total_len += _num_subarrays(buf.nonnumeric) + total_len += _num_subarrays(buf.dependent) if state <= total_len return (buf[state], state + 1) else @@ -258,15 +326,16 @@ end function Base.:(==)(a::MTKParameters, b::MTKParameters) return a.tunable == b.tunable && a.discrete == b.discrete && - a.constant == b.constant && a.dependent == b.dependent + a.constant == b.constant && a.dependent == b.dependent && + a.nonnumeric == b.nonnumeric end # to support linearize/linearization_function function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where {F, C} - T = eltype(p.tunable) + tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) + T = eltype(tunable) tag = ForwardDiff.Tag(pf, T) dualtype = ForwardDiff.Dual{typeof(tag), T, ForwardDiff.chunksize(chunk)} - tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) p_big = SciMLStructures.replace(SciMLStructures.Tunable(), p, dualtype.(tunable)) p_closure = let pf = pf, input_idxs = input_idxs, @@ -280,7 +349,7 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where # tunable[input_idxs] .= p_small_inner # p_big = repack(tunable) return if pf isa SciMLBase.ParamJacobianWrapper - buffer = similar(p_big.tunable, size(pf.u)) + buffer = Array{dualtype}(undef, size(pf.u)) pf(buffer, p_big) buffer else diff --git a/src/utils.jl b/src/utils.jl index d4e068388b..85d8308c2d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -232,7 +232,7 @@ end function collect_defaults!(defs, vars) for v in vars - (haskey(defs, v) || !hasdefault(v)) && continue + (haskey(defs, v) || !hasdefault(unwrap(v))) && continue defs[v] = getdefault(v) end return defs diff --git a/src/variables.jl b/src/variables.jl index 999c7a9836..ca6d1b6954 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -360,7 +360,7 @@ struct VariableDescription end Symbolics.option_to_metadata_type(::Val{:description}) = VariableDescription getdescription(x::Num) = getdescription(Symbolics.unwrap(x)) - +getdescription(x::Symbolics.Arr) = getdescription(Symbolics.unwrap(x)) """ getdescription(x) diff --git a/test/mass_matrix.jl b/test/mass_matrix.jl index 300bdbc9e9..b67f2da870 100644 --- a/test/mass_matrix.jl +++ b/test/mass_matrix.jl @@ -8,7 +8,7 @@ eqs = [D(y[1]) ~ -k[1] * y[1] + k[3] * y[2] * y[3], D(y[2]) ~ k[1] * y[1] - k[3] * y[2] * y[3] - k[2] * y[2]^2, 0 ~ y[1] + y[2] + y[3] - 1] -@named sys = ODESystem(eqs, t, y, k) +@named sys = ODESystem(eqs, t, y, [k]) sys = complete(sys) @test_throws ArgumentError ODESystem(eqs, y[1]) M = calculate_massmatrix(sys) @@ -16,9 +16,8 @@ M = calculate_massmatrix(sys) 0 1 0 0 0 0] -f = ODEFunction(sys) -prob_mm = ODEProblem(f, [1.0, 0.0, 0.0], (0.0, 1e5), - MTKParameters(sys, (k[1] => 0.04, k[2] => 3e7, k[3] => 1e4))) +prob_mm = ODEProblem(sys, [1.0, 0.0, 0.0], (0.0, 1e5), + [k => [0.04, 3e7, 1e4]]) sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) function rober(du, u, p, t) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index 97a4fc7c78..ff101dcbad 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -33,13 +33,24 @@ t_end = 10.0 time = 0:dt:t_end x = @. time^2 + 1.0 -get_value(data, t, dt) = data[round(Int, t / dt + 1)] -@register_symbolic get_value(data::Vector, t, dt) +struct Interpolator + data::Vector{Float64} + dt::Float64 +end + +function (i::Interpolator)(t) + return i.data[round(Int, t / i.dt + 1)] +end +@register_symbolic (i::Interpolator)(t) -function Sampled(; name, dt = 0.0, n = length(data)) +get_value(interp::Interpolator, t) = interp(t) +@register_symbolic get_value(interp::Interpolator, t) +# get_value(data, t, dt) = data[round(Int, t / dt + 1)] +# @register_symbolic get_value(data::Vector, t, dt) + +function Sampled(; name, interp = Interpolator(Float64[], 0.0)) pars = @parameters begin - data[1:n] - dt = dt + interpolator::Interpolator = interp end vars = [] @@ -48,15 +59,15 @@ function Sampled(; name, dt = 0.0, n = length(data)) end eqs = [ - output.u ~ get_value(data, t, dt) + output.u ~ get_value(interpolator, t) ] - return ODESystem(eqs, t, vars, [data..., dt]; name, systems, - defaults = [output.u => data[1]]) + return ODESystem(eqs, t, vars, [interpolator]; name, systems, + defaults = [output.u => interp.data[1]]) end vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 -@named src = Sampled(; dt, n = length(x)) +@named src = Sampled(; interp = Interpolator(x, dt)) @named int = Integrator() eqs = [y ~ src.output.u @@ -67,25 +78,20 @@ eqs = [y ~ src.output.u @named sys = ODESystem(eqs, t, vars, []; systems = [int, src]) s = complete(sys) sys = structural_simplify(sys) -prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; tofloat = false) -@test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}} +prob = ODEProblem(sys, [], (0.0, t_end), [s.src.interpolator => Interpolator(x, dt)]; tofloat = false) sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success @test sol[y][end] == x[end] #TODO: remake becomes more complicated now, how to improve? defs = ModelingToolkit.defaults(sys) -defs[s.src.data] = 2x -p′ = ModelingToolkit.varmap_to_vars(defs, parameters(sys); tofloat = false) -p′, = ModelingToolkit.split_parameters_by_type(p′) #NOTE: we need to ensure this is called now before calling remake() +defs[s.src.interpolator] = Interpolator(2x, dt) +p′ = ModelingToolkit.MTKParameters(sys, defs) prob′ = remake(prob; p = p′) sol = solve(prob′, ImplicitEuler()); @test sol.retcode == ReturnCode.Success @test sol[y][end] == 2x[end] -prob′′ = remake(prob; p = [s.src.data => x]) -@test_broken prob′′.p isa Tuple - # ------------------------ Mixed Type Converted to float (default behavior) vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 @@ -95,7 +101,7 @@ eqs = [D(y) ~ dy * a ddy ~ sin(t) * c] @named model = ODESystem(eqs, t, vars, pars) -sys = structural_simplify(model) +sys = structural_simplify(model; split = false) tspan = (0.0, t_end) prob = ODEProblem(sys, [], tspan, []) From 36388d0204d0c3b6be72ac92d356158d3758fddd Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 16 Feb 2024 00:23:27 +0530 Subject: [PATCH 05/29] test: fix test file paths --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 938db5550e..a12ba563c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -77,8 +77,8 @@ end if GROUP == "All" || GROUP == "Downstream" activate_downstream_env() - @safetestset "Linearization Tests" include("linearize.jl") - @safetestset "Inverse Models Test" include("inversemodel.jl") + @safetestset "Linearization Tests" include("downstream/linearize.jl") + @safetestset "Inverse Models Test" include("downstream/inversemodel.jl") end if GROUP == "All" || GROUP == "Extensions" From f9f7ae192c6068ee2d89c1ee341903e185651502 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 16 Feb 2024 12:09:01 +0530 Subject: [PATCH 06/29] refactor: support clock systems without index caches --- src/systems/clock_inference.jl | 157 ++++++++++++++++++++++++--------- test/clock.jl | 13 +++ 2 files changed, 127 insertions(+), 43 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 07b684de88..b315234878 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -146,12 +146,15 @@ function generate_discrete_affect( @static if VERSION < v"1.7" error("The `generate_discrete_affect` function requires at least Julia 1.7") end - has_index_cache(osys) && get_index_cache(osys) !== nothing || error("System must have index_cache for clock support") + use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing out = Sym{Any}(:out) appended_parameters = parameters(syss[continuous_id]) - param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p) - for p in appended_parameters) offset = length(appended_parameters) + param_to_idx = if use_index_cache + Dict{Any, ParameterIndex}(p => parameter_index(osys, p) for p in appended_parameters) + else + Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters)) + end affect_funs = [] init_funs = [] svs = [] @@ -170,18 +173,20 @@ function generate_discrete_affect( push!(fullvars, s) end needed_disc_to_cont_obs = [] - disc_to_cont_idxs = ParameterIndex[] + if use_index_cache + disc_to_cont_idxs = ParameterIndex[] + else + disc_to_cont_idxs = Int[] + end for v in inputs[continuous_id] vv = arguments(v)[1] if vv in fullvars push!(needed_disc_to_cont_obs, vv) - # @show param_to_idx[v] v - # @assert param_to_idx[v].portion isa SciMLStructures.Discrete # TOOD: remove push!(disc_to_cont_idxs, param_to_idx[v]) end end append!(appended_parameters, input, unknowns(sys)) - cont_to_disc_obs = build_explicit_observed_function(osys, + cont_to_disc_obs = build_explicit_observed_function(use_index_cache ? osys : syss[continuous_id], needed_cont_to_disc_obs, throw = false, expression = true, @@ -192,36 +197,62 @@ function generate_discrete_affect( expression = true, output_type = SVector, ps = reorder_parameters(osys, parameters(sys))) + ni = length(input) + ns = length(unknowns(sys)) disc = Func( [ out, DestructuredArgs(unknowns(osys)), - DestructuredArgs.(reorder_parameters(osys, parameters(osys)))..., - # DestructuredArgs(appended_parameters), + if use_index_cache + DestructuredArgs.(reorder_parameters(osys, parameters(osys))) + else + (DestructuredArgs(appended_parameters),) + end..., get_iv(sys) ], [], let_block) - cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] - disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] + if use_index_cache + cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] + disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] + else + cont_to_disc_idxs = (offset + 1):(offset += ni) + input_offset = offset + disc_range = (offset + 1):(offset += ns) + end save_vec = Expr(:ref, :Float64) - for unk in unknowns(sys) - idx = parameter_index(osys, unk) - push!(save_vec.args, :($(parameter_values)(p, $idx))) + if use_index_cache + for unk in unknowns(sys) + idx = parameter_index(osys, unk) + push!(save_vec.args, :($(parameter_values)(p, $idx))) + end + else + for i in 1:ns + push!(save_vec.args, :(p[$(input_offset + i)])) + end end empty_disc = isempty(disc_range) - disc_init = :(function (p, t) - d2c_obs = $disc_to_cont_obs - disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range) - result = d2c_obs(disc_state, p..., t) - for (val, i) in zip(result, $disc_to_cont_idxs) - # prevent multiple updates to dependents - _set_parameter_unchecked!(p, val, i; update_dependent = false) - end - discretes, repack, _ = $(SciMLStructures.canonicalize)( - $(SciMLStructures.Discrete()), p) - repack(discretes) # to force recalculation of dependents - end) + disc_init = if use_index_cache + :(function (p, t) + d2c_obs = $disc_to_cont_obs + disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range) + result = d2c_obs(disc_state, p..., t) + for (val, i) in zip(result, $disc_to_cont_idxs) + # prevent multiple updates to dependents + _set_parameter_unchecked!(p, val, i; update_dependent = false) + end + discretes, repack, _ = $(SciMLStructures.canonicalize)( + $(SciMLStructures.Discrete()), p) + repack(discretes) # to force recalculation of dependents + end) + else + :(function (p, t) + d2c_obs = $disc_to_cont_obs + d2c_view = view(p, $disc_to_cont_idxs) + disc_state = view(p, $disc_range) + copyto!(d2c_view, d2c_obs(disc_state, p, t)) + end) + end # @show disc_to_cont_idxs # @show cont_to_disc_idxs @@ -230,8 +261,18 @@ function generate_discrete_affect( @unpack u, p, t = integrator c2d_obs = $cont_to_disc_obs d2c_obs = $disc_to_cont_obs + $( + if use_index_cache + :(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]) + else + quote + c2d_view = view(p, $cont_to_disc_idxs) + d2c_view = view(p, $disc_to_cont_idxs) + disc_unknowns = view(p, $disc_range) + end + end + ) # TODO: find a way to do this without allocating - disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range] disc = $disc push!(saved_values.t, t) @@ -246,26 +287,56 @@ function generate_discrete_affect( # d2c comes last # @show t # @show "incoming", p - result = c2d_obs(integrator.u, p..., t) - for (val, i) in zip(result, $cont_to_disc_idxs) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end + $( + if use_index_cache + quote + result = c2d_obs(integrator.u, p..., t) + for (val, i) in zip(result, $cont_to_disc_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end + end + else + :(copyto!(c2d_view, c2d_obs(integrator.u, p, t))) + end + ) # @show "after c2d", p - if !$empty_disc - disc(disc_unknowns, integrator.u, p..., t) - for (val, i) in zip(disc_unknowns, $disc_range) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + $( + if use_index_cache + quote + if !$empty_disc + disc(disc_unknowns, integrator.u, p..., t) + for (val, i) in zip(disc_unknowns, $disc_range) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end + end + end + else + :($empty_disc || disc(disc_unknowns, disc_unknowns, p, t)) end - end + ) # @show "after state update", p - result = d2c_obs(disc_unknowns, p..., t) - for (val, i) in zip(result, $disc_to_cont_idxs) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end + $( + if use_index_cache + quote + result = d2c_obs(disc_unknowns, p..., t) + for (val, i) in zip(result, $disc_to_cont_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end + end + else + :(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))) + end + ) # @show "after d2c", p - discretes, repack, _ = $(SciMLStructures.canonicalize)( - $(SciMLStructures.Discrete()), p) - repack(discretes) + $( + if use_index_cache + quote + discretes, repack, _ = $(SciMLStructures.canonicalize)( + $(SciMLStructures.Discrete()), p) + repack(discretes) + end + end + ) end) sv = SavedValues(Float64, Vector{Float64}) push!(affect_funs, affect!) diff --git a/test/clock.jl b/test/clock.jl index 2dc37a4505..26b416fe26 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -118,6 +118,12 @@ prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, Tf), [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) @test sort(vcat(prob.p...)) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) + +ss_nosplit = structural_simplify(sys; split = false) +prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0, y => 0.0], (0.0, Tf), + [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) +@test sort(prob_nosplit.p) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud +sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent) # For all inputs in parameters, just initialize them to 0.0, and then set them # in the callback. @@ -154,8 +160,11 @@ prob = ODEProblem(foo!, [0.0], (0.0, Tf), [1.0, 4.0, 2.0, 3.0], callback = cb) # ud initializes to kp * (r - yd) + z = 1 * (1 - 0) + 3 = 4 sol2 = solve(prob, Tsit5()) @test sol.u == sol2.u +@test sol_nosplit.u == sol2.u @test saved_values.t == sol.prob.kwargs[:disc_saved_values][1].t +@test saved_values.t == sol_nosplit.prob.kwargs[:disc_saved_values][1].t @test saved_values.saveval == sol.prob.kwargs[:disc_saved_values][1].saveval +@test saved_values.saveval == sol_nosplit.prob.kwargs[:disc_saved_values][1].saveval @info "Testing multi-rate hybrid system" dt = 0.1 @@ -280,10 +289,13 @@ ci, varmap = infer_clocks(cl) @test varmap[u] == Continuous() ss = structural_simplify(cl) +ss_nosplit = structural_simplify(cl; split = false) if VERSION >= v"1.7" prob = ODEProblem(ss, [x => 0.0], (0.0, 1.0), [kp => 1.0]) + prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0], (0.0, 1.0), [kp => 1.0]) sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) + sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent) function foo!(dx, x, p, t) kp, ud1, ud2 = p @@ -314,6 +326,7 @@ if VERSION >= v"1.7" sol2 = solve(prob, Tsit5()) @test sol.u≈sol2.u atol=1e-6 + @test sol_nosplit.u≈sol2.u atol=1e-6 end ## From 15913422bb548d09383c53dbc566ec5f0f1ab60e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 16 Feb 2024 12:17:58 +0530 Subject: [PATCH 07/29] test: test callbacks without parameter splitting --- test/symbolic_events.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index ac16db3e50..b15769e630 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -138,6 +138,7 @@ fsys = flatten(sys) @test isequal(ModelingToolkit.continuous_events(sys2)[2].eqs[], sys.x ~ 1) sys = complete(sys) +sys_nosplit = complete(sys; split = false) sys2 = complete(sys2) # Functions should be generated for root-finding equations prob = ODEProblem(sys, Pair[], (0.0, 2.0)) @@ -155,15 +156,22 @@ cond.rf_ip(out, [2], p0, t0) @test out[] ≈ 1 # signature is u,p,t prob = ODEProblem(sys, Pair[], (0.0, 2.0)) +prob_nosplit = ODEProblem(sys_nosplit, Pair[], (0.0, 2.0)) sol = solve(prob, Tsit5()) +sol_nosplit = solve(prob_nosplit, Tsit5()) @test minimum(t -> abs(t - 1), sol.t) < 1e-10 # test that the solver stepped at the root +@test minimum(t -> abs(t - 1), sol_nosplit.t) < 1e-10 # test that the solver stepped at the root # Test that a user provided callback is respected test_callback = DiscreteCallback(x -> x, x -> x) prob = ODEProblem(sys, Pair[], (0.0, 2.0), callback = test_callback) +prob_nosplit = ODEProblem(sys_nosplit, Pair[], (0.0, 2.0), callback = test_callback) cbs = get_callback(prob) +cbs_nosplit = get_callback(prob_nosplit) @test cbs isa CallbackSet @test cbs.discrete_callbacks[1] == test_callback +@test cbs_nosplit isa CallbackSet +@test cbs_nosplit.discrete_callbacks[1] == test_callback prob = ODEProblem(sys2, Pair[], (0.0, 3.0)) cb = get_callback(prob) @@ -234,9 +242,11 @@ continuous_events = [[x ~ 0] => [vx ~ -vx] D(vy) ~ -0.01vy], t; continuous_events) ball = structural_simplify(ball) +ball_nosplit = structural_simplify(ball; split = false) tspan = (0.0, 5.0) prob = ODEProblem(ball, Pair[], tspan) +prob_nosplit = ODEProblem(ball_nosplit, Pair[], tspan) cb = get_callback(prob) @test cb isa ModelingToolkit.DiffEqCallbacks.VectorContinuousCallback @@ -250,9 +260,13 @@ cond.rf_ip(out, [0, 0, 0, 0], p0, t0) @test out ≈ [0, 1.5, -1.5] sol = solve(prob, Tsit5()) +sol_nosplit = solve(prob_nosplit, Tsit5()) @test 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close @test minimum(sol[y]) ≈ -1.5 # check wall conditions @test maximum(sol[y]) ≈ 1.5 # check wall conditions +@test 0 <= minimum(sol_nosplit[x]) <= 1e-10 # the ball never went through the floor but got very close +@test minimum(sol_nosplit[y]) ≈ -1.5 # check wall conditions +@test maximum(sol_nosplit[y]) ≈ 1.5 # check wall conditions # tv = sort([LinRange(0, 5, 200); sol.t]) # plot(sol(tv)[y], sol(tv)[x], line_z=tv) @@ -270,13 +284,18 @@ continuous_events = [ D(vx) ~ -1 D(vy) ~ 0], t; continuous_events) +ball_nosplit = structural_simplify(ball) ball = structural_simplify(ball) tspan = (0.0, 5.0) prob = ODEProblem(ball, Pair[], tspan) +prob_nosplit = ODEProblem(ball_nosplit, Pair[], tspan) sol = solve(prob, Tsit5()) +sol_nosplit = solve(prob_nosplit, Tsit5()) @test 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close @test -minimum(sol[y]) ≈ maximum(sol[y]) ≈ sqrt(2) # the ball will never go further than √2 in either direction (gravity was changed to 1 to get this particular number) +@test 0 <= minimum(sol_nosplit[x]) <= 1e-10 # the ball never went through the floor but got very close +@test -minimum(sol_nosplit[y]) ≈ maximum(sol_nosplit[y]) ≈ sqrt(2) # the ball will never go further than √2 in either direction (gravity was changed to 1 to get this particular number) # tv = sort([LinRange(0, 5, 200); sol.t]) # plot(sol(tv)[y], sol(tv)[x], line_z=tv) From f48d30d00aa364cc76389bbf701a29d44615bd87 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 16 Feb 2024 12:20:23 +0530 Subject: [PATCH 08/29] refactor: format --- src/systems/clock_inference.jl | 74 +++++++++++++++++---------------- src/systems/index_cache.jl | 24 +++++++---- src/systems/parameter_buffer.jl | 23 ++++++---- test/split_parameters.jl | 3 +- 4 files changed, 71 insertions(+), 53 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index b315234878..76766ef07c 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -151,7 +151,8 @@ function generate_discrete_affect( appended_parameters = parameters(syss[continuous_id]) offset = length(appended_parameters) param_to_idx = if use_index_cache - Dict{Any, ParameterIndex}(p => parameter_index(osys, p) for p in appended_parameters) + Dict{Any, ParameterIndex}(p => parameter_index(osys, p) + for p in appended_parameters) else Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters)) end @@ -186,7 +187,8 @@ function generate_discrete_affect( end end append!(appended_parameters, input, unknowns(sys)) - cont_to_disc_obs = build_explicit_observed_function(use_index_cache ? osys : syss[continuous_id], + cont_to_disc_obs = build_explicit_observed_function( + use_index_cache ? osys : syss[continuous_id], needed_cont_to_disc_obs, throw = false, expression = true, @@ -263,14 +265,14 @@ function generate_discrete_affect( d2c_obs = $disc_to_cont_obs $( if use_index_cache - :(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]) - else - quote - c2d_view = view(p, $cont_to_disc_idxs) - d2c_view = view(p, $disc_to_cont_idxs) - disc_unknowns = view(p, $disc_range) - end + :(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]) + else + quote + c2d_view = view(p, $cont_to_disc_idxs) + d2c_view = view(p, $disc_to_cont_idxs) + disc_unknowns = view(p, $disc_range) end + end ) # TODO: find a way to do this without allocating disc = $disc @@ -289,53 +291,53 @@ function generate_discrete_affect( # @show "incoming", p $( if use_index_cache - quote - result = c2d_obs(integrator.u, p..., t) - for (val, i) in zip(result, $cont_to_disc_idxs) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end + quote + result = c2d_obs(integrator.u, p..., t) + for (val, i) in zip(result, $cont_to_disc_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end - else - :(copyto!(c2d_view, c2d_obs(integrator.u, p, t))) end + else + :(copyto!(c2d_view, c2d_obs(integrator.u, p, t))) + end ) # @show "after c2d", p $( if use_index_cache - quote - if !$empty_disc - disc(disc_unknowns, integrator.u, p..., t) - for (val, i) in zip(disc_unknowns, $disc_range) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end + quote + if !$empty_disc + disc(disc_unknowns, integrator.u, p..., t) + for (val, i) in zip(disc_unknowns, $disc_range) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end end - else - :($empty_disc || disc(disc_unknowns, disc_unknowns, p, t)) end + else + :($empty_disc || disc(disc_unknowns, disc_unknowns, p, t)) + end ) # @show "after state update", p $( if use_index_cache - quote - result = d2c_obs(disc_unknowns, p..., t) - for (val, i) in zip(result, $disc_to_cont_idxs) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end + quote + result = d2c_obs(disc_unknowns, p..., t) + for (val, i) in zip(result, $disc_to_cont_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end - else - :(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))) end + else + :(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))) + end ) # @show "after d2c", p $( if use_index_cache - quote - discretes, repack, _ = $(SciMLStructures.canonicalize)( - $(SciMLStructures.Discrete()), p) - repack(discretes) - end + quote + discretes, repack, _ = $(SciMLStructures.canonicalize)( + $(SciMLStructures.Discrete()), p) + repack(discretes) end + end ) end) sv = SavedValues(Float64, Vector{Float64}) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 11cd9fc2bf..ac0511de0a 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -107,7 +107,8 @@ function IndexCache(sys::AbstractSystem) ) end - function get_buffer_sizes_and_idxs(buffers::Dict{DataType, Set{BasicSymbolic}}, track_linear_index = true) + function get_buffer_sizes_and_idxs( + buffers::Dict{DataType, Set{BasicSymbolic}}, track_linear_index = true) idxs = IndexMap() buffer_sizes = BufferTemplate[] for (i, (T, buf)) in enumerate(buffers) @@ -139,7 +140,7 @@ function IndexCache(sys::AbstractSystem) param_buffer_sizes, const_buffer_sizes, dependent_buffer_sizes, - nonnumeric_buffer_sizes, + nonnumeric_buffer_sizes ) end @@ -170,7 +171,9 @@ end function discrete_linear_index(ic::IndexCache, idx::ParameterIndex) idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected") ind = sum(temp.length for temp in ic.param_buffer_sizes; init = 0) - ind += sum(temp.length for temp in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1); init = 0) + ind += sum( + temp.length for temp in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1); + init = 0) ind += idx.idx[2] return ind end @@ -186,11 +189,16 @@ function reorder_parameters(sys::AbstractSystem, ps; kwargs...) end function reorder_parameters(ic::IndexCache, ps; drop_missing = false) - param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.param_buffer_sizes) - disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.discrete_buffer_sizes) - const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.constant_buffer_sizes) - dep_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.dependent_buffer_sizes) - nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:temp.length] for temp in ic.nonnumeric_buffer_sizes) + param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] + for temp in ic.param_buffer_sizes) + disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] + for temp in ic.discrete_buffer_sizes) + const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] + for temp in ic.constant_buffer_sizes) + dep_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] + for temp in ic.dependent_buffer_sizes) + nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] + for temp in ic.nonnumeric_buffer_sizes) for p in ps h = getsymbolhash(p) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index f93900906e..e07874607a 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -36,11 +36,16 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals for (k, v) in p if !haskey(extra_params, unwrap(k))) end - tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.param_buffer_sizes) - disc_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.discrete_buffer_sizes) - const_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.constant_buffer_sizes) - dep_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.dependent_buffer_sizes) - nonnumeric_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.nonnumeric_buffer_sizes) + tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length) + for temp in ic.param_buffer_sizes) + disc_buffer = Tuple(Vector{temp.type}(undef, temp.length) + for temp in ic.discrete_buffer_sizes) + const_buffer = Tuple(Vector{temp.type}(undef, temp.length) + for temp in ic.constant_buffer_sizes) + dep_buffer = Tuple(Vector{temp.type}(undef, temp.length) + for temp in ic.dependent_buffer_sizes) + nonnumeric_buffer = Tuple(Vector{temp.type}(undef, temp.length) + for temp in ic.nonnumeric_buffer_sizes) dependencies = Dict{Num, Num}() function set_value(sym, val) h = getsymbolhash(sym) @@ -117,7 +122,7 @@ function split_into_buffers(raw::AbstractArray, buf; recurse = true) if eltype(buf_v) isa AbstractArray && recurse return _helper.(buf_v; recurse = false) else - res = raw[idx:idx+length(buf_v)-1] + res = raw[idx:(idx + length(buf_v) - 1)] idx += length(buf_v) return res end @@ -218,7 +223,8 @@ function SymbolicIndexingInterface.set_parameter!( end end -function _set_parameter_unchecked!(p::MTKParameters, val, idx::ParameterIndex; update_dependent = true) +function _set_parameter_unchecked!( + p::MTKParameters, val, idx::ParameterIndex; update_dependent = true) @unpack portion, idx = idx i, j, k... = idx if portion isa SciMLStructures.Tunable @@ -304,7 +310,8 @@ function Base.setindex!(p::MTKParameters, val, i) end done end - _helper(p.tunable) || _helper(p.discrete) || _helper(p.constant) || _helper(p.nonnumeric) || throw(BoundsError(p, i)) + _helper(p.tunable) || _helper(p.discrete) || _helper(p.constant) || + _helper(p.nonnumeric) || throw(BoundsError(p, i)) if p.dependent_update_iip !== nothing p.dependent_update_iip(ArrayPartition(p.dependent), p...) end diff --git a/test/split_parameters.jl b/test/split_parameters.jl index ff101dcbad..2aaea23f98 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -78,7 +78,8 @@ eqs = [y ~ src.output.u @named sys = ODESystem(eqs, t, vars, []; systems = [int, src]) s = complete(sys) sys = structural_simplify(sys) -prob = ODEProblem(sys, [], (0.0, t_end), [s.src.interpolator => Interpolator(x, dt)]; tofloat = false) +prob = ODEProblem( + sys, [], (0.0, t_end), [s.src.interpolator => Interpolator(x, dt)]; tofloat = false) sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success @test sol[y][end] == x[end] From 6abcc4968f312394be3bae5dd2bfa0387ceb8304 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 18 Feb 2024 20:30:23 +0530 Subject: [PATCH 09/29] refactor: disable treating symbolic defaults as param dependencies --- src/systems/index_cache.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index ac0511de0a..d26e9f0ca4 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -77,13 +77,13 @@ function IndexCache(sys::AbstractSystem) end end - all_ps = Set(unwrap.(parameters(sys))) - for (sym, value) in defaults(sys) - sym = unwrap(sym) - if sym in all_ps && symbolic_type(unwrap(value)) !== NotSymbolic() - insert_by_type!(dependent_buffers, sym) - end - end + # all_ps = Set(unwrap.(parameters(sys))) + # for (sym, value) in defaults(sys) + # sym = unwrap(sym) + # if sym in all_ps && symbolic_type(unwrap(value)) !== NotSymbolic() + # insert_by_type!(dependent_buffers, sym) + # end + # end for p in parameters(sys) p = unwrap(p) From 9eb96a50567aae01cf332c0d5ba7fc1cc265b37f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 Feb 2024 12:51:31 +0530 Subject: [PATCH 10/29] feat: add support for parameter dependencies --- src/ModelingToolkit.jl | 2 +- src/systems/abstractsystem.jl | 68 ++++-- src/systems/callbacks.jl | 6 +- src/systems/clock_inference.jl | 4 +- src/systems/diffeqs/abstractodesystem.jl | 19 +- src/systems/diffeqs/odesystem.jl | 18 +- src/systems/diffeqs/sdesystem.jl | 22 +- src/systems/index_cache.jl | 28 ++- src/systems/jumps/jumpsystem.jl | 22 +- src/systems/nonlinear/nonlinearsystem.jl | 17 +- .../optimization/optimizationsystem.jl | 6 +- src/systems/parameter_buffer.jl | 54 +++-- test/parameter_dependencies.jl | 200 ++++++++++++++++++ test/runtests.jl | 1 + 14 files changed, 353 insertions(+), 114 deletions(-) create mode 100644 test/parameter_dependencies.jl diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 79cfa9e8d8..224ed22724 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -36,7 +36,7 @@ using PrecompileTools, Reexport using RecursiveArrayTools using SymbolicIndexingInterface - export independent_variables, unknowns, parameters + export independent_variables, unknowns, parameters, full_parameters import SymbolicUtils import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype, Symbolic, isadd, ismul, ispow, issym, FnType, diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d177f78c33..22253e5791 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -286,27 +286,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) if has_index_cache(sys) && get_index_cache(sys) !== nothing ic = get_index_cache(sys) h = getsymbolhash(sym) - return if haskey(ic.param_idx, h) - ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h]) - elseif haskey(ic.discrete_idx, h) - ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h]) - elseif haskey(ic.constant_idx, h) - ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h]) - elseif haskey(ic.dependent_idx, h) - ParameterIndex(nothing, ic.dependent_idx[h]) + return if (idx = ParameterIndex(ic, sym)) !== nothing + idx + elseif (idx = ParameterIndex(ic, default_toterm(sym))) !== nothing + idx else - h = getsymbolhash(default_toterm(sym)) - if haskey(ic.param_idx, h) - ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h]) - elseif haskey(ic.discrete_idx, h) - ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h]) - elseif haskey(ic.constant_idx, h) - ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h]) - elseif haskey(ic.dependent_idx, h) - ParameterIndex(nothing, ic.dependent_idx[h]) - else - nothing - end + nothing end end @@ -329,7 +314,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym end function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem) - return parameters(sys) + return full_parameters(sys) end function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym) @@ -419,6 +404,7 @@ for prop in [:eqs :metadata :gui_metadata :discrete_subsystems + :parameter_dependencies :solved_unknowns :split_idxs :parent @@ -750,7 +736,29 @@ function parameters(sys::AbstractSystem) ps = first.(ps) end systems = get_systems(sys) - unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))]) + result = unique(isempty(systems) ? ps : + [ps; reduce(vcat, namespace_parameters.(systems))]) + if has_parameter_dependencies(sys) && + (pdeps = get_parameter_dependencies(sys)) !== nothing + filter(result) do sym + !haskey(pdeps, sym) + end + else + result + end +end + +function dependent_parameters(sys::AbstractSystem) + if has_parameter_dependencies(sys) && + (pdeps = get_parameter_dependencies(sys)) !== nothing + collect(keys(pdeps)) + else + [] + end +end + +function full_parameters(sys::AbstractSystem) + vcat(parameters(sys), dependent_parameters(sys)) end # required in `src/connectors.jl:437` @@ -1612,7 +1620,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs, kwargs...) sts = unknowns(sys) t = get_iv(sys) - ps = parameters(sys) + ps = full_parameters(sys) p = reorder_parameters(sys, ps) fun = generate_function(sys, sts, ps; expression = Val{false})[1] @@ -2123,3 +2131,17 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, error("substituting symbols is not supported for $(typeof(sys))") end end + +function process_parameter_dependencies(pdeps, ps) + pdeps === nothing && return pdeps, ps + if pdeps isa Vector && eltype(pdeps) <: Pair + pdeps = Dict(pdeps) + elseif !(pdeps isa Dict) + error("parameter_dependencies must be a `Dict` or `Vector{<:Pair}`") + end + + ps = filter(ps) do p + !haskey(pdeps, p) + end + return pdeps, ps +end diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index e0e0e7e7c8..15d09a9f3d 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -433,14 +433,14 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin end function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys); kwargs...) + ps = full_parameters(sys); kwargs...) cbs = continuous_events(sys) isempty(cbs) && return nothing generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...) end function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys); kwargs...) + ps = full_parameters(sys); kwargs...) eqs = map(cb -> cb.eqs, cbs) num_eqs = length.(eqs) (isempty(eqs) || sum(num_eqs) == 0) && return nothing @@ -556,7 +556,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! = end function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys), - ps = parameters(sys); kwargs...) + ps = full_parameters(sys); kwargs...) has_discrete_events(sys) || return nothing symcbs = discrete_events(sys) isempty(symcbs) && return nothing diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 76766ef07c..dab56cf916 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -198,7 +198,7 @@ function generate_discrete_affect( throw = false, expression = true, output_type = SVector, - ps = reorder_parameters(osys, parameters(sys))) + ps = reorder_parameters(osys, full_parameters(sys))) ni = length(input) ns = length(unknowns(sys)) disc = Func( @@ -206,7 +206,7 @@ function generate_discrete_affect( out, DestructuredArgs(unknowns(osys)), if use_index_cache - DestructuredArgs.(reorder_parameters(osys, parameters(osys))) + DestructuredArgs.(reorder_parameters(osys, full_parameters(osys))) else (DestructuredArgs(appended_parameters),) end..., diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 0454d02952..a9e726bb62 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -80,7 +80,8 @@ function calculate_control_jacobian(sys::AbstractODESystem; return jac end -function generate_tgrad(sys::AbstractODESystem, dvs = unknowns(sys), ps = parameters(sys); +function generate_tgrad( + sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys); simplify = false, kwargs...) tgrad = calculate_tgrad(sys, simplify = simplify) pre = get_preprocess_constants(tgrad) @@ -100,7 +101,7 @@ function generate_tgrad(sys::AbstractODESystem, dvs = unknowns(sys), ps = parame end function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys); + ps = full_parameters(sys); simplify = false, sparse = false, kwargs...) jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse) pre = get_preprocess_constants(jac) @@ -118,7 +119,7 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), end function generate_control_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys); + ps = full_parameters(sys); simplify = false, sparse = false, kwargs...) jac = calculate_control_jacobian(sys; simplify = simplify, sparse = sparse) p = reorder_parameters(sys, ps) @@ -146,7 +147,7 @@ function generate_dae_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), end function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys); + ps = full_parameters(sys); implicit_dae = false, ddvs = implicit_dae ? map(Differential(get_iv(sys)), dvs) : nothing, @@ -314,7 +315,7 @@ end function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys), u0 = nothing; + ps = full_parameters(sys), u0 = nothing; version = nothing, tgrad = false, jac = false, p = nothing, t = nothing, @@ -830,7 +831,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; kwargs...) eqs = equations(sys) dvs = unknowns(sys) - ps = parameters(sys) + ps = full_parameters(sys) iv = get_iv(sys) if has_index_cache(sys) && get_index_cache(sys) !== nothing @@ -845,7 +846,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; symbolic_u0) p, split_idxs = split_parameters_by_type(p) if p isa Tuple - ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) + ps = Base.Fix1(getindex, full_parameters(sys)).(split_idxs) ps = (ps...,) #if p is Tuple, ps should be Tuple end end @@ -997,7 +998,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = cbs = CallbackSet(discrete_cbs...) end else - cbs = CallbackSet(cbs, discrete_cbs) + cbs = CallbackSet(cbs, discrete_cbs...) end else svs = nothing @@ -1060,7 +1061,7 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan end function generate_history(sys::AbstractODESystem, u0; kwargs...) - p = reorder_parameters(sys, parameters(sys)) + p = reorder_parameters(sys, full_parameters(sys)) build_function(u0, p..., get_iv(sys); expression = Val{false}, kwargs...) end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index b3254f8b47..653ccfadd0 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -111,6 +111,11 @@ struct ODESystem <: AbstractODESystem """ discrete_events::Vector{SymbolicDiscreteCallback} """ + A mapping from dependent parameters to expressions describing how they are calculated from + other parameters. + """ + parameter_dependencies::Union{Nothing, Dict} + """ Metadata for the system, to be used by downstream packages. """ metadata::Any @@ -154,7 +159,7 @@ struct ODESystem <: AbstractODESystem function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching, connector_type, preface, cevents, - devents, metadata = nothing, gui_metadata = nothing, + devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing, tearing_state = nothing, substitutions = nothing, complete = false, index_cache = nothing, discrete_subsystems = nothing, solved_unknowns = nothing, @@ -171,8 +176,8 @@ struct ODESystem <: AbstractODESystem end new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching, - connector_type, preface, cevents, devents, metadata, gui_metadata, - tearing_state, substitutions, complete, index_cache, + connector_type, preface, cevents, devents, parameter_dependencies, metadata, + gui_metadata, tearing_state, substitutions, complete, index_cache, discrete_subsystems, solved_unknowns, split_idxs, parent) end end @@ -190,6 +195,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; preface = nothing, continuous_events = nothing, discrete_events = nothing, + parameter_dependencies = nothing, checks = true, metadata = nothing, gui_metadata = nothing) @@ -225,10 +231,12 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; end cont_callbacks = SymbolicContinuousCallbacks(continuous_events) disc_callbacks = SymbolicDiscreteCallbacks(discrete_events) + parameter_dependencies, ps′ = process_parameter_dependencies( + parameter_dependencies, ps′) ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, - connector_type, preface, cont_callbacks, disc_callbacks, + connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata, checks = checks) end @@ -323,7 +331,7 @@ function build_explicit_observed_function(sys, ts; output_type = Array, checkbounds = true, drop_expr = drop_expr, - ps = parameters(sys), + ps = full_parameters(sys), throw = true) if (isscalar = !(ts isa AbstractVector)) ts = [ts] diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index e873e27ef0..c61c83ceda 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -104,6 +104,11 @@ struct SDESystem <: AbstractODESystem """ discrete_events::Vector{SymbolicDiscreteCallback} """ + A mapping from dependent parameters to expressions describing how they are calculated from + other parameters. + """ + parameter_dependencies::Union{Nothing, Dict} + """ Metadata for the system, to be used by downstream packages. """ metadata::Any @@ -128,7 +133,7 @@ struct SDESystem <: AbstractODESystem tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, - cevents, devents, metadata = nothing, gui_metadata = nothing, + cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing, complete = false, index_cache = nothing, parent = nothing; checks::Union{Bool, Int} = true) if checks == true || (checks & CheckComponents) > 0 @@ -144,7 +149,7 @@ struct SDESystem <: AbstractODESystem new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents, - metadata, gui_metadata, complete, index_cache, parent) + parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent) end end @@ -161,6 +166,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv checks = true, continuous_events = nothing, discrete_events = nothing, + parameter_dependencies = nothing, metadata = nothing, gui_metadata = nothing) name === nothing && @@ -195,11 +201,12 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv Wfact_t = RefValue(EMPTY_JAC) cont_callbacks = SymbolicContinuousCallbacks(continuous_events) disc_callbacks = SymbolicDiscreteCallbacks(discrete_events) - + parameter_dependencies, ps′ = process_parameter_dependencies( + parameter_dependencies, ps′) SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, - cont_callbacks, disc_callbacks, metadata, gui_metadata; checks = checks) + cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata; checks = checks) end function SDESystem(sys::ODESystem, neqs; kwargs...) @@ -220,7 +227,7 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem) end function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys), - ps = parameters(sys); isdde = false, kwargs...) + ps = full_parameters(sys); isdde = false, kwargs...) eqs = get_noiseeqs(sys) if isdde eqs = delay_to_function(sys, eqs) @@ -285,7 +292,7 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor) end SDESystem(deqs, get_noiseeqs(sys), get_iv(sys), unknowns(sys), parameters(sys), - name = name, checks = false) + name = name, parameter_dependencies = get_parameter_dependencies(sys), checks = false) end """ @@ -393,7 +400,8 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0) # return modified SDE System SDESystem(deqs, noiseeqs, get_iv(sys), unknown_vars, parameters(sys); defaults = Dict(θ => θ0), observed = [weight ~ θ / θ0], - name = name, checks = false) + name = name, parameter_dependencies = get_parameter_dependencies(sys), + checks = false) end function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys), diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index d26e9f0ca4..ed2ae98414 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -77,13 +77,13 @@ function IndexCache(sys::AbstractSystem) end end - # all_ps = Set(unwrap.(parameters(sys))) - # for (sym, value) in defaults(sys) - # sym = unwrap(sym) - # if sym in all_ps && symbolic_type(unwrap(value)) !== NotSymbolic() - # insert_by_type!(dependent_buffers, sym) - # end - # end + if has_parameter_dependencies(sys) && + (pdeps = get_parameter_dependencies(sys)) !== nothing + for (sym, value) in pdeps + sym = unwrap(sym) + insert_by_type!(dependent_buffers, sym) + end + end for p in parameters(sys) p = unwrap(p) @@ -107,8 +107,7 @@ function IndexCache(sys::AbstractSystem) ) end - function get_buffer_sizes_and_idxs( - buffers::Dict{DataType, Set{BasicSymbolic}}, track_linear_index = true) + function get_buffer_sizes_and_idxs(buffers::Dict{DataType, Set{BasicSymbolic}}) idxs = IndexMap() buffer_sizes = BufferTemplate[] for (i, (T, buf)) in enumerate(buffers) @@ -144,14 +143,8 @@ function IndexCache(sys::AbstractSystem) ) end -function ParameterIndex(ic::IndexCache, p) +function ParameterIndex(ic::IndexCache, p, sub_idx = ()) p = unwrap(p) - if istree(p) && operation(p) === getindex - sub_idx = Base.tail(arguments(p)) - p = arguments(p)[begin] - else - sub_idx = () - end h = getsymbolhash(p) return if haskey(ic.param_idx, h) ParameterIndex(SciMLStructures.Tunable(), (ic.param_idx[h]..., sub_idx...)) @@ -163,6 +156,9 @@ function ParameterIndex(ic::IndexCache, p) ParameterIndex(DEPENDENT_PORTION, (ic.dependent_idx[h]..., sub_idx...)) elseif haskey(ic.nonnumeric_idx, h) ParameterIndex(NONNUMERIC_PORTION, (ic.nonnumeric_idx[h]..., sub_idx...)) + elseif istree(p) && operation(p) === getindex + _p, sub_idx... = arguments(p) + ParameterIndex(ic, _p, sub_idx) else nothing end diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 5f442473b8..d98078324a 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -90,6 +90,11 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem """ discrete_events::Vector{SymbolicDiscreteCallback} """ + A mapping from dependent parameters to expressions describing how they are calculated from + other parameters. + """ + parameter_dependencies::Union{Nothing, Dict} + """ Metadata for the system, to be used by downstream packages. """ metadata::Any @@ -108,7 +113,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem function JumpSystem{U}(tag, ap::U, iv, unknowns, ps, var_to_name, observed, name, systems, - defaults, connector_type, devents, + defaults, connector_type, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing, complete = false, index_cache = nothing; checks::Union{Bool, Int} = true) where {U <: ArrayPartition} @@ -121,7 +126,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem check_units(u, ap, iv) end new{U}(tag, ap, iv, unknowns, ps, var_to_name, observed, name, systems, defaults, - connector_type, devents, metadata, gui_metadata, complete, index_cache) + connector_type, devents, parameter_dependencies, metadata, gui_metadata, + complete, index_cache) end end function JumpSystem(tag, ap, iv, states, ps, var_to_name, args...; kwargs...) @@ -139,6 +145,7 @@ function JumpSystem(eqs, iv, unknowns, ps; checks = true, continuous_events = nothing, discrete_events = nothing, + parameter_dependencies = nothing, metadata = nothing, gui_metadata = nothing, kwargs...) @@ -177,11 +184,11 @@ function JumpSystem(eqs, iv, unknowns, ps; (continuous_events === nothing) || error("JumpSystems currently only support discrete events.") disc_callbacks = SymbolicDiscreteCallbacks(discrete_events) - + parameter_dependencies, ps = process_parameter_dependencies(parameter_dependencies, ps) JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), ap, value(iv), unknowns, ps, var_to_name, observed, name, systems, - defaults, connector_type, disc_callbacks, metadata, gui_metadata, - checks = checks) + defaults, connector_type, disc_callbacks, parameter_dependencies, + metadata, gui_metadata, checks = checks) end function generate_rate_function(js::JumpSystem, rate) @@ -190,7 +197,7 @@ function generate_rate_function(js::JumpSystem, rate) csubs = Dict(c => getdefault(c) for c in consts) rate = substitute(rate, csubs) end - p = reorder_parameters(js, parameters(js)) + p = reorder_parameters(js, full_parameters(js)) rf = build_function(rate, unknowns(js), p..., get_iv(js), expression = Val{true}) @@ -202,8 +209,7 @@ function generate_affect_function(js::JumpSystem, affect, outputidxs) csubs = Dict(c => getdefault(c) for c in consts) affect = substitute(affect, csubs) end - p = reorder_parameters(js, parameters(js)) - compile_affect(affect, js, unknowns(js), p...; outputidxs = outputidxs, + compile_affect(affect, js, unknowns(js), full_parameters(js); outputidxs = outputidxs, expression = Val{true}, checkvars = false) end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index d36d0bf69c..ae14509ab6 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -171,7 +171,8 @@ function calculate_jacobian(sys::NonlinearSystem; sparse = false, simplify = fal return jac end -function generate_jacobian(sys::NonlinearSystem, vs = unknowns(sys), ps = parameters(sys); +function generate_jacobian( + sys::NonlinearSystem, vs = unknowns(sys), ps = full_parameters(sys); sparse = false, simplify = false, kwargs...) jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify) pre = get_preprocess_constants(jac) @@ -190,7 +191,8 @@ function calculate_hessian(sys::NonlinearSystem; sparse = false, simplify = fals return hess end -function generate_hessian(sys::NonlinearSystem, vs = unknowns(sys), ps = parameters(sys); +function generate_hessian( + sys::NonlinearSystem, vs = unknowns(sys), ps = full_parameters(sys); sparse = false, simplify = false, kwargs...) hess = calculate_hessian(sys, sparse = sparse, simplify = simplify) pre = get_preprocess_constants(hess) @@ -198,7 +200,8 @@ function generate_hessian(sys::NonlinearSystem, vs = unknowns(sys), ps = paramet return build_function(hess, vs, p...; postprocess_fbody = pre, kwargs...) end -function generate_function(sys::NonlinearSystem, dvs = unknowns(sys), ps = parameters(sys); +function generate_function( + sys::NonlinearSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...) rhss = [deq.rhs for deq in equations(sys)] pre, sol_states = get_substitutions_and_solved_unknowns(sys) @@ -221,7 +224,7 @@ end """ ```julia SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys), - ps = parameters(sys); + ps = full_parameters(sys); version = nothing, jac = false, sparse = false, @@ -237,7 +240,7 @@ function SciMLBase.NonlinearFunction(sys::NonlinearSystem, args...; kwargs...) end function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys), - ps = parameters(sys), u0 = nothing; + ps = full_parameters(sys), u0 = nothing; version = nothing, jac = false, eval_expression = true, @@ -294,7 +297,7 @@ end """ ```julia SciMLBase.NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys), - ps = parameters(sys); + ps = full_parameters(sys); version = nothing, jac = false, sparse = false, @@ -308,7 +311,7 @@ variable and parameter vectors, respectively. struct NonlinearFunctionExpr{iip} end function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys), - ps = parameters(sys), u0 = nothing; + ps = full_parameters(sys), u0 = nothing; version = nothing, tgrad = false, jac = false, linenumbers = false, diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index de16453776..12b44074e6 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -131,7 +131,7 @@ function calculate_gradient(sys::OptimizationSystem) end function generate_gradient(sys::OptimizationSystem, vs = unknowns(sys), - ps = parameters(sys); + ps = full_parameters(sys); kwargs...) grad = calculate_gradient(sys) pre = get_preprocess_constants(grad) @@ -145,7 +145,7 @@ function calculate_hessian(sys::OptimizationSystem) end function generate_hessian( - sys::OptimizationSystem, vs = unknowns(sys), ps = parameters(sys); + sys::OptimizationSystem, vs = unknowns(sys), ps = full_parameters(sys); sparse = false, kwargs...) if sparse hess = sparsehessian(objective(sys), unknowns(sys)) @@ -159,7 +159,7 @@ function generate_hessian( end function generate_function(sys::OptimizationSystem, vs = unknowns(sys), - ps = parameters(sys); + ps = full_parameters(sys); kwargs...) eqs = subs_constants(objective(sys)) p = if has_index_cache(sys) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index e07874607a..01ed40a707 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -46,7 +46,6 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals for temp in ic.dependent_buffer_sizes) nonnumeric_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.nonnumeric_buffer_sizes) - dependencies = Dict{Num, Num}() function set_value(sym, val) h = getsymbolhash(sym) if haskey(ic.param_idx, h) @@ -61,7 +60,6 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals elseif haskey(ic.dependent_idx, h) i, j = ic.dependent_idx[h] dep_buffer[i][j] = val - dependencies[wrap(sym)] = wrap(p[sym]) elseif haskey(ic.nonnumeric_idx, h) i, j = ic.nonnumeric_idx[h] nonnumeric_buffer[i][j] = val @@ -79,37 +77,32 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals set_value(sym, val) end - dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer)...) - for (sym, val) in dependencies - h = getsymbolhash(sym) - i, j = ic.dependent_idx[h] - dep_exprs.x[i][j] = wrap(fixpoint_sub(val, dependencies)) - end - p = reorder_parameters(ic, parameters(sys))[begin:(end - length(dep_buffer))] - update_function_iip, update_function_oop = if isempty(dep_exprs.x) - nothing, nothing - else + if has_parameter_dependencies(sys) && + (pdeps = get_parameter_dependencies(sys)) !== nothing + pdeps = Dict(k => fixpoint_sub(v, pdeps) for (k, v) in pdeps) + dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer)...) + for (sym, val) in pdeps + h = getsymbolhash(sym) + i, j = ic.dependent_idx[h] + dep_exprs.x[i][j] = wrap(val) + end + p = reorder_parameters(ic, parameters(sys)) oop, iip = build_function(dep_exprs, p...) - RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(iip), + update_function_iip, update_function_oop = RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(iip), RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(oop) + else + update_function_iip = update_function_oop = nothing end - # everything is an ArrayPartition so it's easy to figure out how many - # distinct vectors we have for each portion as `ArrayPartition.x` - # if use_union - # tunable_buffer = restrict_array_to_union(ArrayPartition(tunable_buffer)) - # disc_buffer = restrict_array_to_union(ArrayPartition(disc_buffer)) - # const_buffer = restrict_array_to_union(ArrayPartition(const_buffer)) - # dep_buffer = restrict_array_to_union(ArrayPartition(dep_buffer)) - # elseif tofloat - # tunable_buffer = Float64.(tunable_buffer) - # disc_buffer = Float64.(disc_buffer) - # const_buffer = Float64.(const_buffer) - # dep_buffer = Float64.(dep_buffer) - # end - return MTKParameters{typeof(tunable_buffer), typeof(disc_buffer), typeof(const_buffer), + + mtkps = MTKParameters{ + typeof(tunable_buffer), typeof(disc_buffer), typeof(const_buffer), typeof(dep_buffer), typeof(nonnumeric_buffer), typeof(update_function_iip), typeof(update_function_oop)}(tunable_buffer, disc_buffer, const_buffer, dep_buffer, nonnumeric_buffer, update_function_iip, update_function_oop) + if mtkps.dependent_update_iip !== nothing + mtkps.dependent_update_iip(ArrayPartition(mtkps.dependent), mtkps...) + end + return mtkps end function buffer_to_arraypartition(buf) @@ -122,7 +115,7 @@ function split_into_buffers(raw::AbstractArray, buf; recurse = true) if eltype(buf_v) isa AbstractArray && recurse return _helper.(buf_v; recurse = false) else - res = raw[idx:(idx + length(buf_v) - 1)] + res = reshape(raw[idx:(idx + length(buf_v) - 1)], size(buf_v)) idx += length(buf_v) return res end @@ -158,8 +151,9 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable) @eval function SciMLStructures.replace!(::$Portion, p::MTKParameters, newvals) src = split_into_buffers(newvals, p.$field) - dst = buffer_to_arraypartition(newvals) - dst .= src + for i in 1:length(p.$field) + (p.$field)[i] .= src[i] + end if p.dependent_update_iip !== nothing p.dependent_update_iip(ArrayPartition(p.dependent), p...) end diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl new file mode 100644 index 0000000000..78538d23c7 --- /dev/null +++ b/test/parameter_dependencies.jl @@ -0,0 +1,200 @@ +using ModelingToolkit +using Test +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq +using StochasticDiffEq +using JumpProcesses +using StableRNGs +using SciMLStructures: canonicalize, Tunable, replace, replace! +using SymbolicIndexingInterface + +@testset "ODESystem with callbacks" begin + @parameters p1=1.0 p2=1.0 + @variables x(t) + cb1 = [x ~ 2.0] => [p1 ~ 2.0] # triggers at t=-2+√6 + function affect1!(integ, u, p, ctx) + integ.ps[p[1]] = integ.ps[p[2]] + end + cb2 = [x ~ 4.0] => (affect1!, [], [p1, p2], [p1]) # triggers at t=-2+√7 + cb3 = [1.0] => [p1 ~ 5.0] + + @mtkbuild sys = ODESystem( + [D(x) ~ p1 * t + p2], + t; + parameter_dependencies = [p2 => 2p1], + continuous_events = [cb1, cb2], + discrete_events = [cb3] + ) + @test isequal(only(parameters(sys)), p1) + @test Set(full_parameters(sys)) == Set([p1, p2]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.5), jac = true) + @test prob.ps[p1] == 1.0 + @test prob.ps[p2] == 2.0 + @test_nowarn solve(prob, Tsit5()) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.5), [p1 => 1.0], jac = true) + @test prob.ps[p1] == 1.0 + @test prob.ps[p2] == 2.0 + integ = init(prob, Tsit5()) + @test integ.ps[p1] == 1.0 + @test integ.ps[p2] == 2.0 + step!(integ, 0.5, true) # after cb1, before cb2 + @test integ.ps[p1] == 2.0 + @test integ.ps[p2] == 4.0 + step!(integ, 0.4, true) # after cb2, before cb3 + @test integ.ps[p1] == 4.0 + @test integ.ps[p2] == 8.0 + step!(integ, 0.2, true) # after cb3 + @test integ.ps[p1] == 5.0 + @test integ.ps[p2] == 10.0 +end + +@testset "Clock system" begin + dt = 0.1 + @variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t) + @parameters kp kq + d = Clock(t, dt) + k = ShiftIndex(d) + + eqs = [yd ~ Sample(t, dt)(y) + ud ~ kp * (r - yd) + kq * z + r ~ 1.0 + u ~ Hold(ud) + D(x) ~ -x + u + y ~ x + z(k + 2) ~ z(k) + yd] + @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp]) + + Tf = 1.0 + prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), + [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) + @test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent) + + @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp], + discrete_events = [[0.5] => [kp ~ 2.0]]) + prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), + [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) + @test prob.ps[kp] == 1.0 + @test prob.ps[kq] == 2.0 + @test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) + prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), + [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) + integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent) + @test integ.ps[kp] == 1.0 + @test integ.ps[kq] == 2.0 + step!(integ, 0.6) + @test integ.ps[kp] == 2.0 + @test integ.ps[kq] == 4.0 +end + +@testset "SDESystem" begin + @parameters σ ρ β + @variables x(t) y(t) z(t) + + eqs = [D(x) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + + noiseeqs = [0.1 * x, + 0.1 * y, + 0.1 * z] + + @named sys = ODESystem(eqs, t) + @named sdesys = SDESystem(sys, noiseeqs; parameter_dependencies = [ρ => 2σ]) + sdesys = complete(sdesys) + @test Set(parameters(sdesys)) == Set([σ, β]) + @test Set(full_parameters(sdesys)) == Set([σ, β, ρ]) + + prob = SDEProblem( + sdesys, [x => 1.0, y => 0.0, z => 0.0], (0.0, 100.0), [σ => 10.0, β => 2.33]) + @test prob.ps[ρ] == 2prob.ps[σ] + @test_nowarn solve(prob, SRIW1()) + + @named sys = ODESystem(eqs, t) + @named sdesys = SDESystem(sys, noiseeqs; parameter_dependencies = [ρ => 2σ], + discrete_events = [[10.0] => [σ ~ 15.0]]) + sdesys = complete(sdesys) + prob = SDEProblem( + sdesys, [x => 1.0, y => 0.0, z => 0.0], (0.0, 100.0), [σ => 10.0, β => 2.33]) + integ = init(prob, SRIW1()) + @test integ.ps[σ] == 10.0 + @test integ.ps[ρ] == 20.0 + step!(integ, 11.0) + @test integ.ps[σ] == 15.0 + @test integ.ps[ρ] == 30.0 +end + +@testset "JumpSystem" begin + rng = StableRNG(12345) + @parameters β γ + @constants h = 1 + @variables S(t) I(t) R(t) + rate₁ = β * S * I * h + affect₁ = [S ~ S - 1 * h, I ~ I + 1] + rate₃ = γ * I * h + affect₃ = [I ~ I * h - 1, R ~ R + 1] + j₁ = ConstantRateJump(rate₁, affect₁) + j₃ = ConstantRateJump(rate₃, affect₃) + @named js2 = JumpSystem( + [j₁, j₃], t, [S, I, R], [γ]; parameter_dependencies = [β => 0.01γ]) + @test isequal(only(parameters(js2)), γ) + @test Set(full_parameters(js2)) == Set([γ, β]) + js2 = complete(js2) + tspan = (0.0, 250.0) + u₀map = [S => 999, I => 1, R => 0] + parammap = [γ => 0.01] + dprob = DiscreteProblem(js2, u₀map, tspan, parammap) + jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng) + @test jprob.ps[γ] == 0.01 + @test jprob.ps[β] == 0.0001 + @test_nowarn solve(jprob, SSAStepper()) + + @named js2 = JumpSystem( + [j₁, j₃], t, [S, I, R], [γ]; parameter_dependencies = [β => 0.01γ], + discrete_events = [[10.0] => [γ ~ 0.02]]) + js2 = complete(js2) + dprob = DiscreteProblem(js2, u₀map, tspan, parammap) + jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng) + integ = init(jprob, SSAStepper()) + @test integ.ps[γ] == 0.01 + @test integ.ps[β] == 0.0001 + step!(integ, 11.0) + @test integ.ps[γ] == 0.02 + @test integ.ps[β] == 0.0002 +end + +@testset "SciMLStructures interface" begin + @parameters p1=1.0 p2=1.0 + @variables x(t) + cb1 = [x ~ 2.0] => [p1 ~ 2.0] # triggers at t=-2+√6 + function affect1!(integ, u, p, ctx) + integ.ps[p[1]] = integ.ps[p[2]] + end + cb2 = [x ~ 4.0] => (affect1!, [], [p1, p2], [p1]) # triggers at t=-2+√7 + cb3 = [1.0] => [p1 ~ 5.0] + + @mtkbuild sys = ODESystem( + [D(x) ~ p1 * t + p2], + t; + parameter_dependencies = [p2 => 2p1] + ) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.5), [p1 => 1.0], jac = true) + prob.ps[p1] = 3.0 + @test prob.ps[p1] == 3.0 + @test prob.ps[p2] == 6.0 + + ps = prob.p + buffer, repack, _ = canonicalize(Tunable(), ps) + @test only(buffer) == 3.0 + buffer[1] = 4.0 + ps = repack(buffer) + @test getp(sys, p1)(ps) == 4.0 + @test getp(sys, p2)(ps) == 8.0 + + replace!(Tunable(), ps, [1.0]) + @test getp(sys, p1)(ps) == 1.0 + @test getp(sys, p2)(ps) == 2.0 + + ps2 = replace(Tunable(), ps, [2.0]) + @test getp(sys, p1)(ps2) == 2.0 + @test getp(sys, p2)(ps2) == 4.0 +end diff --git a/test/runtests.jl b/test/runtests.jl index a12ba563c0..610ed55354 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,7 @@ end @safetestset "OptimizationSystem Test" include("optimizationsystem.jl") @safetestset "FuncAffect Test" include("funcaffect.jl") @safetestset "Constants Test" include("constants.jl") + @safetestset "Parameter Dependency Test" include("parameter_dependencies.jl") end if GROUP == "All" || GROUP == "InterfaceII" From 55f273018a0b4230f02c7979c0abb2ead39baa4c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 Feb 2024 13:26:36 +0530 Subject: [PATCH 11/29] docs: update NEWS with parameter dependencies --- NEWS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/NEWS.md b/NEWS.md index 8eeec8b6df..d253351b3d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -45,3 +45,7 @@ equations. For example, `[p[1] => 1.0, p[2] => 2.0]` is no longer allowed in default equations, use `[p => [1.0, 2.0]]` instead. Also, array equations like for `@variables u[1:2]` have `D(u) ~ A*u` as an array equation. If the scalarized version is desired, use `scalarize(u)`. + - Parameter dependencies are now supported. They can be specified using the syntax + `(single_parameter => expression_involving_other_parameters)` and a `Vector` of these can be passed to + the `parameter_dependencies` keyword argument of `ODESystem`, `SDESystem` and `JumpSystem`. The dependent + parameters are updated whenever other parameters are modified, e.g. in callbacks. From c26d4d9d8b5f2d74eb22b0fda42a07bb339d2c14 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 Feb 2024 16:09:54 +0530 Subject: [PATCH 12/29] feat: un-scalarize inferred parameters, improve parameter initialization --- src/systems/abstractsystem.jl | 8 ++------ src/systems/diffeqs/odesystem.jl | 15 ++++++++++++++- src/systems/parameter_buffer.jl | 20 +++++++++++++++++--- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 22253e5791..5d7010526c 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -192,8 +192,7 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym) ic = get_index_cache(sys) h = getsymbolhash(sym) return haskey(ic.unknown_idx, h) || - haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) || - hasname(sym) && is_variable(sys, getname(sym)) + haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) else return any(isequal(sym), variable_symbols(sys)) || hasname(sym) && is_variable(sys, getname(sym)) @@ -220,8 +219,6 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym) h = getsymbolhash(default_toterm(sym)) if haskey(ic.unknown_idx, h) ic.unknown_idx[h] - elseif hasname(sym) - variable_index(sys, getname(sym)) else nothing end @@ -264,8 +261,7 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) else h = getsymbolhash(default_toterm(sym)) haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) || - haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) || - hasname(sym) && is_parameter(sys, getname(sym)) + haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) end end return any(isequal(sym), parameter_symbols(sys)) || diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 653ccfadd0..55155151a2 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -280,10 +280,23 @@ function ODESystem(eqs, iv; kwargs...) isdelay(v, iv) || continue collect_vars!(allunknowns, ps, arguments(v)[1], iv) end + new_ps = OrderedSet() + for p in ps + if istree(p) && operation(p) === getindex + par = arguments(p)[begin] + if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && all(par[i] in ps for i in eachindex(par)) + push!(new_ps, par) + else + push!(new_ps, p) + end + else + push!(new_ps, p) + end + end algevars = setdiff(allunknowns, diffvars) # the orders here are very important! return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv, - collect(Iterators.flatten((diffvars, algevars))), collect(ps); kwargs...) + collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...) end # NOTE: equality does not check cached Jacobian diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 01ed40a707..9bcbc2a6e8 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -36,6 +36,12 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals for (k, v) in p if !haskey(extra_params, unwrap(k))) end + for (sym, _) in p + if istree(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[begin]) + # error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`") + end + end + tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.param_buffer_sizes) disc_buffer = Tuple(Vector{temp.type}(undef, temp.length) @@ -48,6 +54,7 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals for temp in ic.nonnumeric_buffer_sizes) function set_value(sym, val) h = getsymbolhash(sym) + done = true if haskey(ic.param_idx, h) i, j = ic.param_idx[h] tunable_buffer[i][j] = val @@ -64,17 +71,24 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals i, j = ic.nonnumeric_idx[h] nonnumeric_buffer[i][j] = val elseif !isequal(default_toterm(sym), sym) - set_value(default_toterm(sym), val) + done = set_value(default_toterm(sym), val) else - error("Symbol $sym does not have an index") + done = false end + return done end for (sym, val) in p sym = unwrap(sym) ctype = concrete_symtype(sym) val = convert(ctype, fixpoint_sub(val, p)) - set_value(sym, val) + done = set_value(sym, val) + if !done && Symbolics.isarraysymbolic(sym) + done = all(set_value.(collect(sym), val)) + end + if !done + error("Symbol $sym does not have an index") + end end if has_parameter_dependencies(sys) && From 812e004be7d8345c973039c9d3a6594114c238fa Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 Feb 2024 18:10:49 +0530 Subject: [PATCH 13/29] feat: flatten equations to avoid scalarizing array arguments --- src/systems/diffeqs/abstractodesystem.jl | 27 ++++++++++++++++++++++++ src/systems/diffeqs/odesystem.jl | 17 +++++++++++++-- src/systems/diffeqs/sdesystem.jl | 9 +++++++- src/systems/jumps/jumpsystem.jl | 2 +- src/systems/parameter_buffer.jl | 3 ++- test/odesystem.jl | 18 ++++++++++++++-- 6 files changed, 69 insertions(+), 7 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index a9e726bb62..edd0cc0e59 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -807,6 +807,17 @@ function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false) defs = mergedefaults(defs, parammap, ps) end defs = mergedefaults(defs, u0map, dvs) + for (k, v) in defs + if Symbolics.isarraysymbolic(k) + ks = scalarize(k) + length(ks) == length(v) || error("$k has default value $v with unmatched size") + for (kk, vv) in zip(ks, v) + if !haskey(defs, kk) + defs[kk] = vv + end + end + end + end if symbolic_u0 u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) @@ -1415,3 +1426,19 @@ function isisomorphic(sys1::AbstractODESystem, sys2::AbstractODESystem) end return false end + +function flatten_equations(eqs) + mapreduce(vcat, eqs; init = Equation[]) do eq + islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs) + isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs) + if islhsarr || isrhsarr + islhsarr && isrhsarr || + error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar") + size(eq.lhs) == size(eq.rhs) || + error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))") + return collect(eq.lhs) .~ collect(eq.rhs) + else + eq + end + end +end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 55155151a2..cfd707c92b 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -202,7 +202,19 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters." - deqs = reduce(vcat, scalarize(deqs); init = Equation[]) + deqs = mapreduce(vcat, deqs; init = Equation[]) do eq + islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs) + isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs) + if islhsarr || isrhsarr + islhsarr && isrhsarr || + error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar") + size(eq.lhs) == size(eq.rhs) || + error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))") + return collect(eq.lhs) .~ collect(eq.rhs) + else + eq + end + end iv′ = value(iv) ps′ = value.(ps) ctrl′ = value.(controls) @@ -284,7 +296,8 @@ function ODESystem(eqs, iv; kwargs...) for p in ps if istree(p) && operation(p) === getindex par = arguments(p)[begin] - if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && all(par[i] in ps for i in eachindex(par)) + if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && + all(par[i] in ps for i in eachindex(par)) push!(new_ps, par) else push!(new_ps, p) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index c61c83ceda..d7a001f937 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -171,7 +171,14 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv gui_metadata = nothing) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - deqs = scalarize(deqs) + deqs = flatten_equations(deqs) + neqs = mapreduce(vcat, neqs) do expr + if expr isa AbstractArray || Symbolics.isarraysymbolic(expr) + collect(expr) + else + expr + end + end iv′ = value(iv) dvs′ = value.(dvs) ps′ = value.(ps) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index d98078324a..abe6648ea9 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -151,7 +151,7 @@ function JumpSystem(eqs, iv, unknowns, ps; kwargs...) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - eqs = scalarize(eqs) + eqs = flatten_equations(eqs) sysnames = nameof.(systems) if length(unique(sysnames)) != length(sysnames) throw(ArgumentError("System names must be unique.")) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 9bcbc2a6e8..d9a2bc797e 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -37,7 +37,8 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals end for (sym, _) in p - if istree(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[begin]) + if istree(sym) && operation(sym) === getindex && + is_parameter(sys, arguments(sym)[begin]) # error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`") end end diff --git a/test/odesystem.jl b/test/odesystem.jl index dd89428f16..efea02b628 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -520,8 +520,7 @@ using SymbolicUtils.Code using Symbolics: unwrap, wrap, @register_symbolic foo(a, ms::AbstractVector) = a + sum(ms) @register_symbolic foo(a, ms::AbstractVector) -@variables t x(t) ms(t)[1:3] -D = Differential(t) +@variables x(t) ms(t)[1:3] eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)] @named sys = ODESystem(eqs, t, [x; ms], []) @named emptysys = ODESystem(Equation[], t) @@ -529,6 +528,21 @@ eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)] prob = ODEProblem(outersys, [sys.x => 1.0, sys.ms => 1:3], (0, 1.0)) @test_nowarn solve(prob, Tsit5()) +# array equations +bar(x, p) = p * x +@register_array_symbolic bar(x::AbstractVector, p::AbstractMatrix) begin + size = size(x) + eltype = promote_type(eltype(x), eltype(p)) +end +@parameters p[1:3, 1:3] +eqs = [D(x) ~ foo(x, ms); D(ms) ~ bar(ms, p)] +@named sys = ODESystem(eqs, t) +@named emptysys = ODESystem(Equation[], t) +@mtkbuild outersys = compose(emptysys, sys) +prob = ODEProblem( + outersys, [sys.x => 1.0, sys.ms => 1:3], (0.0, 1.0), [sys.p => ones(3, 3)]) +@test_nowarn solve(prob, Tsit5()) + # x/x @variables x(t) @named sys = ODESystem([D(x) ~ x / x], t) From 059f6e8665499943d595954c8c68b0691a84f323 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 19 Feb 2024 07:59:33 -0500 Subject: [PATCH 14/29] fix formatting --- src/bipartite_graph.jl | 3 ++- src/structural_transformation/partial_state_selection.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/bipartite_graph.jl b/src/bipartite_graph.jl index 2b9b0b3d75..2c12b52c11 100644 --- a/src/bipartite_graph.jl +++ b/src/bipartite_graph.jl @@ -88,7 +88,8 @@ function Base.push!(m::Matching, v) end end -function complete(m::Matching{U}, N = maximum((x for x in m.match if isa(x, Int)); init=0)) where {U} +function complete(m::Matching{U}, + N = maximum((x for x in m.match if isa(x, Int)); init = 0)) where {U} m.inv_match !== nothing && return m inv_match = Union{U, Int}[unassigned for _ in 1:N] for (i, eq) in enumerate(m.match) diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index f61677c2cd..36ea47fc52 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -51,7 +51,7 @@ function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varl old_level_vars = () ict = IncrementalCycleTracker( DiCMOBiGraph{true}(graph, - complete(Matching(ndsts(graph)), nsrcs(graph))), + complete(Matching(ndsts(graph)), nsrcs(graph))), dir = :in) while level >= 0 From 55cd1d878417773b0aef4ad8c95677dbd7987717 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 19 Feb 2024 08:03:29 -0500 Subject: [PATCH 15/29] fix typo --- src/systems/index_cache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index ed2ae98414..8d1c726365 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -72,7 +72,7 @@ function IndexCache(sys::AbstractSystem) if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing _, inputs, continuous_id, _ = get_discrete_subsystems(sys) for par in inputs[continuous_id] - is_parameter(sys, par) || error("Discrete subsytem input is not a parameter") + is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") insert_by_type!(disc_buffers, par) end end From 0c26977c7206732f53e7e43b74a88ec5e6d83be0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 Feb 2024 17:12:12 +0530 Subject: [PATCH 16/29] fix: fix `vars!` --- src/utils.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 85d8308c2d..049b4d5e6b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -350,20 +350,20 @@ function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) end function vars!(vars, O; op = Differential) - if isvariable(O) && !(istree(O) && operation(O) === getindex) + if isvariable(O) return push!(vars, O) end - !istree(O) && return vars + + operation(O) isa op && return push!(vars, O) + if operation(O) === (getindex) arr = first(arguments(O)) - return vars!(vars, arr) + istree(arr) && operation(arr) isa op && return push!(vars, O) + isvariable(arr) && return push!(vars, O) end - operation(O) isa op && return push!(vars, O) - isvariable(operation(O)) && push!(vars, O) - for arg in arguments(O) vars!(vars, arg; op = op) end From 8d7c677cec44944aee4021f0ee6807eee2347157 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 Feb 2024 17:12:27 +0530 Subject: [PATCH 17/29] fix: refactor IndexCache for non-scalarized unknowns --- src/systems/abstractsystem.jl | 23 ++++++++++++----------- src/systems/index_cache.jl | 17 ++++++++++++----- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 5d7010526c..7bb3c84818 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -192,7 +192,9 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym) ic = get_index_cache(sys) h = getsymbolhash(sym) return haskey(ic.unknown_idx, h) || - haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) + haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) || + (istree(sym) && operation(sym) === getindex && + is_variable(sys, first(arguments(sym)))) else return any(isequal(sym), variable_symbols(sys)) || hasname(sym) && is_variable(sys, getname(sym)) @@ -213,16 +215,15 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym) if has_index_cache(sys) && get_index_cache(sys) !== nothing ic = get_index_cache(sys) h = getsymbolhash(sym) - return if haskey(ic.unknown_idx, h) - ic.unknown_idx[h] - else - h = getsymbolhash(default_toterm(sym)) - if haskey(ic.unknown_idx, h) - ic.unknown_idx[h] - else - nothing - end - end + haskey(ic.unknown_idx, h) && return ic.unknown_idx[h] + + h = getsymbolhash(default_toterm(sym)) + haskey(ic.unknown_idx, h) && return ic.unknown_idx[h] + sym = unwrap(sym) + istree(sym) && operation(sym) === getindex || return nothing + idx = variable_index(sys, first(arguments(sym))) + idx === nothing && return nothing + return idx[arguments(sym)[(begin + 1):end]...] end idx = findfirst(isequal(sym), variable_symbols(sys)) if idx === nothing && hasname(sym) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 8d1c726365..fa926676cf 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -21,7 +21,7 @@ end const IndexMap = Dict{UInt, Tuple{Int, Int}} struct IndexCache - unknown_idx::Dict{UInt, Int} + unknown_idx::Dict{UInt, Union{Int, UnitRange{Int}}} discrete_idx::IndexMap param_idx::IndexMap constant_idx::IndexMap @@ -36,10 +36,17 @@ end function IndexCache(sys::AbstractSystem) unks = solved_unknowns(sys) - unk_idxs = Dict{UInt, Int}() - for (i, sym) in enumerate(unks) - h = getsymbolhash(sym) - unk_idxs[h] = i + unk_idxs = Dict{UInt, Union{Int, UnitRange{Int}}}() + let idx = 1 + for sym in unks + h = getsymbolhash(sym) + if Symbolics.isarraysymbolic(sym) + unk_idxs[h] = idx:(idx + length(sym) - 1) + else + unk_idxs[h] = idx + end + idx += length(sym) + end end disc_buffers = Dict{DataType, Set{BasicSymbolic}}() From 9e2c9bc7e4c6a7a3c9b852e67b52cd4204b720be Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 Feb 2024 17:12:46 +0530 Subject: [PATCH 18/29] fix: do not call flatten_equations in JumpSystem --- src/systems/jumps/jumpsystem.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index abe6648ea9..9361b8f71c 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -151,7 +151,6 @@ function JumpSystem(eqs, iv, unknowns, ps; kwargs...) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - eqs = flatten_equations(eqs) sysnames = nameof.(systems) if length(unique(sysnames)) != length(sysnames) throw(ArgumentError("System names must be unique.")) From a6add74f47876dce2211339b7afb635fdf552087 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 Feb 2024 17:13:08 +0530 Subject: [PATCH 19/29] fix: handle broadcasted equations and array variables in ODESystem constructor --- src/systems/diffeqs/odesystem.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index cfd707c92b..84d9de0022 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -218,7 +218,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; iv′ = value(iv) ps′ = value.(ps) ctrl′ = value.(controls) - dvs′ = value.(dvs) + dvs′ = value.(symbolic_type(dvs) === NotSymbolic() ? dvs : [dvs]) dvs′ = filter(x -> !isdelay(x, iv), dvs′) if !(isempty(default_u0) && isempty(default_p)) Base.depwarn( @@ -253,6 +253,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; end function ODESystem(eqs, iv; kwargs...) + eqs = collect(eqs) # NOTE: this assumes that the order of algebraic equations doesn't matter diffvars = OrderedSet() allunknowns = OrderedSet() From a41a64fcbd240e69f2f0dc8f3c713598463de2a3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 Feb 2024 17:13:22 +0530 Subject: [PATCH 20/29] fix: use variable_index in calculate_massmatrix --- src/systems/diffeqs/abstractodesystem.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index edd0cc0e59..daff7d5a86 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -247,11 +247,10 @@ function calculate_massmatrix(sys::AbstractODESystem; simplify = false) eqs = [eq for eq in equations(sys)] dvs = unknowns(sys) M = zeros(length(eqs), length(eqs)) - unknown2idx = Dict(s => i for (i, s) in enumerate(dvs)) for (i, eq) in enumerate(eqs) if istree(eq.lhs) && operation(eq.lhs) isa Differential st = var_from_nested_derivative(eq.lhs)[1] - j = unknown2idx[st] + j = variable_index(sys, st) M[i, j] = 1 else _iszero(eq.lhs) || From c6c96ddee6b23bc208ac5bdce9444165e620fd03 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 Feb 2024 18:57:23 +0530 Subject: [PATCH 21/29] fix: do not scalarize in system constructors --- src/systems/diffeqs/odesystem.jl | 15 +-------------- src/systems/diffeqs/sdesystem.jl | 8 -------- src/systems/jumps/jumpsystem.jl | 1 + src/systems/systemstructure.jl | 3 ++- 4 files changed, 4 insertions(+), 23 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 84d9de0022..9b19e1c26c 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -202,23 +202,10 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters." - deqs = mapreduce(vcat, deqs; init = Equation[]) do eq - islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs) - isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs) - if islhsarr || isrhsarr - islhsarr && isrhsarr || - error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar") - size(eq.lhs) == size(eq.rhs) || - error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))") - return collect(eq.lhs) .~ collect(eq.rhs) - else - eq - end - end iv′ = value(iv) ps′ = value.(ps) ctrl′ = value.(controls) - dvs′ = value.(symbolic_type(dvs) === NotSymbolic() ? dvs : [dvs]) + dvs′ = value.(dvs) dvs′ = filter(x -> !isdelay(x, iv), dvs′) if !(isempty(default_u0) && isempty(default_p)) Base.depwarn( diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index d7a001f937..b021a201fe 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -171,14 +171,6 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv gui_metadata = nothing) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - deqs = flatten_equations(deqs) - neqs = mapreduce(vcat, neqs) do expr - if expr isa AbstractArray || Symbolics.isarraysymbolic(expr) - collect(expr) - else - expr - end - end iv′ = value(iv) dvs′ = value.(dvs) ps′ = value.(ps) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 9361b8f71c..0ce14211dc 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -151,6 +151,7 @@ function JumpSystem(eqs, iv, unknowns, ps; kwargs...) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) + eqs = scalarize.(eqs) sysnames = nameof.(systems) if length(unique(sysnames)) != length(sysnames) throw(ArgumentError("System names must be unique.")) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 3f00410b10..ddc54d4998 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -251,7 +251,8 @@ function TearingState(sys; quick_cancel = false, check = true) sys = flatten(sys) ivs = independent_variables(sys) iv = length(ivs) == 1 ? ivs[1] : nothing - eqs = copy(equations(sys)) + # scalarize array equations, without scalarizing arguments to registered functions + eqs = flatten_equations(copy(equations(sys))) neqs = length(eqs) dervaridxs = OrderedSet{Int}() var2idx = Dict{Any, Int}() From d7265c1fd1864f327cda23f42b4a1ad13a493d73 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 Feb 2024 18:57:35 +0530 Subject: [PATCH 22/29] test: fix mass matrix tests --- test/mass_matrix.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/mass_matrix.jl b/test/mass_matrix.jl index b67f2da870..5183b4ab3f 100644 --- a/test/mass_matrix.jl +++ b/test/mass_matrix.jl @@ -8,7 +8,7 @@ eqs = [D(y[1]) ~ -k[1] * y[1] + k[3] * y[2] * y[3], D(y[2]) ~ k[1] * y[1] - k[3] * y[2] * y[3] - k[2] * y[2]^2, 0 ~ y[1] + y[2] + y[3] - 1] -@named sys = ODESystem(eqs, t, y, [k]) +@named sys = ODESystem(eqs, t, collect(y), [k]) sys = complete(sys) @test_throws ArgumentError ODESystem(eqs, y[1]) M = calculate_massmatrix(sys) @@ -16,7 +16,7 @@ M = calculate_massmatrix(sys) 0 1 0 0 0 0] -prob_mm = ODEProblem(sys, [1.0, 0.0, 0.0], (0.0, 1e5), +prob_mm = ODEProblem(sys, [y => [1.0, 0.0, 0.0]], (0.0, 1e5), [k => [0.04, 3e7, 1e4]]) sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) @@ -40,6 +40,6 @@ sol2 = solve(prob_mm2, Rodas5(), reltol = 1e-8, abstol = 1e-8, tstops = sol.t, # Test mass matrix in the identity case eqs = [D(y[1]) ~ y[1], D(y[2]) ~ y[2], D(y[3]) ~ y[3]] -@named sys = ODESystem(eqs, t, y, k) +@named sys = ODESystem(eqs, t, collect(y), [k]) @test calculate_massmatrix(sys) === I From 1275d6ea44026628248764f70db299bcb65d2915 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 Feb 2024 18:58:41 +0530 Subject: [PATCH 23/29] fixup! fix: fix `vars!` --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 049b4d5e6b..5fa79530aa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -360,7 +360,7 @@ function vars!(vars, O; op = Differential) if operation(O) === (getindex) arr = first(arguments(O)) istree(arr) && operation(arr) isa op && return push!(vars, O) - isvariable(arr) && return push!(vars, O) + isvariable(arr) && return push!(vars, O) end isvariable(operation(O)) && push!(vars, O) From 3e0aea0d917d61a398f27bdf15a9215bd2ac742c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 21 Feb 2024 16:10:22 +0530 Subject: [PATCH 24/29] fix: fix IndexCache to not put matrices as nonnumeric parameters --- src/systems/index_cache.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index fa926676cf..94b610779f 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -97,9 +97,8 @@ function IndexCache(sys::AbstractSystem) ctype = concrete_symtype(p) haskey(disc_buffers, ctype) && p in disc_buffers[ctype] && continue haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue - insert_by_type!( - if ctype <: Real || ctype <: Vector{<:Real} + if ctype <: Real || ctype <: AbstractArray{<:Real} if is_discrete_domain(p) disc_buffers elseif istunable(p, true) && size(p) !== Symbolics.Unknown() @@ -240,5 +239,5 @@ end concrete_symtype(x::BasicSymbolic) = concrete_symtype(symtype(x)) concrete_symtype(::Type{Real}) = Float64 concrete_symtype(::Type{Integer}) = Int -concrete_symtype(::Type{Vector{T}}) where {T} = Vector{concrete_symtype(T)} +concrete_symtype(::Type{A}) where {T, N, A<:Array{T, N}} = Array{concrete_symtype(T), N} concrete_symtype(::Type{T}) where {T} = T From 1218152ae5d1f78f5bfdfd8a8a1251a5cf0e8651 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 21 Feb 2024 16:10:42 +0530 Subject: [PATCH 25/29] feat: add copy method for MTKParameters --- src/systems/parameter_buffer.jl | 37 ++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index d9a2bc797e..69b34450e0 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -39,7 +39,7 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals for (sym, _) in p if istree(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[begin]) - # error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`") + error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`") end end @@ -121,7 +121,7 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals end function buffer_to_arraypartition(buf) - return ArrayPartition((eltype(v) isa AbstractArray ? buffer_to_arraypartition(v) : v for v in buf)...) + return ArrayPartition(Tuple(eltype(v) <: AbstractArray ? buffer_to_arraypartition(v) : v for v in buf)) end function split_into_buffers(raw::AbstractArray, buf; recurse = true) @@ -146,13 +146,19 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable) (SciMLStructures.Discrete, :discrete) (SciMLStructures.Constants, :constant)] @eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters) - function repack(_) # aliases, so we don't need to use the parameter - if p.dependent_update_iip !== nothing - p.dependent_update_iip(ArrayPartition(p.dependent), p...) + as_vector = buffer_to_arraypartition(p.$field) + repack = let as_vector = as_vector, p = p + function (new_val) + if new_val !== as_vector + p.$field = split_into_buffers(new_val, p.$field) + end + if p.dependent_update_iip !== nothing + p.dependent_update_iip(ArrayPartition(p.dependent), p...) + end + p end - p end - return buffer_to_arraypartition(p.$field), repack, true + return as_vector, repack, true end @eval function SciMLStructures.replace(::$Portion, p::MTKParameters, newvals) @@ -176,6 +182,23 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable) end end +function Base.copy(p::MTKParameters) + tunable = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.tunable) + discrete = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.discrete) + constant = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.constant) + dependent = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.dependent) + nonnumeric = copy.(p.nonnumeric) + return MTKParameters( + tunable, + discrete, + constant, + dependent, + nonnumeric, + p.dependent_update_iip, + p.dependent_update_oop, + ) +end + function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex) @unpack portion, idx = pind i, j, k... = idx From 9d5c211890ed20d75ccc4674174b7eacf24cca3e Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 21 Feb 2024 12:11:23 -0500 Subject: [PATCH 26/29] Skip partial_state_selection test --- test/structural_transformation/index_reduction.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/structural_transformation/index_reduction.jl b/test/structural_transformation/index_reduction.jl index 053371d835..8dc06d680c 100644 --- a/test/structural_transformation/index_reduction.jl +++ b/test/structural_transformation/index_reduction.jl @@ -115,8 +115,10 @@ prob_auto = ODEProblem(new_sys, u0, (0.0, 10.0), p) sol = solve(prob_auto, Rodas5()); #plot(sol, idxs=(D(x), y)) -let pss_pendulum2 = partial_state_selection(pendulum2) - @test length(equations(pss_pendulum2)) <= 6 +@test_skip begin + let pss_pendulum2 = partial_state_selection(pendulum2) + length(equations(pss_pendulum2)) <= 6 + end end eqs = [D(x) ~ w, From 7ce0f57d3c7d4397264a1f4bd7b236b5c76e5b48 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 21 Feb 2024 22:03:32 -0500 Subject: [PATCH 27/29] format --- src/systems/index_cache.jl | 2 +- src/systems/parameter_buffer.jl | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 94b610779f..80d4d2b533 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -239,5 +239,5 @@ end concrete_symtype(x::BasicSymbolic) = concrete_symtype(symtype(x)) concrete_symtype(::Type{Real}) = Float64 concrete_symtype(::Type{Integer}) = Int -concrete_symtype(::Type{A}) where {T, N, A<:Array{T, N}} = Array{concrete_symtype(T), N} +concrete_symtype(::Type{A}) where {T, N, A <: Array{T, N}} = Array{concrete_symtype(T), N} concrete_symtype(::Type{T}) where {T} = T diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 69b34450e0..897225c47f 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -121,7 +121,8 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals end function buffer_to_arraypartition(buf) - return ArrayPartition(Tuple(eltype(v) <: AbstractArray ? buffer_to_arraypartition(v) : v for v in buf)) + return ArrayPartition(Tuple(eltype(v) <: AbstractArray ? buffer_to_arraypartition(v) : + v for v in buf)) end function split_into_buffers(raw::AbstractArray, buf; recurse = true) @@ -148,7 +149,7 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable) @eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters) as_vector = buffer_to_arraypartition(p.$field) repack = let as_vector = as_vector, p = p - function (new_val) + function (new_val) if new_val !== as_vector p.$field = split_into_buffers(new_val, p.$field) end @@ -195,7 +196,7 @@ function Base.copy(p::MTKParameters) dependent, nonnumeric, p.dependent_update_iip, - p.dependent_update_oop, + p.dependent_update_oop ) end From 703da35147b498d4e27eb05bb3a2fa72b1e46dd6 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 21 Feb 2024 22:18:01 -0500 Subject: [PATCH 28/29] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5a088c2c9e..1812bfbc3b 100644 --- a/Project.toml +++ b/Project.toml @@ -106,7 +106,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" SymbolicIndexingInterface = "0.3.1" SymbolicUtils = "1.0" -Symbolics = "5.7" +Symbolics = "5.21" URIs = "1" UnPack = "0.1, 1.0" Unitful = "1.1" From 3ede8ffeab1d1137c483dd2805d6cddad40f5514 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 21 Feb 2024 22:54:28 -0500 Subject: [PATCH 29/29] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1812bfbc3b..ae57ed3c52 100644 --- a/Project.toml +++ b/Project.toml @@ -106,7 +106,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" SymbolicIndexingInterface = "0.3.1" SymbolicUtils = "1.0" -Symbolics = "5.21" +Symbolics = "5.20.1" URIs = "1" UnPack = "0.1, 1.0" Unitful = "1.1"