Skip to content

Commit

Permalink
Merge pull request #107 from SciML/as/tuple-observed
Browse files Browse the repository at this point in the history
feat: add support for directly generating tuple observed functions
  • Loading branch information
AayushSabharwal authored Nov 13, 2024
2 parents 585be57 + 5cea80e commit 124a6b8
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
18 changes: 18 additions & 0 deletions src/index_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,24 @@ See also: [`is_time_dependent`](@ref), [`is_markovian`](@ref), [`constant_struct
observed(indp, sym) = observed(symbolic_container(indp), sym)
observed(indp, sym, states) = observed(symbolic_container(indp), sym, states)

"""
supports_tuple_observed(indp)
Check if the given index provider supports generating observed functions for tuples of
symbolic variables. Falls back using `symbolic_container`, and returns `false` by
default.
See also: [`observed`](@ref), [`parameter_observed`](@ref), [`symbolic_container`](@ref).
"""
function supports_tuple_observed(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) !== indp
supports_tuple_observed(sc)
else
false
end
end

"""
is_time_dependent(indp)
Expand Down
7 changes: 5 additions & 2 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,14 @@ for (t1, t2) in [
# `getp` errors on older MTK that doesn't support `parameter_observed`.
getters = getp.((sys,), p)
num_observed = count(is_observed_getter, getters)
supports_tuple = supports_tuple_observed(sys)
p_arr = p isa Tuple ? collect(p) : p

if num_observed == 0
return MultipleParametersGetter(getters)
else
pofn = parameter_observed(sys, p_arr)
pofn = supports_tuple ? parameter_observed(sys, p) :
parameter_observed(sys, p_arr)
if pofn === nothing
return MultipleParametersGetter.(getters)
end
Expand All @@ -615,7 +617,8 @@ for (t1, t2) in [
else
getter = GetParameterObservedNoTime(pofn)
end
return p isa Tuple ? AsParameterTupleWrapper{length(p)}(getter) : getter
return p isa Tuple && !supports_tuple ?
AsParameterTupleWrapper{length(p)}(getter) : getter
end
end
end
Expand Down
7 changes: 4 additions & 3 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ for (t1, t2) in [
return MultipleGetters(ContinuousTimeseries(), sym)
end
sym_arr = sym isa Tuple ? collect(sym) : sym
supports_tuple = supports_tuple_observed(sys)
num_observed = 0
for s in sym
num_observed += is_observed(sys, s)
Expand All @@ -261,7 +262,7 @@ for (t1, t2) in [
if num_observed == 0 || num_observed == 1 && sym isa Tuple
return MultipleGetters(nothing, getsym.((sys,), sym))
else
obs = observed(sys, sym_arr)
obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr)
getter = TimeIndependentObservedFunction(obs)
if sym isa Tuple
getter = AsTupleWrapper{length(sym)}(getter)
Expand All @@ -283,13 +284,13 @@ for (t1, t2) in [
getters = getsym.((sys,), sym)
return MultipleGetters(ts_idxs, getters)
else
obs = observed(sys, sym_arr)
obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr)
getter = if is_time_dependent(sys)
TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, obs)
else
TimeIndependentObservedFunction(obs)
end
if sym isa Tuple
if sym isa Tuple && !supports_tuple
getter = AsTupleWrapper{length(sym)}(getter)
end
return getter
Expand Down
32 changes: 32 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,35 @@ getter = getsym(sys, :(x + y))
@test getter(fi) 2.8
@test getter(fs) [3.0i + 2(ts[i] - 0.1) for i in 1:11]
@test getter(fs, 1) 2.8

struct TupleObservedWrapper{S}
sys::S
end
SymbolicIndexingInterface.symbolic_container(t::TupleObservedWrapper) = t.sys
SymbolicIndexingInterface.supports_tuple_observed(::TupleObservedWrapper) = true

@testset "Tuple observed" begin
sc = SymbolCache([:x, :y, :z], [:a, :b, :c])
sys = TupleObservedWrapper(sc)
ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3])
getter = getsym(sys, (:(x + y), :(y + z)))
@test all(getter(ps) .≈ (3.0, 5.0))
@test getter(ps) isa Tuple
@test_nowarn @inferred getter(ps)
getter = getsym(sys, (:(a + b), :(b + c)))
@test all(getter(ps) .≈ (0.3, 0.5))
@test getter(ps) isa Tuple
@test_nowarn @inferred getter(ps)

sc = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
sys = TupleObservedWrapper(sc)
ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.1)
getter = getsym(sys, (:(x + y), :(y + t)))
@test all(getter(ps) .≈ (3.0, 2.1))
@test getter(ps) isa Tuple
@test_nowarn @inferred getter(ps)
getter = getsym(sys, (:(a + b), :(b + c)))
@test all(getter(ps) .≈ (0.3, 0.5))
@test getter(ps) isa Tuple
@test_nowarn @inferred getter(ps)
end

0 comments on commit 124a6b8

Please sign in to comment.