Skip to content

Commit

Permalink
fix: handle edge case with MTK symbolic indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 3, 2024
1 parent f9c500f commit efa68af
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -507,15 +507,25 @@ function (atw::AsParameterTupleWrapper)(
atw.getter(buffer, ts, prob, args...)
end

is_observed_getter(_) = false
is_observed_getter(::GetParameterObserved) = true
is_observed_getter(::GetParameterObservedNoTime) = true
is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg.getters)

for (t1, t2) in [
(ArraySymbolic, Any),
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
num_observed = count(x -> is_observed(sys, x), p)
# We need to do it this way because if an `ODESystem` has `p[1], p[2], p[3]` as
# parameters (all scalarized) then `is_observed(sys, p[2:3]) == true`. Then,
# `getp` errors on older MTK that doesn't support `parameter_observed`.
getters = getp.((sys,), p)
num_observed = count(is_observed_getter, getters)

if num_observed == 0
return MultipleParametersGetter(getp.((sys,), p))
return MultipleParametersGetter(getters)
else
pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p)
if is_time_dependent(sys)
Expand Down

0 comments on commit efa68af

Please sign in to comment.