Skip to content

Commit

Permalink
Support array states in ODEProblem/ODEFunction
Browse files Browse the repository at this point in the history
Co-authored-by: Aayush Sabharwal <[email protected]>
  • Loading branch information
YingboMa and AayushSabharwal committed Feb 16, 2024
1 parent 61e08f3 commit aa16e5e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
27 changes: 27 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,21 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
nothing,
isdde = false,
kwargs...)
array_vars = Dict{Any, Vector{Int}}()
for (j, x) in enumerate(dvs)
if istree(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)

Check warning on line 160 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L158-L160

Added lines #L158 - L160 were not covered by tests
end
end
subs = Dict()
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
inds = inds′

Check warning on line 166 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
end
subs[k] = term(view, Sym{Any}(Symbol("ˍ₋arg1")), inds)
end

Check warning on line 169 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L168-L169

Added lines #L168 - L169 were not covered by tests
if isdde
eqs = delay_to_function(sys)
else
Expand All @@ -164,6 +179,7 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
# substitute x(t) by just x
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
[eq.rhs for eq in eqs]
rhss = fast_substitute(rhss, subs)

# TODO: add an optional check on the ordering of observed equations
u = map(x -> time_varying_as_func(value(x), sys), dvs)
Expand Down Expand Up @@ -764,6 +780,17 @@ function get_u0_p(sys,
defs = mergedefaults(defs, parammap, ps)
end
defs = mergedefaults(defs, u0map, dvs)
for (k, v) in defs
if Symbolics.isarraysymbolic(k)
ks = scalarize(k)
length(ks) == length(v) || error("$k has default value $v with unmatched size")
for (kk, vv) in zip(ks, v)
if !haskey(defs, kk)
defs[kk] = vv

Check warning on line 789 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L785-L789

Added lines #L785 - L789 were not covered by tests
end
end

Check warning on line 791 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L791

Added line #L791 was not covered by tests
end
end

if symbolic_u0
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ end
function fast_substitute(eq::T, subs::Pair) where {T <: Eq}
T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
end
fast_substitute(eqs::AbstractArray{<:Eq}, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(a, b) = substitute(a, b)
function fast_substitute(expr, pair::Pair)
a, b = pair
Expand Down

0 comments on commit aa16e5e

Please sign in to comment.