From 5cea80e6a0ad584fc26a8be1c7c2868c4f4940fb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 13 Nov 2024 12:07:23 +0530 Subject: [PATCH] feat: add support for directly generating tuple observed functions --- src/index_provider_interface.jl | 18 ++++++++++++++++++ src/parameter_indexing.jl | 7 +++++-- src/state_indexing.jl | 7 ++++--- test/state_indexing_test.jl | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/index_provider_interface.jl b/src/index_provider_interface.jl index f2932d5..bf7aa0b 100644 --- a/src/index_provider_interface.jl +++ b/src/index_provider_interface.jl @@ -212,6 +212,24 @@ See also: [`is_time_dependent`](@ref), [`is_markovian`](@ref), [`constant_struct observed(indp, sym) = observed(symbolic_container(indp), sym) observed(indp, sym, states) = observed(symbolic_container(indp), sym, states) +""" + supports_tuple_observed(indp) + +Check if the given index provider supports generating observed functions for tuples of +symbolic variables. Falls back using `symbolic_container`, and returns `false` by +default. + +See also: [`observed`](@ref), [`parameter_observed`](@ref), [`symbolic_container`](@ref). +""" +function supports_tuple_observed(indp) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) && + (sc = symbolic_container(indp)) !== indp + supports_tuple_observed(sc) + else + false + end +end + """ is_time_dependent(indp) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 6333cea..12fe0a5 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -599,12 +599,14 @@ for (t1, t2) in [ # `getp` errors on older MTK that doesn't support `parameter_observed`. getters = getp.((sys,), p) num_observed = count(is_observed_getter, getters) + supports_tuple = supports_tuple_observed(sys) p_arr = p isa Tuple ? collect(p) : p if num_observed == 0 return MultipleParametersGetter(getters) else - pofn = parameter_observed(sys, p_arr) + pofn = supports_tuple ? parameter_observed(sys, p) : + parameter_observed(sys, p_arr) if pofn === nothing return MultipleParametersGetter.(getters) end @@ -615,7 +617,8 @@ for (t1, t2) in [ else getter = GetParameterObservedNoTime(pofn) end - return p isa Tuple ? AsParameterTupleWrapper{length(p)}(getter) : getter + return p isa Tuple && !supports_tuple ? + AsParameterTupleWrapper{length(p)}(getter) : getter end end end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index b434e4c..243f7f2 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -252,6 +252,7 @@ for (t1, t2) in [ return MultipleGetters(ContinuousTimeseries(), sym) end sym_arr = sym isa Tuple ? collect(sym) : sym + supports_tuple = supports_tuple_observed(sys) num_observed = 0 for s in sym num_observed += is_observed(sys, s) @@ -261,7 +262,7 @@ for (t1, t2) in [ if num_observed == 0 || num_observed == 1 && sym isa Tuple return MultipleGetters(nothing, getsym.((sys,), sym)) else - obs = observed(sys, sym_arr) + obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr) getter = TimeIndependentObservedFunction(obs) if sym isa Tuple getter = AsTupleWrapper{length(sym)}(getter) @@ -283,13 +284,13 @@ for (t1, t2) in [ getters = getsym.((sys,), sym) return MultipleGetters(ts_idxs, getters) else - obs = observed(sys, sym_arr) + obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr) getter = if is_time_dependent(sys) TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, obs) else TimeIndependentObservedFunction(obs) end - if sym isa Tuple + if sym isa Tuple && !supports_tuple getter = AsTupleWrapper{length(sym)}(getter) end return getter diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 622e5b1..ade57bb 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -346,3 +346,35 @@ getter = getsym(sys, :(x + y)) @test getter(fi) ≈ 2.8 @test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11] @test getter(fs, 1) ≈ 2.8 + +struct TupleObservedWrapper{S} + sys::S +end +SymbolicIndexingInterface.symbolic_container(t::TupleObservedWrapper) = t.sys +SymbolicIndexingInterface.supports_tuple_observed(::TupleObservedWrapper) = true + +@testset "Tuple observed" begin + sc = SymbolCache([:x, :y, :z], [:a, :b, :c]) + sys = TupleObservedWrapper(sc) + ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]) + getter = getsym(sys, (:(x + y), :(y + z))) + @test all(getter(ps) .≈ (3.0, 5.0)) + @test getter(ps) isa Tuple + @test_nowarn @inferred getter(ps) + getter = getsym(sys, (:(a + b), :(b + c))) + @test all(getter(ps) .≈ (0.3, 0.5)) + @test getter(ps) isa Tuple + @test_nowarn @inferred getter(ps) + + sc = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) + sys = TupleObservedWrapper(sc) + ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.1) + getter = getsym(sys, (:(x + y), :(y + t))) + @test all(getter(ps) .≈ (3.0, 2.1)) + @test getter(ps) isa Tuple + @test_nowarn @inferred getter(ps) + getter = getsym(sys, (:(a + b), :(b + c))) + @test all(getter(ps) .≈ (0.3, 0.5)) + @test getter(ps) isa Tuple + @test_nowarn @inferred getter(ps) +end