Skip to content

Commit

Permalink
feat: add support for parameter and indepvar symbols in getu
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 16, 2024
1 parent b35a302 commit 5ab2679
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 13 deletions.
32 changes: 25 additions & 7 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ Return a function that takes an integrator, problem or solution of `sys`, and re
the value of the symbolic `sym`. If `sym` is not an observed quantity, the returned
function can also directly be called with an array of values representing the state
vector. `sym` can be a direct index into the state vector, a symbolic state, a symbolic
expression involving symbolic quantities in the system `sys`, or an array/tuple of the
aforementioned. If the returned function is called with a timeseries object, it can also
be given a second argument representing the index at which to find the value of `sym`.
expression involving symbolic quantities in the system `sys`, a parameter symbol, or the
independent variable symbol, or an array/tuple of the aforementioned. If the returned
function is called with a timeseries object, it can also be given a second argument
representing the index at which to find the value of `sym`.
At minimum, this requires that the integrator, problem or solution implement
[`state_values`](@ref). To support symbolic expressions, the integrator or problem
Expand Down Expand Up @@ -131,6 +132,19 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return getu(sys, idx)
elseif is_parameter(sys, sym)
return let fn = getp(sys, sym)
getter(prob, args...) = fn(prob)
getter

Check warning on line 138 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L135-L138

Added lines #L135 - L138 were not covered by tests
end
elseif is_independent_variable(sys, sym)
_getter(::IsTimeseriesTrait, prob) = current_time(prob)
_getter(::Timeseries, prob, i) = current_time(prob, i)
return let _getter = _getter
getter(prob) = _getter(is_timeseries(prob), prob)
getter(prob, i) = _getter(is_timeseries(prob), prob, i)
getter

Check warning on line 146 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L140-L146

Added lines #L140 - L146 were not covered by tests
end
elseif is_observed(sys, sym)
fn = observed(sys, sym)
if is_time_dependent(sys)
Expand Down Expand Up @@ -226,11 +240,15 @@ function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym)
end

function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`")
idx = variable_index(sys, sym)
return function setter!(prob, val)
set_state!(prob, val, idx)
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return function setter!(prob, val)
set_state!(prob, val, idx)

Check warning on line 246 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L242-L246

Added lines #L242 - L246 were not covered by tests
end
elseif is_parameter(sys, sym)
return setp(sys, sym)

Check warning on line 249 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L248-L249

Added lines #L248 - L249 were not covered by tests
end
error("Invalid symbol $sym for `setu`")

Check warning on line 251 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L251

Added line #L251 was not covered by tests
end

for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
Expand Down
76 changes: 70 additions & 6 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
using SymbolicIndexingInterface

struct FakeIntegrator{S, U}
struct FakeIntegrator{S, U, P, T}
sys::S
u::U
p::P
t::T
end

SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p
SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t

sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
u = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(u))
p = [11.0, 12.0, 13.0]
t = 0.5
fi = FakeIntegrator(sys, copy(u), copy(p), t)
# checking inference for non-concretely typed arrays will always fail
for (sym, val, newval, check_inference) in [
(:x, u[1], 4.0, true)
Expand Down Expand Up @@ -61,19 +67,60 @@ for (sym, val, newval, check_inference) in [
@test get(u) == val
end

for (sym, oldval, newval, check_inference) in [
(:a, p[1], 4.0, true)
(:b, p[2], 5.0, true)
(:c, p[3], 6.0, true)
([:a, :b], p[1:2], [4.0, 5.0], true)
((:c, :b), (p[3], p[2]), (6.0, 5.0), true)
([:x, :a], [u[1], p[1]], [4.0, 5.0], false)
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)
]
get = getu(fi, sym)
set! = setu(fi, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == oldval
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
end

struct FakeSolution{S, U}
for (sym, val, check_inference) in [
(:t, t, true),
([:x, :a, :t], [u[1], p[1], t], false),
((:x, :a, :t), (u[1], p[1], t), true),
]
get = getu(fi, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == val
end

struct FakeSolution{S, U, P, T}
sys::S
u::U
p::P
t::T
end

SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries()
SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys
SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u
SymbolicIndexingInterface.parameter_values(fp::FakeSolution) = fp.p
SymbolicIndexingInterface.current_time(fp::FakeSolution) = fp.t

sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
sol = FakeSolution(sys, u)
t = [1.5, 2.0]
sol = FakeSolution(sys, u, p, t)

xvals = getindex.(sol.u, 1)
yvals = getindex.(sol.u, 2)
Expand All @@ -97,6 +144,11 @@ for (sym, ans, check_inference) in [
((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true)
((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true)
((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true)
([:x, :a], vcat.(xvals, p[1]), false)
((:y, :b), tuple.(yvals, p[2]), true)
(:t, t, true)
([:x, :a, :t], vcat.(xvals, p[1], t), false)
((:x, :a, :t), tuple.(xvals, p[1], t), true)
]
get = getu(sys, sym)
if check_inference
Expand All @@ -110,3 +162,15 @@ for (sym, ans, check_inference) in [
@test get(sol, i) == ans[i]
end
end

for (sym, val) in [
(:a, p[1])
(:b, p[2])
(:c, p[3])
([:a, :b], p[1:2])
((:c, :b), (p[3], p[2]))
]
get = getu(fi, sym)
@inferred get(fi)
@test get(fi) == val
end

0 comments on commit 5ab2679

Please sign in to comment.