Skip to content

Commit

Permalink
Merge pull request #18 from SciML/as/all-symbols
Browse files Browse the repository at this point in the history
feat: add `all_solvable_symbols` and `all_symbols`
  • Loading branch information
ChrisRackauckas authored Dec 27, 2023
2 parents 0e67d45 + ae50f9d commit 0ad8caa
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 5 deletions.
5 changes: 4 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[deps]
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[compat]
Documenter = "0.27"
Expand Down
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

pages = [
"Home" => "index.md",
"API" => "api.md",
"Tutorial" => "tutorial.md",
"Usage" => "usage.md",
"API" => "api.md",
]
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ is_observed
observed
is_time_dependent
constant_structure
all_variable_symbols
all_symbols
solvedvariables
allvariables
parameter_values
getp
setp
Expand Down
21 changes: 19 additions & 2 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct ExampleSolution
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
independent_variable::Union{Symbol,Nothing}
# mapping from observed variable to Expr to calculate its value
observed::Dict{Symbol,Expr}
u::Vector{Vector{Float64}}
p::Vector{Float64}
t::Vector{Float64}
Expand Down Expand Up @@ -86,9 +88,9 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::ExampleSolu
sys.independent_variable === nothing ? [] : [sys.independent_variable]
end

# this types accepts `Expr` for observed expressions involving state/parameter
# this type accepts `Expr` for observed expressions involving state/parameter/observed
# variables
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)

function SymbolicIndexingInterface.observed(sys::ExampleSolution, sym::Expr)
if is_time_dependent(sys)
Expand All @@ -109,6 +111,21 @@ function SymbolicIndexingInterface.is_time_dependent(sys::ExampleSolution)
end

SymbolicIndexingInterface.constant_structure(::ExampleSolution) = true

function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSolution)
return vcat(
collect(keys(sys.state_index)),
collect(keys(sys.observed)),
)
end

