diff --git a/src/interface.jl b/src/interface.jl index f07de3b9..7609b65b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -16,7 +16,7 @@ function has_static_variable end variable_index(sys, sym, [i]) Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. If -[`has_static_variable`](@ref) is `false`, this accepts the current timestep as an +[`has_static_variable`](@ref) is `false`, this accepts the current time index as an additional parameter `i`. """ function variable_index end @@ -57,12 +57,14 @@ Check whether the given `sym` is an observed value in `sys`. function is_observed end """ - observed(sys, sym) + observed(sys, sym, [states]) Return the observed function of the given `sym` in `sys`. The returned function should have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and parameter vector. If `istimedependent(sys) == true`, the function should accept -the current time `t` as its third parameter. +the current time `t` as its third parameter. If `constant_structure(sys) == false`, +accept a third parameter which can either be a vector of symbols indicating the order +of states or a time index which identifies the order of states. See also: [`is_time_dependent`](@ref), [`has_static_variable`](@ref) """ diff --git a/test/example_test.jl b/test/example_test.jl index e1e8a61f..2adccf0a 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -8,7 +8,7 @@ end SymbolicIndexingInterface.is_variable(sys::SymbolCache, sym) = sym in sys.vars function SymbolicIndexingInterface.variable_index(sys::SymbolCache, sym, t = nothing) if !has_static_variable(sys) && t === nothing - error("timestep must be present") + error("time index must be present") end findfirst(isequal(sym), current_state(sys, t)) end @@ -25,18 +25,22 @@ end function SymbolicIndexingInterface.is_observed(sys::SymbolCache, sym) is_variable(sys, sym) || is_parameter(sys, sym) || is_independent_variable(sys, sym) end -function SymbolicIndexingInterface.observed(sys::SymbolCache, sym) +function SymbolicIndexingInterface.observed(sys::SymbolCache, sym, states = nothing) + if !has_static_variable(sys) && states === nothing + error("States required") + end + states = states isa Vector ? states : current_state(sys, states) if is_variable(sys, sym) return is_time_dependent(sys) ? - (u, p, t) -> [u[variable_index(sys, sym, i)] * i for i in t] : - (u, p) -> [u[variable_index(sys, sym)] * i for i in 1:5] + (u, p, t) -> u[findfirst(isequal(sym), states)] : + (u, p) -> u[findfirst(isequal(sym), states)] end idx = parameter_index(sys, sym) if idx !== nothing return is_time_dependent(sys) ? (u, p, t) -> p[idx] : (u, p) -> p[idx] end if is_independent_variable(sys, sym) - return is_time_dependent(sys) ? (u, p, t) -> t : (u, p) -> 1:5 + return is_time_dependent(sys) ? (u, p, t) -> t : (u, p) -> 1 end end SymbolicIndexingInterface.is_time_dependent(sys::SymbolCache) = isequal(sys.indepvar, :t) @@ -56,13 +60,13 @@ sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], :t) @test is_independent_variable(sys, :t) @test all(.!is_independent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r])) @test all(is_observed.((sys,), [:x, :y, :z, :a, :b, :c, :t])) -@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [1, 2, 3, 4]) -@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [2, 4, 6, 8]) -@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [3, 6, 9, 12]) -@test observed(sys, :a)(1:3, 4:6, 1:4) == 4 -@test observed(sys, :b)(1:3, 4:6, 1:4) == 5 -@test observed(sys, :c)(1:3, 4:6, 1:4) == 6 -@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4 +@test all(observed(sys, :x)(1:3, 4:6, 1.5) .== 1) +@test all(observed(sys, :y)(1:3, 4:6, 1.5) .== 2) +@test all(observed(sys, :z)(1:3, 4:6, 1.5) .== 3) +@test observed(sys, :a)(1:3, 4:6, 1.5) == 4 +@test observed(sys, :b)(1:3, 4:6, 1.5) == 5 +@test observed(sys, :c)(1:3, 4:6, 1.5) == 6 +@test observed(sys, :t)(1:3, 4:6, 1.5) == 1.5 @test is_time_dependent(sys) @test constant_structure(sys) @test has_static_variable(sys) @@ -70,9 +74,9 @@ sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], :t) sys = SymbolCache(true, [:x, :y, :z], [:a, :b, :c], nothing) @test !is_time_dependent(sys) -@test all(observed(sys, :x)(1:3, 4:6) .== [1, 2, 3, 4, 5]) -@test all(observed(sys, :y)(1:3, 4:6) .== [2, 4, 6, 8, 10]) -@test all(observed(sys, :z)(1:3, 4:6) .== [3, 6, 9, 12, 15]) +@test all(observed(sys, :x)(1.0:3.0, 4:6) .== 1.0) +@test all(observed(sys, :y)(1.0:3.0, 4:6) .== 2.0) +@test all(observed(sys, :z)(1.0:3.0, 4:6) .== 3.0) @test observed(sys, :a)(1:3, 4:6) == 4 @test observed(sys, :b)(1:3, 4:6) == 5 @test observed(sys, :c)(1:3, 4:6) == 6 @@ -82,13 +86,14 @@ sys = SymbolCache(false, [:x, :y, :z], [:a, :b, :c], :t) @test !has_static_variable(sys) for variable in [:x, :y, :z, :a, :b, :c, :t] @test_throws ErrorException variable_index(sys, variable) + @test_throws ErrorException observed(sys, variable) end @test all(variable_index.((sys,), [:z, :y, :x], 1) .== [1, 3, 2]) @test all(variable_index.((sys,), [:a, :b, :c, :t], 1) .== nothing) -@test all(observed(sys, :x)(1:3, 4:6, 1:4) .== [2, 6, 3, 8]) -@test all(observed(sys, :y)(1:4, 4:6, 1:4) .== [3, 2, 6, 12]) -@test all(observed(sys, :z)(1:4, 4:6, 1:4) .== [1, 4, 9, 4]) -@test observed(sys, :a)(1:3, 4:6, 1:4) == 4 -@test observed(sys, :b)(1:3, 4:6, 1:4) == 5 -@test observed(sys, :c)(1:3, 4:6, 1:4) == 6 -@test observed(sys, :t)(1:3, 4:6, 1:4) == 1:4 +@test all(observed(sys, :x, 2)(1:3, 4:6, 1.5) .== 3) +@test all(observed(sys, :y, 2)(1:3, 4:6, 1.5) .== 1) +@test all(observed(sys, :z, 2)(1:3, 4:6, 1.5) .== 2) +@test observed(sys, :a, 2)(1:3, 4:6, 1.5) == 4 +@test observed(sys, :b, 2)(1:3, 4:6, 1.5) == 5 +@test observed(sys, :c, 2)(1:3, 4:6, 1.5) == 6 +@test observed(sys, :t, 2)(1:3, 4:6, 1.5) == 1.5