Skip to content

Commit

Permalink
fix: observed getu generation, tuple wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Apr 30, 2024
1 parent c6bf421 commit 8ccf0f0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,14 @@ struct AsTupleWrapper{G} <: AbstractIndexer
getter::G
end

function (atw::AsTupleWrapper)(::IsTimeseriesTrait, args...)
return Tuple(atw.getter(args...))
function (atw::AsTupleWrapper)(::Timeseries, prob)
return Tuple.(atw.getter(prob))
end
function (atw::AsTupleWrapper)(::Timeseries, prob, i)
return Tuple(atw.getter(prob, i))
end
function (atw::AsTupleWrapper)(::NotTimeseries, prob)
return Tuple(atw.getter(prob))
end

for (t1, t2) in [
Expand All @@ -151,7 +157,7 @@ for (t1, t2) in [
return MultipleGetters(getters)
else
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
getter = if is_timeseries(sys)
getter = if is_time_dependent(sys)
TimeDependentObservedFunction(obs)
else
TimeIndependentObservedFunction(obs)
Expand Down
26 changes: 25 additions & 1 deletion test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
@test get(u) == val
end

for (sym, val, check_inference) in [
(:(x + y), u[1] + u[2], true),
([:(x + y), :z], [u[1] + u[2], u[3]], false),
((:(x + y), :(z + y)), (u[1] + u[2], u[2] + u[3]), false)
]
get = getu(sys, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == val
end

for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
(:b, p[2], 5.0, true)
(:c, p[3], 6.0, true)
Expand All @@ -101,7 +113,7 @@ end
for (sym, val, check_inference) in [
(:t, t, true),
([:x, :a, :t], [u[1], p[1], t], false),
((:x, :a, :t), (u[1], p[1], t), true)
((:x, :a, :t), (u[1], p[1], t), false)
]
get = getu(fi, sym)
if check_inference
Expand Down Expand Up @@ -182,6 +194,18 @@ for (sym, ans, check_inference) in [(:x, xvals, true)
end
end

for (sym, val, check_inference) in [
(:(x + y), xvals .+ yvals, true),
([:(x + y), :z], [xvals .+ yvals, zvals], false),
((:(x + y), :(z + y)), (xvals .+ yvals, yvals .+ zvals), false)
]
get = getu(sys, sym)
if check_inference
@inferred get(sol)
end
@test get(fi) == val
end

for (sym, val) in [(:a, p[1])
(:b, p[2])
(:c, p[3])
Expand Down

0 comments on commit 8ccf0f0

Please sign in to comment.