From 8f786f39db1533d5af520e4798c953f93b85da5d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Nov 2023 14:10:04 +0530 Subject: [PATCH] feat: add functions to get vector of symbols from system --- docs/src/api.md | 4 +++- src/SymbolicIndexingInterface.jl | 6 +++--- src/interface.jl | 26 ++++++++++++++++++++++---- src/symbol_cache.jl | 12 ++++++------ test/example_test.jl | 26 +++++++++++++++++++++++--- test/symbol_cache_test.jl | 7 +++++-- 6 files changed, 62 insertions(+), 19 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index ec93493..f08f7a3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -3,10 +3,12 @@ ```@docs is_variable variable_index +variable_symbols is_parameter parameter_index +parameter_symbols is_independent_variable -current_state +independent_variable_symbols is_observed observed is_time_dependent diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index e364d46..54bd832 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -5,9 +5,9 @@ using Requires export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type include("trait.jl") -export is_variable, variable_index, is_parameter, parameter_index, - is_independent_variable, current_state, is_observed, observed, is_time_dependent, - constant_structure +export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, + parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, + observed, is_time_dependent, constant_structure include("interface.jl") export SymbolCache diff --git a/src/interface.jl b/src/interface.jl index 3030c1e..af56adf 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -14,6 +14,15 @@ additional parameter `i`. """ function variable_index end +""" + variable_symbols(sys, [i]) + +Return a vector of the symbolic variables being solved for in the system `sys`. If +`constant_structure(sys) == false` this accepts an additional parameter indicating +the current time index. The returned vector should not be mutated. +""" +function variable_symbols end + """ is_parameter(sys, sym) @@ -28,19 +37,28 @@ Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise. """ function parameter_index end +""" + parameter_symbols(sys) + +Return a vector of the symbolic parameters of the given system `sys`. The returned +vector should not be mutated. +""" +function parameter_symbols end + """ is_independent_variable(sys, sym) -Check whether the given `sym` is an independent variable in `sys`. +Check whether the given `sym` is an independent variable in `sys`. The returned vector +should not be mutated. """ function is_independent_variable end """ - current_state(sys, i) + independent_variable_symbols(sys) -Return a vector of symbols representing the state variables at timestep `i`. +Return a vector of the symbolic independent variables of the given system `sys`. """ -function current_state end +function independent_variable_symbols end """ is_observed(sys, sym) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 164067c..06165d2 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -20,18 +20,18 @@ end is_variable(sc::SymbolCache, sym) = sym in sc.variables variable_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.variables) +variable_symbols(sc::SymbolCache, i = nothing) = sc.variables is_parameter(sc::SymbolCache, sym) = sym in sc.parameters parameter_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.parameters) +parameter_symbols(sc::SymbolCache) = sc.parameters is_independent_variable(sc::SymbolCache, sym) = sym in sc.independent_variables -current_state(sc::SymbolCache) = sc.variables +independent_variable_symbols(sc::SymbolCache) = sc.independent_variables is_observed(sc::SymbolCache, sym) = false is_time_dependent(sc::SymbolCache) = !isempty(sc.independent_variables) constant_structure(::SymbolCache) = true function Base.copy(sc::SymbolCache) - return SymbolCache( - sc.variables === nothing ? nothing : copy(sc.variables), + return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), sc.parameters === nothing ? nothing : copy(sc.parameters), - sc.independent_variables === nothing ? nothing : copy(sc.independent_variables), - ) -end \ No newline at end of file + sc.independent_variables === nothing ? nothing : copy(sc.independent_variables)) +end diff --git a/test/example_test.jl b/test/example_test.jl index f5cb633..260e8aa 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -10,18 +10,26 @@ function SymbolicIndexingInterface.variable_index(sys::SystemMockup, sym, t = no if !constant_structure(sys) && t === nothing error("time index must be present") end - findfirst(isequal(sym), current_state(sys, t)) + findfirst(isequal(sym), variable_symbols(sys, t)) end -function SymbolicIndexingInterface.current_state(sys::SystemMockup, i) +function SymbolicIndexingInterface.variable_symbols(sys::SystemMockup, i = nothing) return constant_structure(sys) ? sys.vars : circshift(sys.vars, i) end SymbolicIndexingInterface.is_parameter(sys::SystemMockup, sym) = sym in sys.params function SymbolicIndexingInterface.parameter_index(sys::SystemMockup, sym) findfirst(isequal(sym), sys.params) end +SymbolicIndexingInterface.parameter_symbols(sys::SystemMockup) = sys.params function SymbolicIndexingInterface.is_independent_variable(sys::SystemMockup, sym) sys.indepvar !== nothing && isequal(sym, sys.indepvar) end +function SymbolicIndexingInterface.independent_variable_symbols(sys::SystemMockup) + if sys.indepvar === nothing + return [] + else + return [sys.indepvar] + end +end function SymbolicIndexingInterface.is_observed(sys::SystemMockup, sym) is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym) end @@ -29,7 +37,7 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not if !constant_structure(sys) && states === nothing error("States required") end - states = states isa Vector ? states : current_state(sys, states) + states = states isa Vector ? states : variable_symbols(sys, states) if is_variable(sys, sym) return is_time_dependent(sys) ? (u, p, t) -> u[findfirst(isequal(sym), states)] : @@ -68,6 +76,9 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test observed(sys, :t)(1:3, 4:6, 1.5) == 1.5 @test is_time_dependent(sys) @test constant_structure(sys) +@test variable_symbols(sys) == [:x, :y, :z] +@test parameter_symbols(sys) == [:a, :b, :c] +@test independent_variable_symbols(sys) == [:t] sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @@ -79,6 +90,9 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @test observed(sys, :b)(1:3, 4:6) == 5 @test observed(sys, :c)(1:3, 4:6) == 6 @test constant_structure(sys) +@test variable_symbols(sys) == [:x, :y, :z] +@test parameter_symbols(sys) == [:a, :b, :c] +@test independent_variable_symbols(sys) == [] sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t) @test !constant_structure(sys) @@ -95,3 +109,9 @@ end @test observed(sys, :b, 2)(1:3, 4:6, 1.5) == 5 @test observed(sys, :c, 2)(1:3, 4:6, 1.5) == 6 @test observed(sys, :t, 2)(1:3, 4:6, 1.5) == 1.5 +@test_throws Exception variable_symbols(sys) +@test variable_symbols(sys, 1) == [:z, :x, :y] +@test variable_symbols(sys, 2) == [:y, :z, :x] +@test variable_symbols(sys, 3) == [:x, :y, :z] +@test parameter_symbols(sys) == [:a, :b, :c] +@test independent_variable_symbols(sys) == [:t] diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index d5b3dfd..be1b6d9 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -8,17 +8,20 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) @test parameter_index.((sc,), [:a, :b, :x]) == [1, 2, nothing] @test is_independent_variable(sc, :t) @test all(.!is_independent_variable.((sc,), [:x, :y, :z, :a, :b, :q])) -@test current_state(sc) == [:x, :y, :z] @test all(.!is_observed.((sc,), [:x, :y, :z, :a, :b, :t, :q])) @test is_time_dependent(sc) @test constant_structure(sc) +@test variable_symbols(sc) == [:x, :y, :z] +@test parameter_symbols(sc) == [:a, :b] +@test independent_variable_symbols(sc) == [:t] sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) +@test independent_variable_symbols(sc) == [] sc2 = copy(sc) @test sc.variables == sc2.variables @test sc.parameters == sc2.parameters -@test sc.independent_variables == sc2.independent_variables \ No newline at end of file +@test sc.independent_variables == sc2.independent_variables