Skip to content

Commit

Permalink
feat: add current_state, refactor observed, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 27, 2023
1 parent 52fc726 commit 9885731
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
3 changes: 2 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type
include("trait.jl")

export is_variable, has_static_variable, variable_index, is_parameter, parameter_index,
is_independent_variable, is_observed, observed, is_time_dependent, constant_structure
is_independent_variable, current_state, is_observed, observed, is_time_dependent,
constant_structure
include("interface.jl")

@static if !isdefined(Base, :get_extension)
Expand Down
13 changes: 9 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ Check whether the given `sym` is an independent variable in `sys`.
"""
function is_independent_variable end

"""
current_state(sys, i)
Return a vector of symbols representing the state variables at timestep `i`.
"""
function current_state end

"""
is_observed(sys, sym)
Expand All @@ -50,14 +57,12 @@ Check whether the given `sym` is an observed value in `sys`.
function is_observed end

"""
observed(sys, sym, [symbolic_states])
observed(sys, sym)
Return the observed function of the given `sym` in `sys`. The returned function should
have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and
parameter vector. If `istimedependent(sys) == true`, the function should accept
the current time `t` as its third parameter. If `has_static_variable(sys) == false` then
`observed` must accept a third parameter `symbolic_states` indicating the order of symbolic
variables in `u`.
the current time `t` as its third parameter.
See also: [`is_time_dependent`](@ref), [`has_static_variable`](@ref)
"""
Expand Down
41 changes: 17 additions & 24 deletions test/example_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ function SymbolicIndexingInterface.variable_index(sys::SymbolCache, sym, t = not
if !has_static_variable(sys) && t === nothing
error("timestep must be present")
end
findfirst(isequal(sym), sys.vars)
findfirst(isequal(sym), current_state(sys, t))
end
function SymbolicIndexingInterface.current_state(sys::SymbolCache, i)
return has_static_variable(sys) ? sys.vars : circshift(sys.vars, i)
end
SymbolicIndexingInterface.is_parameter(sys::SymbolCache, sym) = sym in sys.params
function SymbolicIndexingInterface.parameter_index(sys::SymbolCache, sym)
Expand All @@ -22,19 +25,11 @@ end
function SymbolicIndexingInterface.is_observed(sys::SymbolCache, sym)
is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym)
end
function SymbolicIndexingInterface.observed(sys::SymbolCache,
sym,
symbolic_states = nothing)
if !has_static_variable(sys) && symbolic_states === nothing
error("Symbolic states must be present")
end
if has_static_variable(sys)
symbolic_states = sys.vars
end
idx = findfirst(isequal(sym), symbolic_states)
if idx !== nothing
return is_time_dependent(sys) ? (u, p, t) -> [u[idx] * i for i in t] :
(u, p) -> [u[idx] * i for i in 1:5]
function SymbolicIndexingInterface.observed(sys::SymbolCache, sym)
if is_variable(sys, sym)
return is_time_dependent(sys) ?
(u, p, t) -> [u[variable_index(sys, sym, i)] * i for i in t] :
(u, p) -> [u[variable_index(sys, sym)] * i for i in 1:5]
end
idx = parameter_index(sys, sym)
if idx !== nothing
Expand Down Expand Up @@ -87,15 +82,13 @@ sys = SymbolCache(false, [:x, :y, :z], [:a, :b, :c], :t)
@test !has_static_variable(sys)
for variable in [:x, :y, :z, :a, :b, :c, :t]
@test_throws ErrorException variable_index(sys, variable)
@test_throws ErrorException observed(sys, variable)
end
@test all(variable_index.((sys,), [:z, :y, :x], 1) .== [3, 2, 1])
@test all(variable_index.((sys,), [:z, :y, :x], 1) .== [1, 3, 2])
@test all(variable_index.((sys,), [:a, :b, :c, :t], 1) .== nothing)
variable_order = [:x, :y, :z]
@test all(observed(sys, :x, variable_order)(1:3, 4:6, 1:4) .== [1, 2, 3, 4])
@test all(observed(sys, :y, variable_order)(1:4, 4:6, 1:4) .== [2, 4, 6, 8])
@test all(observed(sys, :z, variable_order)(1:4, 4:6, 1:4) .== [3, 6, 9, 12])
@test observed(sys, :a, variable_order)(1:3, 4:6, 1:4) == 4
@test observed(sys, :b, variable_order)(1:3, 4:6, 1:4) == 5
@test observed(sys, :c, variable_order)(1:3, 4:6, 1:4) == 6
@test observed(sys, :t, variable_order)(1:3, 4:6, 1:4) == 1:4
@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [2, 6, 3, 8])
@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [3, 2, 6, 12])
@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [1, 4, 9, 4])
@test observed(sys, :a)(1:3, 4:6, 1:4) == 4
@test observed(sys, :b)(1:3, 4:6, 1:4) == 5
@test observed(sys, :c)(1:3, 4:6, 1:4) == 6
@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4

0 comments on commit 9885731

Please sign in to comment.