function SymbolicIndexingInterface.all_symbols(sys::ExampleSolution)
return vcat(
all_solvable_symbols(sys),
collect(keys(sys.parameter_index)),
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
)
end
```

Note that the method definitions are all assuming `constant_structure(p) == true`.
Expand Down
147 changes: 147 additions & 0 deletions docs/src/usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Using the SymbolicIndexingInterface

This tutorial will cover ways to use the interface for types that implement it.
Consider the following example:

```@example Usage
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Plots
@parameters σ ρ β
@variables t x(t) y(t) z(t) w(t)
D = Differential(t)
eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z,
w ~ x + y + z]
@named sys = ODESystem(eqs)
sys = structural_simplify(sys)
```

The system has 4 state variables, 3 parameters and one observed variable:
```@example Usage
observed(sys)
```

Solving the system,
```@example Usage
u0 = [D(x) => 2.0,
x => 1.0,
y => 0.0,
z => 0.0]
p = [σ => 28.0,
ρ => 10.0,
β => 8 / 3]
tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())
```

We can obtain the timeseries of any time-dependent variable using `getindex`
```@example Usage
sol[x]
```

This also works for arrays or tuples of variables, including observed quantities and
independent variables, for interpolating solutions, and plotting:
```@example Usage
sol[[x, y]]
```

```@example Usage
sol[(t, w)]
```

```@example Usage
sol(1.3, idxs=x)
```

```@example Usage
sol(1.3, idxs=[x, w])
```

```@example Usage
sol(1.3, idxs=[:y, :z])
```

```@example Usage
plot(sol, idxs=x)
```

If necessary, `Symbol`s can be used to refer to variables. This is only valid for
symbolic variables for which [`hasname`](@ref) returns `true`. The `Symbol` used must
match the one returned by [`getname`](@ref) for the variable.
```@example Usage
hasname(x)
```

```@example Usage
getname(x)
```

```@example Usage
sol[(:x, :w)]
```

Note how when indexing with an array, the returned type is a `Vector{Array{Float64}}`,
and when using a `Tuple`, the returned type is `Vector{Tuple{Float64, Float64}}`.
To obtain the value of all state variables, we can use the shorthand:
```@example Usage
sol[solvedvariables] # equivalent to sol[variable_symbols(sol)]
```

This does not include the observed variable `w`. To include observed variables in the
output, the following shorthand is used:
```@example Usage
sol[allvariables] # equivalent to sol[all_variable_symbols(sol)]
```

Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref).

```@example Usage
σ_getter = getp(sys, σ)
σ_getter(sol) # can also pass `prob`
```

Note that this also supports arrays/tuples of parameter symbols:

```@example Usage
σ_ρ_getter = getp(sys, (σ, ρ))
σ_ρ_getter(sol)
```

Now suppose the system has to be solved with a different value of the parameter `β`.

```@example Usage
β_setter = setp(sys, β)
β_setter(prob, 3)
```

The updated parameter values can be checked using [`parameter_values`](@ref).

```@example Usage
parameter_values(prob)
```

Solving the new system, note that the parameter getter functions still work on the new
solution object.

```@example Usage
sol2 = solve(prob, Tsit5())
σ_getter(sol)
```

```@example Usage
σ_ρ_getter(sol)
```

To set the entire parameter vector at once, [`parameter_values`](@ref) can be used
(note the usage of broadcasted assignment).

```@example Usage
parameter_values(prob) .= [29.0, 11.0, 2.5]
parameter_values(prob)
```
3 changes: 2 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ include("trait.jl")

export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
observed, is_time_dependent, constant_structure, symbolic_container
observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols,
all_symbols, solvedvariables, allvariables
include("interface.jl")

export SymbolCache
Expand Down
47 changes: 47 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ variable_index(sys, sym, i) = variable_index(symbolic_container(sys), sym, i)
Return a vector of the symbolic variables being solved for in the system `sys`. If
`constant_structure(sys) == false` this accepts an additional parameter indicating
the current time index. The returned vector should not be mutated.
For types that implement `Base.getindex` with symbolic indices using this interface,
The shorthand `sys[solvedvariables]` can be used as shorthand for
`sys[variable_symbols(sys)]`. See: [`solvedvariables`](@ref).
"""
variable_symbols(sys) = variable_symbols(symbolic_container(sys))
variable_symbols(sys, i) = variable_symbols(symbolic_container(sys), i)
Expand Down Expand Up @@ -110,3 +114,46 @@ Check if `sys` has a constant structure. Constant structure systems do not chang
number of variables or parameters over time.
"""
constant_structure(sys) = constant_structure(symbolic_container(sys))

"""
all_variable_symbols(sys)
Return a vector of variable symbols in the system, including observed quantities.
For types that implement `Base.getindex` with symbolic indices using this interface,
The shorthand `sys[allvariables]` can be used as shorthand for
`sys[all_variable_symbols(sys)]`. See: [`allvariables`](@ref).
"""
all_variable_symbols(sys) = all_variable_symbols(symbolic_container(sys))

"""
all_symbols(sys)
Return an array of all symbols in the system. This includes parameters and independent
variables.
"""
all_symbols(sys) = all_symbols(symbolic_container(sys))

struct SolvedVariables end

"""
const solvedvariables = SolvedVariables()
This singleton is used as a shortcut to allow indexing all solution variables
(excluding observed quantities). It has a [`symbolic_type`](@ref) of
[`ScalarSymbolic`](@ref). See: [`variable_symbols`](@ref).
"""
const solvedvariables = SolvedVariables()
symbolic_type(::Type{SolvedVariables}) = ScalarSymbolic()

struct AllVariables end

"""
const allvariables = AllVariables()
This singleton is used as a shortcut to allow indexing all solution variables
(including observed quantities). It has a [`symbolic_type`](@ref) of
[`ScalarSymbolic`](@ref). See [`all_variable_symbols`](@ref).
"""
const allvariables = AllVariables()
symbolic_type(::Type{AllVariables}) = ScalarSymbolic()
2 changes: 2 additions & 0 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ function is_time_dependent(sc::SymbolCache)
end
end
constant_structure(::SymbolCache) = true
all_variable_symbols(sc::SymbolCache) = variable_symbols(sc)
all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))

function Base.copy(sc::SymbolCache)
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),
Expand Down
10 changes: 10 additions & 0 deletions test/example_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not
end
SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t)
SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static
SymbolicIndexingInterface.all_variable_symbols(sys::SystemMockup) = sys.vars
function SymbolicIndexingInterface.all_symbols(sys::SystemMockup)
vcat(sys.vars, sys.params, independent_variable_symbols(sys))
end

sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)

Expand All @@ -79,6 +83,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
@test variable_symbols(sys) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == [:t]
@test all_variable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]

sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)

Expand All @@ -93,6 +99,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
@test variable_symbols(sys) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == []
@test all_variable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :x, :y, :z]

sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t)
@test !constant_structure(sys)
Expand All @@ -115,3 +123,5 @@ end
@test variable_symbols(sys, 3) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == [:t]
@test all_variable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]
2 changes: 2 additions & 0 deletions test/fallback_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ all_syms = [:x, :y, :z, :a, :b, :t]
@test variable_symbols(sys) == variable_symbols(sc)
@test parameter_symbols(sys) == parameter_symbols(sc)
@test independent_variable_symbols(sys) == independent_variable_symbols(sc)
@test all_variable_symbols(sys) == variable_symbols(sc)
@test all_symbols(sys) == all_symbols(sc)
7 changes: 7 additions & 0 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
@test variable_symbols(sc) == [:x, :y, :z]
@test parameter_symbols(sc) == [:a, :b]
@test independent_variable_symbols(sc) == [:t]
@test all_variable_symbols(sc) == [:x, :y, :z]
@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z]

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
# make sure the constructor works
@test_nowarn SymbolCache([:x, :y])

Expand All @@ -30,12 +33,16 @@ sc = SymbolCache()
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t]))
@test independent_variable_symbols(sc) == []
@test !is_time_dependent(sc)
@test all_variable_symbols(sc) == []
@test all_symbols(sc) == []

sc = SymbolCache(nothing, nothing, :t)
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b]))
@test is_independent_variable(sc, :t)
@test independent_variable_symbols(sc) == [:t]
@test is_time_dependent(sc)
@test all_variable_symbols(sc) == []
@test all_symbols(sc) == [:t]

sc2 = copy(sc)
@test sc.variables == sc2.variables
Expand Down

0 comments on commit 0ad8caa

Please sign in to comment.