diff --git a/docs/Project.toml b/docs/Project.toml index c81e0b6..1e8f527 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,9 @@ [deps] -SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [compat] Documenter = "0.27" diff --git a/docs/pages.jl b/docs/pages.jl index bb03773..62447c9 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -2,6 +2,7 @@ pages = [ "Home" => "index.md", - "API" => "api.md", "Tutorial" => "tutorial.md", + "Usage" => "usage.md", + "API" => "api.md", ] diff --git a/docs/src/api.md b/docs/src/api.md index 065a104..bc60cac 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -14,6 +14,10 @@ is_observed observed is_time_dependent constant_structure +all_variable_symbols +all_symbols +solvedvariables +allvariables parameter_values getp setp diff --git a/docs/src/tutorial.md b/docs/src/tutorial.md index b0ede67..5a7f797 100644 --- a/docs/src/tutorial.md +++ b/docs/src/tutorial.md @@ -41,6 +41,8 @@ struct ExampleSolution state_index::Dict{Symbol,Int} parameter_index::Dict{Symbol,Int} independent_variable::Union{Symbol,Nothing} + # mapping from observed variable to Expr to calculate its value + observed::Dict{Symbol,Expr} u::Vector{Vector{Float64}} p::Vector{Float64} t::Vector{Float64} @@ -86,9 +88,9 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::ExampleSolu sys.independent_variable === nothing ? [] : [sys.independent_variable] end -# this types accepts `Expr` for observed expressions involving state/parameter +# this type accepts `Expr` for observed expressions involving state/parameter/observed # variables -SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr +SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym) function SymbolicIndexingInterface.observed(sys::ExampleSolution, sym::Expr) if is_time_dependent(sys) @@ -109,6 +111,21 @@ function SymbolicIndexingInterface.is_time_dependent(sys::ExampleSolution) end SymbolicIndexingInterface.constant_structure(::ExampleSolution) = true + +function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSolution) + return vcat( + collect(keys(sys.state_index)), + collect(keys(sys.observed)), + ) +end + +function SymbolicIndexingInterface.all_symbols(sys::ExampleSolution) + return vcat( + all_solvable_symbols(sys), + collect(keys(sys.parameter_index)), + sys.independent_variable === nothing ? Symbol[] : sys.independent_variable + ) +end ``` Note that the method definitions are all assuming `constant_structure(p) == true`. diff --git a/docs/src/usage.md b/docs/src/usage.md new file mode 100644 index 0000000..d45d504 --- /dev/null +++ b/docs/src/usage.md @@ -0,0 +1,147 @@ +# Using the SymbolicIndexingInterface + +This tutorial will cover ways to use the interface for types that implement it. +Consider the following example: + +```@example Usage +using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Plots + +@parameters σ ρ β +@variables t x(t) y(t) z(t) w(t) +D = Differential(t) + +eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z, + w ~ x + y + z] + +@named sys = ODESystem(eqs) +sys = structural_simplify(sys) +``` + +The system has 4 state variables, 3 parameters and one observed variable: +```@example Usage +observed(sys) +``` + +Solving the system, +```@example Usage +u0 = [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0] + +p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3] + +tspan = (0.0, 100.0) +prob = ODEProblem(sys, u0, tspan, p, jac = true) +sol = solve(prob, Tsit5()) +``` + +We can obtain the timeseries of any time-dependent variable using `getindex` +```@example Usage +sol[x] +``` + +This also works for arrays or tuples of variables, including observed quantities and +independent variables, for interpolating solutions, and plotting: +```@example Usage +sol[[x, y]] +``` + +```@example Usage +sol[(t, w)] +``` + +```@example Usage +sol(1.3, idxs=x) +``` + +```@example Usage +sol(1.3, idxs=[x, w]) +``` + +```@example Usage +sol(1.3, idxs=[:y, :z]) +``` + +```@example Usage +plot(sol, idxs=x) +``` + +If necessary, `Symbol`s can be used to refer to variables. This is only valid for +symbolic variables for which [`hasname`](@ref) returns `true`. The `Symbol` used must +match the one returned by [`getname`](@ref) for the variable. +```@example Usage +hasname(x) +``` + +```@example Usage +getname(x) +``` + +```@example Usage +sol[(:x, :w)] +``` + +Note how when indexing with an array, the returned type is a `Vector{Array{Float64}}`, +and when using a `Tuple`, the returned type is `Vector{Tuple{Float64, Float64}}`. +To obtain the value of all state variables, we can use the shorthand: +```@example Usage +sol[solvedvariables] # equivalent to sol[variable_symbols(sol)] +``` + +This does not include the observed variable `w`. To include observed variables in the +output, the following shorthand is used: +```@example Usage +sol[allvariables] # equivalent to sol[all_variable_symbols(sol)] +``` + +Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref). + +```@example Usage +σ_getter = getp(sys, σ) +σ_getter(sol) # can also pass `prob` +``` + +Note that this also supports arrays/tuples of parameter symbols: + +```@example Usage +σ_ρ_getter = getp(sys, (σ, ρ)) +σ_ρ_getter(sol) +``` + +Now suppose the system has to be solved with a different value of the parameter `β`. + +```@example Usage +β_setter = setp(sys, β) +β_setter(prob, 3) +``` + +The updated parameter values can be checked using [`parameter_values`](@ref). + +```@example Usage +parameter_values(prob) +``` + +Solving the new system, note that the parameter getter functions still work on the new +solution object. + +```@example Usage +sol2 = solve(prob, Tsit5()) +σ_getter(sol) +``` + +```@example Usage +σ_ρ_getter(sol) +``` + +To set the entire parameter vector at once, [`parameter_values`](@ref) can be used +(note the usage of broadcasted assignment). + +```@example Usage +parameter_values(prob) .= [29.0, 11.0, 2.5] +parameter_values(prob) +``` diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 7a33e0d..32b6f9e 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -5,7 +5,8 @@ include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, - observed, is_time_dependent, constant_structure, symbolic_container + observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols, + all_symbols, solvedvariables, allvariables include("interface.jl") export SymbolCache diff --git a/src/interface.jl b/src/interface.jl index bcce0b2..67fab0e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -33,6 +33,10 @@ variable_index(sys, sym, i) = variable_index(symbolic_container(sys), sym, i) Return a vector of the symbolic variables being solved for in the system `sys`. If `constant_structure(sys) == false` this accepts an additional parameter indicating the current time index. The returned vector should not be mutated. + +For types that implement `Base.getindex` with symbolic indices using this interface, +The shorthand `sys[solvedvariables]` can be used as shorthand for +`sys[variable_symbols(sys)]`. See: [`solvedvariables`](@ref). """ variable_symbols(sys) = variable_symbols(symbolic_container(sys)) variable_symbols(sys, i) = variable_symbols(symbolic_container(sys), i) @@ -110,3 +114,46 @@ Check if `sys` has a constant structure. Constant structure systems do not chang number of variables or parameters over time. """ constant_structure(sys) = constant_structure(symbolic_container(sys)) + +""" + all_variable_symbols(sys) + +Return a vector of variable symbols in the system, including observed quantities. + +For types that implement `Base.getindex` with symbolic indices using this interface, +The shorthand `sys[allvariables]` can be used as shorthand for +`sys[all_variable_symbols(sys)]`. See: [`allvariables`](@ref). +""" +all_variable_symbols(sys) = all_variable_symbols(symbolic_container(sys)) + +""" + all_symbols(sys) + +Return an array of all symbols in the system. This includes parameters and independent +variables. +""" +all_symbols(sys) = all_symbols(symbolic_container(sys)) + +struct SolvedVariables end + +""" + const solvedvariables = SolvedVariables() + +This singleton is used as a shortcut to allow indexing all solution variables +(excluding observed quantities). It has a [`symbolic_type`](@ref) of +[`ScalarSymbolic`](@ref). See: [`variable_symbols`](@ref). +""" +const solvedvariables = SolvedVariables() +symbolic_type(::Type{SolvedVariables}) = ScalarSymbolic() + +struct AllVariables end + +""" + const allvariables = AllVariables() + +This singleton is used as a shortcut to allow indexing all solution variables +(including observed quantities). It has a [`symbolic_type`](@ref) of +[`ScalarSymbolic`](@ref). See [`all_variable_symbols`](@ref). +""" +const allvariables = AllVariables() +symbolic_type(::Type{AllVariables}) = ScalarSymbolic() diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index c2d89fa..8df5dda 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -71,6 +71,8 @@ function is_time_dependent(sc::SymbolCache) end end constant_structure(::SymbolCache) = true +all_variable_symbols(sc::SymbolCache) = variable_symbols(sc) +all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc)) function Base.copy(sc::SymbolCache) return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), diff --git a/test/example_test.jl b/test/example_test.jl index 260e8aa..965d1b4 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -53,6 +53,10 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not end SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t) SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static +SymbolicIndexingInterface.all_variable_symbols(sys::SystemMockup) = sys.vars +function SymbolicIndexingInterface.all_symbols(sys::SystemMockup) + vcat(sys.vars, sys.params, independent_variable_symbols(sys)) +end sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @@ -79,6 +83,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test variable_symbols(sys) == [:x, :y, :z] @test parameter_symbols(sys) == [:a, :b, :c] @test independent_variable_symbols(sys) == [:t] +@test all_variable_symbols(sys) == [:x, :y, :z] +@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z] sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @@ -93,6 +99,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) @test variable_symbols(sys) == [:x, :y, :z] @test parameter_symbols(sys) == [:a, :b, :c] @test independent_variable_symbols(sys) == [] +@test all_variable_symbols(sys) == [:x, :y, :z] +@test sort(all_symbols(sys)) == [:a, :b, :c, :x, :y, :z] sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t) @test !constant_structure(sys) @@ -115,3 +123,5 @@ end @test variable_symbols(sys, 3) == [:x, :y, :z] @test parameter_symbols(sys) == [:a, :b, :c] @test independent_variable_symbols(sys) == [:t] +@test all_variable_symbols(sys) == [:x, :y, :z] +@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z] diff --git a/test/fallback_test.jl b/test/fallback_test.jl index 4c354bf..f390178 100644 --- a/test/fallback_test.jl +++ b/test/fallback_test.jl @@ -22,3 +22,5 @@ all_syms = [:x, :y, :z, :a, :b, :t] @test variable_symbols(sys) == variable_symbols(sc) @test parameter_symbols(sys) == parameter_symbols(sc) @test independent_variable_symbols(sys) == independent_variable_symbols(sc) +@test all_variable_symbols(sys) == variable_symbols(sc) +@test all_symbols(sys) == all_symbols(sc) diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index 8754b4f..39e388d 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -14,9 +14,12 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) @test variable_symbols(sc) == [:x, :y, :z] @test parameter_symbols(sc) == [:a, :b] @test independent_variable_symbols(sc) == [:t] +@test all_variable_symbols(sc) == [:x, :y, :z] +@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z] sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) +@test sort(all_symbols(sc)) == [:a, :b, :x, :y] # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) @@ -30,12 +33,16 @@ sc = SymbolCache() @test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t])) @test independent_variable_symbols(sc) == [] @test !is_time_dependent(sc) +@test all_variable_symbols(sc) == [] +@test all_symbols(sc) == [] sc = SymbolCache(nothing, nothing, :t) @test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b])) @test is_independent_variable(sc, :t) @test independent_variable_symbols(sc) == [:t] @test is_time_dependent(sc) +@test all_variable_symbols(sc) == [] +@test all_symbols(sc) == [:t] sc2 = copy(sc) @test sc.variables == sc2.variables