From decfb515d837791f128e3a0ef03710410bb3d235 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 Feb 2024 18:10:49 +0530 Subject: [PATCH] feat: flatten equations to avoid scalarizing array arguments --- src/systems/diffeqs/abstractodesystem.jl | 27 ++++++++++++++++++++++++ src/systems/diffeqs/odesystem.jl | 17 +++++++++++++-- src/systems/diffeqs/sdesystem.jl | 9 +++++++- src/systems/jumps/jumpsystem.jl | 2 +- src/systems/parameter_buffer.jl | 3 ++- test/odesystem.jl | 18 ++++++++++++++-- 6 files changed, 69 insertions(+), 7 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index a9e726bb62..edd0cc0e59 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -807,6 +807,17 @@ function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false) defs = mergedefaults(defs, parammap, ps) end defs = mergedefaults(defs, u0map, dvs) + for (k, v) in defs + if Symbolics.isarraysymbolic(k) + ks = scalarize(k) + length(ks) == length(v) || error("$k has default value $v with unmatched size") + for (kk, vv) in zip(ks, v) + if !haskey(defs, kk) + defs[kk] = vv + end + end + end + end if symbolic_u0 u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) @@ -1415,3 +1426,19 @@ function isisomorphic(sys1::AbstractODESystem, sys2::AbstractODESystem) end return false end + +function flatten_equations(eqs) + mapreduce(vcat, eqs; init = Equation[]) do eq + islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs) + isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs) + if islhsarr || isrhsarr + islhsarr && isrhsarr || + error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar") + size(eq.lhs) == size(eq.rhs) || + error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))") + return collect(eq.lhs) .~ collect(eq.rhs) + else + eq + end + end +end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 55155151a2..cfd707c92b 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -202,7 +202,19 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters." - deqs = reduce(vcat, scalarize(deqs); init = Equation[]) + deqs = mapreduce(vcat, deqs; init = Equation[]) do eq + islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs) + isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs) + if islhsarr || isrhsarr + islhsarr && isrhsarr || + error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar") + size(eq.lhs) == size(eq.rhs) || + error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))") + return collect(eq.lhs) .~ collect(eq.rhs) + else + eq + end + end iv′ = value(iv) ps′ = value.(ps) ctrl′ = value.(controls) @@ -284,7 +296,8 @@ function ODESystem(eqs, iv; kwargs...) for p in ps if istree(p) && operation(p) === getindex par = arguments(p)[begin] - if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && all(par[i] in ps for i in eachindex(par)) + if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && + all(par[i] in ps for i in eachindex(par)) push!(new_ps, par) else push!(new_ps, p) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index c61c83ceda..d7a001f937 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -171,7 +171,14 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv gui_metadata = nothing) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - deqs = scalarize(deqs) + deqs = flatten_equations(deqs) + neqs = mapreduce(vcat, neqs) do expr + if expr isa AbstractArray || Symbolics.isarraysymbolic(expr) + collect(expr) + else + expr + end + end iv′ = value(iv) dvs′ = value.(dvs) ps′ = value.(ps) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index d98078324a..abe6648ea9 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -151,7 +151,7 @@ function JumpSystem(eqs, iv, unknowns, ps; kwargs...) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - eqs = scalarize(eqs) + eqs = flatten_equations(eqs) sysnames = nameof.(systems) if length(unique(sysnames)) != length(sysnames) throw(ArgumentError("System names must be unique.")) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 9bcbc2a6e8..d9a2bc797e 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -37,7 +37,8 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals end for (sym, _) in p - if istree(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[begin]) + if istree(sym) && operation(sym) === getindex && + is_parameter(sys, arguments(sym)[begin]) # error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`") end end diff --git a/test/odesystem.jl b/test/odesystem.jl index dd89428f16..efea02b628 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -520,8 +520,7 @@ using SymbolicUtils.Code using Symbolics: unwrap, wrap, @register_symbolic foo(a, ms::AbstractVector) = a + sum(ms) @register_symbolic foo(a, ms::AbstractVector) -@variables t x(t) ms(t)[1:3] -D = Differential(t) +@variables x(t) ms(t)[1:3] eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)] @named sys = ODESystem(eqs, t, [x; ms], []) @named emptysys = ODESystem(Equation[], t) @@ -529,6 +528,21 @@ eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)] prob = ODEProblem(outersys, [sys.x => 1.0, sys.ms => 1:3], (0, 1.0)) @test_nowarn solve(prob, Tsit5()) +# array equations +bar(x, p) = p * x +@register_array_symbolic bar(x::AbstractVector, p::AbstractMatrix) begin + size = size(x) + eltype = promote_type(eltype(x), eltype(p)) +end +@parameters p[1:3, 1:3] +eqs = [D(x) ~ foo(x, ms); D(ms) ~ bar(ms, p)] +@named sys = ODESystem(eqs, t) +@named emptysys = ODESystem(Equation[], t) +@mtkbuild outersys = compose(emptysys, sys) +prob = ODEProblem( + outersys, [sys.x => 1.0, sys.ms => 1:3], (0.0, 1.0), [sys.p => ones(3, 3)]) +@test_nowarn solve(prob, Tsit5()) + # x/x @variables x(t) @named sys = ODESystem([D(x) ~ x / x], t)