Skip to content

Commit

Permalink
feat: add all_solvable_symbols and all_symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 26, 2023
1 parent 0e67d45 commit 2107adc
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ is_observed
observed
is_time_dependent
constant_structure
all_solvable_symbols
all_symbols
parameter_values
getp
setp
Expand Down
21 changes: 19 additions & 2 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct ExampleSolution
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
independent_variable::Union{Symbol,Nothing}
# mapping from observed variable to Expr to calculate its value
observed::Dict{Symbol,Expr}
u::Vector{Vector{Float64}}
p::Vector{Float64}
t::Vector{Float64}
Expand Down Expand Up @@ -86,9 +88,9 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::ExampleSolu
sys.independent_variable === nothing ? [] : [sys.independent_variable]
end

# this types accepts `Expr` for observed expressions involving state/parameter
# this type accepts `Expr` for observed expressions involving state/parameter/observed
# variables
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)

function SymbolicIndexingInterface.observed(sys::ExampleSolution, sym::Expr)
if is_time_dependent(sys)
Expand All @@ -109,6 +111,21 @@ function SymbolicIndexingInterface.is_time_dependent(sys::ExampleSolution)
end

SymbolicIndexingInterface.constant_structure(::ExampleSolution) = true

function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSolution)
return vcat(
collect(keys(sys.state_index)),
collect(keys(sys.observed)),
)
end

function SymbolicIndexingInterface.all_symbols(sys::ExampleSolution)
return vcat(
all_solvable_symbols(sys),
collect(keys(sys.parameter_index)),
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
)
end
```

Note that the method definitions are all assuming `constant_structure(p) == true`.
Expand Down
3 changes: 2 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ include("trait.jl")

export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
observed, is_time_dependent, constant_structure, symbolic_container
observed, is_time_dependent, constant_structure, symbolic_container, all_solvable_symbols,
all_symbols
include("interface.jl")

export SymbolCache
Expand Down
16 changes: 16 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,19 @@ Check if `sys` has a constant structure. Constant structure systems do not chang
number of variables or parameters over time.
"""
constant_structure(sys) = constant_structure(symbolic_container(sys))

"""
all_solvable_symbols(sys)
Return an array of all symbols in the system that can be solved for. This includes
observed variables, but not parameters or independent variables.
"""
all_solvable_symbols(sys) = all_solvable_symbols(symbolic_container(sys))

Check warning on line 120 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L120

Added line #L120 was not covered by tests

"""
all_symbols(sys)
Return an array of all symbols in the system. This includes parameters and independent
variables.
"""
all_symbols(sys) = all_symbols(symbolic_container(sys))

Check warning on line 128 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L128

Added line #L128 was not covered by tests
2 changes: 2 additions & 0 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ function is_time_dependent(sc::SymbolCache)
end
end
constant_structure(::SymbolCache) = true
all_solvable_symbols(sc::SymbolCache) = variable_symbols(sc)
all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))

Check warning on line 75 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L74-L75

Added lines #L74 - L75 were not covered by tests

function Base.copy(sc::SymbolCache)
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),
Expand Down
10 changes: 10 additions & 0 deletions test/example_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not
end
SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t)
SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static
SymbolicIndexingInterface.all_solvable_symbols(sys::SystemMockup) = sys.vars
function SymbolicIndexingInterface.all_symbols(sys::SystemMockup)
vcat(sys.vars, sys.params, independent_variable_symbols(sys))
end

sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)

Expand All @@ -79,6 +83,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
@test variable_symbols(sys) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == [:t]
@test all_solvable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]

sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)

Expand All @@ -93,6 +99,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
@test variable_symbols(sys) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == []
@test all_solvable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :x, :y, :z]

sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t)
@test !constant_structure(sys)
Expand All @@ -115,3 +123,5 @@ end
@test variable_symbols(sys, 3) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == [:t]
@test all_solvable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]
7 changes: 7 additions & 0 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
@test variable_symbols(sc) == [:x, :y, :z]
@test parameter_symbols(sc) == [:a, :b]
@test independent_variable_symbols(sc) == [:t]
@test all_solvable_symbols(sc) == [:x, :y, :z]
@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z]

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
# make sure the constructor works
@test_nowarn SymbolCache([:x, :y])

Expand All @@ -30,12 +33,16 @@ sc = SymbolCache()
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t]))
@test independent_variable_symbols(sc) == []
@test !is_time_dependent(sc)
@test all_solvable_symbols(sc) == []
@test all_symbols(sc) == []

sc = SymbolCache(nothing, nothing, :t)
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b]))
@test is_independent_variable(sc, :t)
@test independent_variable_symbols(sc) == [:t]
@test is_time_dependent(sc)
@test all_solvable_symbols(sc) == []
@test all_symbols(sc) == [:t]

sc2 = copy(sc)
@test sc.variables == sc2.variables
Expand Down

0 comments on commit 2107adc

Please sign in to comment.