From efa68afe41f6b8c834cf7e403641e821b906eabb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Jun 2024 13:17:06 +0530 Subject: [PATCH] fix: handle edge case with MTK symbolic indexing --- src/parameter_indexing.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 633e512a..1191c937 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -507,15 +507,25 @@ function (atw::AsParameterTupleWrapper)( atw.getter(buffer, ts, prob, args...) end +is_observed_getter(_) = false +is_observed_getter(::GetParameterObserved) = true +is_observed_getter(::GetParameterObservedNoTime) = true +is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg.getters) + for (t1, t2) in [ (ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) - num_observed = count(x -> is_observed(sys, x), p) + # We need to do it this way because if an `ODESystem` has `p[1], p[2], p[3]` as + # parameters (all scalarized) then `is_observed(sys, p[2:3]) == true`. Then, + # `getp` errors on older MTK that doesn't support `parameter_observed`. + getters = getp.((sys,), p) + num_observed = count(is_observed_getter, getters) + if num_observed == 0 - return MultipleParametersGetter(getp.((sys,), p)) + return MultipleParametersGetter(getters) else pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p) if is_time_dependent(sys)