Skip to content

Commit

Permalink
feat: flatten equations to avoid scalarizing array arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Feb 19, 2024
1 parent 6fbd195 commit decfb51
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 7 deletions.
27 changes: 27 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,17 @@ function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false)
defs = mergedefaults(defs, parammap, ps)

Check warning on line 807 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L802-L807

Added lines #L802 - L807 were not covered by tests
end
defs = mergedefaults(defs, u0map, dvs)
for (k, v) in defs
if Symbolics.isarraysymbolic(k)
ks = scalarize(k)
length(ks) == length(v) || error("$k has default value $v with unmatched size")
for (kk, vv) in zip(ks, v)
if !haskey(defs, kk)
defs[kk] = vv

Check warning on line 816 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L809-L816

Added lines #L809 - L816 were not covered by tests
end
end

Check warning on line 818 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L818

Added line #L818 was not covered by tests
end
end

Check warning on line 820 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L820

Added line #L820 was not covered by tests

if symbolic_u0
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)

Check warning on line 823 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L822-L823

Added lines #L822 - L823 were not covered by tests
Expand Down Expand Up @@ -1415,3 +1426,19 @@ function isisomorphic(sys1::AbstractODESystem, sys2::AbstractODESystem)
end
return false
end

function flatten_equations(eqs)
mapreduce(vcat, eqs; init = Equation[]) do eq
islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs)
isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs)
if islhsarr || isrhsarr
islhsarr && isrhsarr ||

Check warning on line 1435 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1430-L1435

Added lines #L1430 - L1435 were not covered by tests
error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar")
size(eq.lhs) == size(eq.rhs) ||

Check warning on line 1437 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1437

Added line #L1437 was not covered by tests
error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))")
return collect(eq.lhs) .~ collect(eq.rhs)

Check warning on line 1439 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1439

Added line #L1439 was not covered by tests
else
eq

Check warning on line 1441 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1441

Added line #L1441 was not covered by tests
end
end
end
17 changes: 15 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,19 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
deqs = reduce(vcat, scalarize(deqs); init = Equation[])
deqs = mapreduce(vcat, deqs; init = Equation[]) do eq
islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs)
isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs)
if islhsarr || isrhsarr
islhsarr && isrhsarr ||

Check warning on line 209 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L205-L209

Added lines #L205 - L209 were not covered by tests
error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar")
size(eq.lhs) == size(eq.rhs) ||

Check warning on line 211 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L211

Added line #L211 was not covered by tests
error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))")
return collect(eq.lhs) .~ collect(eq.rhs)

Check warning on line 213 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L213

Added line #L213 was not covered by tests
else
eq

Check warning on line 215 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L215

Added line #L215 was not covered by tests
end
end
iv′ = value(iv)
ps′ = value.(ps)
ctrl′ = value.(controls)
Expand Down Expand Up @@ -284,7 +296,8 @@ function ODESystem(eqs, iv; kwargs...)
for p in ps
if istree(p) && operation(p) === getindex
par = arguments(p)[begin]
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && all(par[i] in ps for i in eachindex(par))
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&

Check warning on line 299 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L295-L299

Added lines #L295 - L299 were not covered by tests
all(par[i] in ps for i in eachindex(par))
push!(new_ps, par)

Check warning on line 301 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L301

Added line #L301 was not covered by tests
else
push!(new_ps, p)

Check warning on line 303 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L303

Added line #L303 was not covered by tests
Expand Down
9 changes: 8 additions & 1 deletion src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
gui_metadata = nothing)
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
deqs = scalarize(deqs)
deqs = flatten_equations(deqs)
neqs = mapreduce(vcat, neqs) do expr
if expr isa AbstractArray || Symbolics.isarraysymbolic(expr)
collect(expr)

Check warning on line 177 in src/systems/diffeqs/sdesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/sdesystem.jl#L174-L177

Added lines #L174 - L177 were not covered by tests
else
expr

Check warning on line 179 in src/systems/diffeqs/sdesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/sdesystem.jl#L179

Added line #L179 was not covered by tests
end
end
iv′ = value(iv)
dvs′ = value.(dvs)
ps′ = value.(ps)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ function JumpSystem(eqs, iv, unknowns, ps;
kwargs...)
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
eqs = scalarize(eqs)
eqs = flatten_equations(eqs)

Check warning on line 154 in src/systems/jumps/jumpsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/jumps/jumpsystem.jl#L154

Added line #L154 was not covered by tests
sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
throw(ArgumentError("System names must be unique."))
Expand Down
3 changes: 2 additions & 1 deletion src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
end

for (sym, _) in p
if istree(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[begin])
if istree(sym) && operation(sym) === getindex &&
is_parameter(sys, arguments(sym)[begin])

Check warning on line 41 in src/systems/parameter_buffer.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/parameter_buffer.jl#L39-L41

Added lines #L39 - L41 were not covered by tests
# error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
end
end

Check warning on line 44 in src/systems/parameter_buffer.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/parameter_buffer.jl#L44

Added line #L44 was not covered by tests
Expand Down
18 changes: 16 additions & 2 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -520,15 +520,29 @@ using SymbolicUtils.Code
using Symbolics: unwrap, wrap, @register_symbolic
foo(a, ms::AbstractVector) = a + sum(ms)
@register_symbolic foo(a, ms::AbstractVector)
@variables t x(t) ms(t)[1:3]
D = Differential(t)
@variables x(t) ms(t)[1:3]
eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)]
@named sys = ODESystem(eqs, t, [x; ms], [])
@named emptysys = ODESystem(Equation[], t)
@mtkbuild outersys = compose(emptysys, sys)
prob = ODEProblem(outersys, [sys.x => 1.0, sys.ms => 1:3], (0, 1.0))
@test_nowarn solve(prob, Tsit5())

# array equations
bar(x, p) = p * x
@register_array_symbolic bar(x::AbstractVector, p::AbstractMatrix) begin
size = size(x)
eltype = promote_type(eltype(x), eltype(p))
end
@parameters p[1:3, 1:3]
eqs = [D(x) ~ foo(x, ms); D(ms) ~ bar(ms, p)]
@named sys = ODESystem(eqs, t)
@named emptysys = ODESystem(Equation[], t)
@mtkbuild outersys = compose(emptysys, sys)
prob = ODEProblem(
outersys, [sys.x => 1.0, sys.ms => 1:3], (0.0, 1.0), [sys.p => ones(3, 3)])
@test_nowarn solve(prob, Tsit5())

# x/x
@variables x(t)
@named sys = ODESystem([D(x) ~ x / x], t)
Expand Down

0 comments on commit decfb51

Please sign in to comment.