Skip to content

Commit

Permalink
feat: add functions to get vector of symbols from system
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 7, 2023
1 parent 9204fc5 commit 8f786f3
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 19 deletions.
4 changes: 3 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
```@docs
is_variable
variable_index
variable_symbols
is_parameter
parameter_index
parameter_symbols
is_independent_variable
current_state
independent_variable_symbols
is_observed
observed
is_time_dependent
Expand Down
6 changes: 3 additions & 3 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ using Requires
export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type
include("trait.jl")

export is_variable, variable_index, is_parameter, parameter_index,
is_independent_variable, current_state, is_observed, observed, is_time_dependent,
constant_structure
export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
observed, is_time_dependent, constant_structure
include("interface.jl")

export SymbolCache
Expand Down
26 changes: 22 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ additional parameter `i`.
"""
function variable_index end

"""
variable_symbols(sys, [i])
Return a vector of the symbolic variables being solved for in the system `sys`. If
`constant_structure(sys) == false` this accepts an additional parameter indicating
the current time index. The returned vector should not be mutated.
"""
function variable_symbols end

"""
is_parameter(sys, sym)
Expand All @@ -28,19 +37,28 @@ Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise.
"""
function parameter_index end

"""
parameter_symbols(sys)
Return a vector of the symbolic parameters of the given system `sys`. The returned
vector should not be mutated.
"""
function parameter_symbols end

"""
is_independent_variable(sys, sym)
Check whether the given `sym` is an independent variable in `sys`.
Check whether the given `sym` is an independent variable in `sys`. The returned vector
should not be mutated.
"""
function is_independent_variable end

"""
current_state(sys, i)
independent_variable_symbols(sys)
Return a vector of symbols representing the state variables at timestep `i`.
Return a vector of the symbolic independent variables of the given system `sys`.
"""
function current_state end
function independent_variable_symbols end

"""
is_observed(sys, sym)
Expand Down
12 changes: 6 additions & 6 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ end

is_variable(sc::SymbolCache, sym) = sym in sc.variables
variable_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.variables)
variable_symbols(sc::SymbolCache, i = nothing) = sc.variables
is_parameter(sc::SymbolCache, sym) = sym in sc.parameters
parameter_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.parameters)
parameter_symbols(sc::SymbolCache) = sc.parameters
is_independent_variable(sc::SymbolCache, sym) = sym in sc.independent_variables
current_state(sc::SymbolCache) = sc.variables
independent_variable_symbols(sc::SymbolCache) = sc.independent_variables
is_observed(sc::SymbolCache, sym) = false
is_time_dependent(sc::SymbolCache) = !isempty(sc.independent_variables)
constant_structure(::SymbolCache) = true

Check warning on line 31 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L21-L31

Added lines #L21 - L31 were not covered by tests

function Base.copy(sc::SymbolCache)
return SymbolCache(
sc.variables === nothing ? nothing : copy(sc.variables),
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),

Check warning on line 34 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
sc.parameters === nothing ? nothing : copy(sc.parameters),
sc.independent_variables === nothing ? nothing : copy(sc.independent_variables),
)
end
sc.independent_variables === nothing ? nothing : copy(sc.independent_variables))
end
26 changes: 23 additions & 3 deletions test/example_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,34 @@ function SymbolicIndexingInterface.variable_index(sys::SystemMockup, sym, t = no
if !constant_structure(sys) && t === nothing
error("time index must be present")
end
findfirst(isequal(sym), current_state(sys, t))
findfirst(isequal(sym), variable_symbols(sys, t))
end
function SymbolicIndexingInterface.current_state(sys::SystemMockup, i)
function SymbolicIndexingInterface.variable_symbols(sys::SystemMockup, i = nothing)
return constant_structure(sys) ? sys.vars : circshift(sys.vars, i)
end
SymbolicIndexingInterface.is_parameter(sys::SystemMockup, sym) = sym in sys.params
function SymbolicIndexingInterface.parameter_index(sys::SystemMockup, sym)
findfirst(isequal(sym), sys.params)
end
SymbolicIndexingInterface.parameter_symbols(sys::SystemMockup) = sys.params
function SymbolicIndexingInterface.is_independent_variable(sys::SystemMockup, sym)
sys.indepvar !== nothing && isequal(sym, sys.indepvar)
end
function SymbolicIndexingInterface.independent_variable_symbols(sys::SystemMockup)
if sys.indepvar === nothing
return []
else
return [sys.indepvar]
end
end
function SymbolicIndexingInterface.is_observed(sys::SystemMockup, sym)
is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym)
end
function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = nothing)
if !constant_structure(sys) && states === nothing
error("States required")
end
states = states isa Vector ? states : current_state(sys, states)
states = states isa Vector ? states : variable_symbols(sys, states)
if is_variable(sys, sym)
return is_time_dependent(sys) ?
(u, p, t) -> u[findfirst(isequal(sym), states)] :
Expand Down Expand Up @@ -68,6 +76,9 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
@test observed(sys, :t)(1:3, 4:6, 1.5) == 1.5
@test is_time_dependent(sys)
@test constant_structure(sys)
@test variable_symbols(sys) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == [:t]

sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)

Expand All @@ -79,6 +90,9 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
@test observed(sys, :b)(1:3, 4:6) == 5
@test observed(sys, :c)(1:3, 4:6) == 6
@test constant_structure(sys)
@test variable_symbols(sys) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == []

sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t)
@test !constant_structure(sys)
Expand All @@ -95,3 +109,9 @@ end
@test observed(sys, :b, 2)(1:3, 4:6, 1.5) == 5
@test observed(sys, :c, 2)(1:3, 4:6, 1.5) == 6
@test observed(sys, :t, 2)(1:3, 4:6, 1.5) == 1.5
@test_throws Exception variable_symbols(sys)
@test variable_symbols(sys, 1) == [:z, :x, :y]
@test variable_symbols(sys, 2) == [:y, :z, :x]
@test variable_symbols(sys, 3) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == [:t]
7 changes: 5 additions & 2 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
@test parameter_index.((sc,), [:a, :b, :x]) == [1, 2, nothing]
@test is_independent_variable(sc, :t)
@test all(.!is_independent_variable.((sc,), [:x, :y, :z, :a, :b, :q]))
@test current_state(sc) == [:x, :y, :z]
@test all(.!is_observed.((sc,), [:x, :y, :z, :a, :b, :t, :q]))
@test is_time_dependent(sc)
@test constant_structure(sc)
@test variable_symbols(sc) == [:x, :y, :z]
@test parameter_symbols(sc) == [:a, :b]
@test independent_variable_symbols(sc) == [:t]

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
# make sure the constructor works
@test_nowarn SymbolCache([:x, :y])
@test independent_variable_symbols(sc) == []

sc2 = copy(sc)
@test sc.variables == sc2.variables
@test sc.parameters == sc2.parameters
@test sc.independent_variables == sc2.independent_variables
@test sc.independent_variables == sc2.independent_variables

0 comments on commit 8f786f3

Please sign in to comment.