Skip to content

Commit

Permalink
feat: update interface, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 31, 2023
1 parent 9885731 commit 15a987b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
8 changes: 5 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function has_static_variable end
variable_index(sys, sym, [i])
Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. If
[`has_static_variable`](@ref) is `false`, this accepts the current timestep as an
[`has_static_variable`](@ref) is `false`, this accepts the current time index as an
additional parameter `i`.
"""
function variable_index end
Expand Down Expand Up @@ -57,12 +57,14 @@ Check whether the given `sym` is an observed value in `sys`.
function is_observed end

"""
observed(sys, sym)
observed(sys, sym, [states])
Return the observed function of the given `sym` in `sys`. The returned function should
have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and
parameter vector. If `istimedependent(sys) == true`, the function should accept
the current time `t` as its third parameter.
the current time `t` as its third parameter. If `constant_structure(sys) == false`,
accept a third parameter which can either be a vector of symbols indicating the order
of states or a time index which identifies the order of states.
See also: [`is_time_dependent`](@ref), [`has_static_variable`](@ref)
"""
Expand Down
49 changes: 27 additions & 22 deletions test/example_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ end
SymbolicIndexingInterface.is_variable(sys::SymbolCache, sym) = sym in sys.vars
function SymbolicIndexingInterface.variable_index(sys::SymbolCache, sym, t = nothing)
if !has_static_variable(sys) && t === nothing
error("timestep must be present")
error("time index must be present")
end
findfirst(isequal(sym), current_state(sys, t))
end
Expand All @@ -25,18 +25,22 @@ end
function SymbolicIndexingInterface.is_observed(sys::SymbolCache, sym)
is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym)
end
function SymbolicIndexingInterface.observed(sys::SymbolCache, sym)
function SymbolicIndexingInterface.observed(sys::SymbolCache, sym, states = nothing)
if !has_static_variable(sys) && states === nothing
error("States required")
end
states = states isa Vector ? states : current_state(sys, states)
if is_variable(sys, sym)
return is_time_dependent(sys) ?
(u, p, t) -> [u[variable_index(sys, sym, i)] * i for i in t] :
(u, p) -> [u[variable_index(sys, sym)] * i for i in 1:5]
(u, p, t) -> u[findfirst(isequal(sym), states)] :
(u, p) -> u[findfirst(isequal(sym), states)]
end
idx = parameter_index(sys, sym)
if idx !== nothing
return is_time_dependent(sys) ? (u, p, t) -> p[idx] : (u, p) -> p[idx]
end
if is_independent_variable(sys, sym)
return is_time_dependent(sys) ? (u, p, t) -> t : (u, p) -> 1:5
return is_time_dependent(sys) ? (u, p, t) -> t : (u, p) -> 1
end
end
SymbolicIndexingInterface.is_time_dependent(sys::SymbolCache) = isequal(sys.indepvar, :t)
Expand All @@ -56,23 +60,23 @@ sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], :t)
@test is_independent_variable(sys, :t)
@test all(.!is_independent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r]))
@test all(is_observed.((sys,), [:x, :y, :z, :a, :b, :c, :t]))
@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [1, 2, 3, 4])
@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [2, 4, 6, 8])
@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [3, 6, 9, 12])
@test observed(sys, :a)(1:3, 4:6, 1:4) == 4
@test observed(sys, :b)(1:3, 4:6, 1:4) == 5
@test observed(sys, :c)(1:3, 4:6, 1:4) == 6
@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4
@test all(observed(sys, :x)(1:3, 4:6, 1.5) .== 1)
@test all(observed(sys, :y)(1:3, 4:6, 1.5) .== 2)
@test all(observed(sys, :z)(1:3, 4:6, 1.5) .== 3)
@test observed(sys, :a)(1:3, 4:6, 1.5) == 4
@test observed(sys, :b)(1:3, 4:6, 1.5) == 5
@test observed(sys, :c)(1:3, 4:6, 1.5) == 6
@test observed(sys, :t)(1:3, 4:6, 1.5) == 1.5
@test is_time_dependent(sys)
@test constant_structure(sys)
@test has_static_variable(sys)

sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], nothing)

@test !is_time_dependent(sys)
@test all(observed(sys, :x)(1:3, 4:6) .== [1, 2, 3, 4, 5])
@test all(observed(sys, :y)(1:3, 4:6) .== [2, 4, 6, 8, 10])
@test all(observed(sys, :z)(1:3, 4:6) .== [3, 6, 9, 12, 15])
@test all(observed(sys, :x)(1.0:3.0, 4:6) .== 1.0)
@test all(observed(sys, :y)(1.0:3.0, 4:6) .== 2.0)
@test all(observed(sys, :z)(1.0:3.0, 4:6) .== 3.0)
@test observed(sys, :a)(1:3, 4:6) == 4
@test observed(sys, :b)(1:3, 4:6) == 5
@test observed(sys, :c)(1:3, 4:6) == 6
Expand All @@ -82,13 +86,14 @@ sys = SymbolCache(false, [:x, :y, :z], [:a, :b, :c], :t)
@test !has_static_variable(sys)
for variable in [:x, :y, :z, :a, :b, :c, :t]
@test_throws ErrorException variable_index(sys, variable)
@test_throws ErrorException observed(sys, variable)
end
@test all(variable_index.((sys,), [:z, :y, :x], 1) .== [1, 3, 2])
@test all(variable_index.((sys,), [:a, :b, :c, :t], 1) .== nothing)
@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [2, 6, 3, 8])
@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [3, 2, 6, 12])
@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [1, 4, 9, 4])
@test observed(sys, :a)(1:3, 4:6, 1:4) == 4
@test observed(sys, :b)(1:3, 4:6, 1:4) == 5
@test observed(sys, :c)(1:3, 4:6, 1:4) == 6
@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4
@test all(observed(sys, :x, 2)(1:3, 4:6, 1.5) .== 3)
@test all(observed(sys, :y, 2)(1:3, 4:6, 1.5) .== 1)
@test all(observed(sys, :z, 2)(1:3, 4:6, 1.5) .== 2)
@test observed(sys, :a, 2)(1:3, 4:6, 1.5) == 4
@test observed(sys, :b, 2)(1:3, 4:6, 1.5) == 5
@test observed(sys, :c, 2)(1:3, 4:6, 1.5) == 6
@test observed(sys, :t, 2)(1:3, 4:6, 1.5) == 1.5

0 comments on commit 15a987b

Please sign in to comment.