From 2107adc4c41e2d71942f80e0467409fd60c5e931 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 17 Dec 2023 13:11:27 +0530 Subject: [PATCH] feat: add `all_solvable_symbols` and `all_symbols` --- docs/src/api.md | 2 ++ docs/src/tutorial.md | 21 +++++++++++++++++++-- src/SymbolicIndexingInterface.jl | 3 ++- src/interface.jl | 16 ++++++++++++++++ src/symbol_cache.jl | 2 ++ test/example_test.jl | 10 ++++++++++ test/symbol_cache_test.jl | 7 +++++++ 7 files changed, 58 insertions(+), 3 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 065a1042..dc9b0d1a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -14,6 +14,8 @@ is_observed observed is_time_dependent constant_structure +all_solvable_symbols +all_symbols parameter_values getp setp diff --git a/docs/src/tutorial.md b/docs/src/tutorial.md index b0ede675..5a7f7973 100644 --- a/docs/src/tutorial.md +++ b/docs/src/tutorial.md @@ -41,6 +41,8 @@ struct ExampleSolution state_index::Dict{Symbol,Int} parameter_index::Dict{Symbol,Int} independent_variable::Union{Symbol,Nothing} + # mapping from observed variable to Expr to calculate its value + observed::Dict{Symbol,Expr} u::Vector{Vector{Float64}} p::Vector{Float64} t::Vector{Float64} @@ -86,9 +88,9 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::ExampleSolu sys.independent_variable === nothing ? [] : [sys.independent_variable] end -# this types accepts `Expr` for observed expressions involving state/parameter +# this type accepts `Expr` for observed expressions involving state/parameter/observed # variables -SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr +SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym) function SymbolicIndexingInterface.observed(sys::ExampleSolution, sym::Expr) if is_time_dependent(sys) @@ -109,6 +111,21 @@ function SymbolicIndexingInterface.is_time_dependent(sys::ExampleSolution) end SymbolicIndexingInterface.constant_structure(::ExampleSolution) = true + +function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSolution) + return vcat( + collect(keys(sys.state_index)), + collect(keys(sys.observed)), + ) +end + +function SymbolicIndexingInterface.all_symbols(sys::ExampleSolution) + return vcat( + all_solvable_symbols(sys), + collect(keys(sys.parameter_index)), + sys.independent_variable === nothing ? Symbol[] : sys.independent_variable + ) +end ``` Note that the method definitions are all assuming `constant_structure(p) == true`. diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 7a33e0d5..71f5b4fe 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -5,7 +5,8 @@ include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, - observed, is_time_dependent, constant_structure, symbolic_container + observed, is_time_dependent, constant_structure, symbolic_container, all_solvable_symbols, + all_symbols include("interface.jl") export SymbolCache diff --git a/src/interface.jl b/src/interface.jl index bcce0b20..bd43c53b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -110,3 +110,19 @@ Check if `sys` has a constant structure. Constant structure systems do not chang number of variables or parameters over time. """ constant_structure(sys) = constant_structure(symbolic_container(sys)) + +""" + all_solvable_symbols(sys) + +Return an array of all symbols in the system that can be solved for. This includes +observed variables, but not parameters or independent variables. +""" +all_solvable_symbols(sys) = all_solvable_symbols(symbolic_container(sys)) + +""" + all_symbols(sys) + +Return an array of all symbols in the system. This includes parameters and independent +variables. +""" +all_symbols(sys) = all_symbols(symbolic_container(sys)) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index c2d89fac..980ba1f0 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -71,6 +71,8 @@ function is_time_dependent(sc::SymbolCache) end end constant_structure(::SymbolCache) = true +all_solvable_symbols(sc::SymbolCache) = variable_symbols(sc) +all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc)) function Base.copy(sc::SymbolCache) return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), diff --git a/test/example_test.jl b/test/example_test.jl index 260e8aa1..8f4dc0df 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -53,6 +53,10 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not end SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t) SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static +SymbolicIndexingInterface.all_solvable_symbols(sys::SystemMockup) = sys.vars +function SymbolicIndexingInterface.all_symbols(sys::SystemMockup) + vcat(sys.vars, sys.params, independent_variable_symbols(sys)) +end sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @@ -79,6 +83,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test variable_symbols(sys) == [:x, :y, :z] @test parameter_symbols(sys) == [:a, :b, :c] @test independent_variable_symbols(sys) == [:t] +@test all_solvable_symbols(sys) == [:x, :y, :z] +@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z] sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @@ -93,6 +99,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @test variable_symbols(sys) == [:x, :y, :z] @test parameter_symbols(sys) == [:a, :b, :c] @test independent_variable_symbols(sys) == [] +@test all_solvable_symbols(sys) == [:x, :y, :z] +@test sort(all_symbols(sys)) == [:a, :b, :c, :x, :y, :z] sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t) @test !constant_structure(sys) @@ -115,3 +123,5 @@ end @test variable_symbols(sys, 3) == [:x, :y, :z] @test parameter_symbols(sys) == [:a, :b, :c] @test independent_variable_symbols(sys) == [:t] +@test all_solvable_symbols(sys) == [:x, :y, :z] +@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z] diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index 8754b4f1..d63a510c 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -14,9 +14,12 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) @test variable_symbols(sc) == [:x, :y, :z] @test parameter_symbols(sc) == [:a, :b] @test independent_variable_symbols(sc) == [:t] +@test all_solvable_symbols(sc) == [:x, :y, :z] +@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z] sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) +@test sort(all_symbols(sc)) == [:a, :b, :x, :y] # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) @@ -30,12 +33,16 @@ sc = SymbolCache() @test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t])) @test independent_variable_symbols(sc) == [] @test !is_time_dependent(sc) +@test all_solvable_symbols(sc) == [] +@test all_symbols(sc) == [] sc = SymbolCache(nothing, nothing, :t) @test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b])) @test is_independent_variable(sc, :t) @test independent_variable_symbols(sc) == [:t] @test is_time_dependent(sc) +@test all_solvable_symbols(sc) == [] +@test all_symbols(sc) == [:t] sc2 = copy(sc) @test sc.variables == sc2.variables