Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add all_solvable_symbols and all_symbols #18

Merged
merged 6 commits into from
Dec 27, 2023
4 changes: 3 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[compat]
Documenter = "0.27"
Expand Down
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

pages = [
"Home" => "index.md",
"API" => "api.md",
"Tutorial" => "tutorial.md",
"Usage" => "usage.md",
"API" => "api.md",
]
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ is_observed
observed
is_time_dependent
constant_structure
all_variable_symbols
all_symbols
solvedvariables
allvariables
parameter_values
getp
setp
Expand Down
21 changes: 19 additions & 2 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct ExampleSolution
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
independent_variable::Union{Symbol,Nothing}
# mapping from observed variable to Expr to calculate its value
observed::Dict{Symbol,Expr}
u::Vector{Vector{Float64}}
p::Vector{Float64}
t::Vector{Float64}
Expand Down Expand Up @@ -86,9 +88,9 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::ExampleSolu
sys.independent_variable === nothing ? [] : [sys.independent_variable]
end

# this types accepts `Expr` for observed expressions involving state/parameter
# this type accepts `Expr` for observed expressions involving state/parameter/observed
# variables
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)

function SymbolicIndexingInterface.observed(sys::ExampleSolution, sym::Expr)
if is_time_dependent(sys)
Expand All @@ -109,6 +111,21 @@ function SymbolicIndexingInterface.is_time_dependent(sys::ExampleSolution)
end

SymbolicIndexingInterface.constant_structure(::ExampleSolution) = true

function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSolution)
return vcat(
collect(keys(sys.state_index)),
collect(keys(sys.observed)),
)
end

function SymbolicIndexingInterface.all_symbols(sys::ExampleSolution)
return vcat(
all_solvable_symbols(sys),
collect(keys(sys.parameter_index)),
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
)
end
```

Note that the method definitions are all assuming `constant_structure(p) == true`.
Expand Down
131 changes: 131 additions & 0 deletions docs/src/usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Using the SymbolicIndexingInterface

This tutorial will cover ways to use the interface for types that implement it.
Consider the following example:

```@example Usage
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface

@parameters σ ρ β
@variables t x(t) y(t) z(t) w(t)
D = Differential(t)

eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z,
w ~ x + y + z]

@named sys = ODESystem(eqs)
sys = structural_simplify(sys)
```

The system has 4 state variables, 3 parameters and one observed variable:
```@example Usage
observed(sys)
```

Solving the system,
```@example Usage
u0 = [D(x) => 2.0,
x => 1.0,
y => 0.0,
z => 0.0]

p = [σ => 28.0,
ρ => 10.0,
β => 8 / 3]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())
```

We can obtain the timeseries of any time-dependent variable using `getindex`
```@example Usage
sol[x]
```

This also works for arrays or tuples of variables, including observed quantities and
independent variables:
```@example Usage
sol[[x, y]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, is this using the optimized call now, i.e. a single call to the observed function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's broadcasted over the array of symbols, so each one is computed individually. A single call to observed behaves differently:

julia> SymbolicIndexingInterface.observed(sol, [x, y])(sol.u, sol.prob.p, sol.t)
2-element Vector{Vector{Float64}}:
 [1.6563910957343146, 0.13029732534611174, 0.0008464888325952026, 1.023645017393648]
 [1.9960454103704428, 0.0014132521265988688, 9.986094350108851e-8, 1.0002823510089436]
julia> SymbolicIndexingInterface.observed(sol, (x, y))(sol.u, sol.prob.p, sol.t)
(x(t), y(t))

For state variables (whose values are in sol.u) I'd imagine broadcasting is faster than the observed method? It avoids a generated function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For state variables it's faster, then for observed it's faster to clump them up. Optimizing this a bit is probably a good idea in the future.

```

```@example Usage
sol[(t, w)]
```

If necessary, `Symbol`s can be used to refer to variables. This is only valid for
symbolic variables for which [`hasname`](@ref) returns `true`. The `Symbol` used must
match the one returned by [`getname`](@ref) for the variable.
```@example Usage
hasname(x)
```

```@example Usage
getname(x)
```

```@example Usage
sol[(:x, :w)]
```

Note how when indexing with an array, the returned type is a `Vector{Array{Float64}}`,
and when using a `Tuple`, the returned type is `Vector{Tuple{Float64, Float64}}`.
To obtain the value of all state variables, we can use the shorthand:
```@example Usage
sol[solvedvariables] # equivalent to sol[variable_symbols(sol)]
```

This does not include the observed variable `w`. To include observed variables in the
output, the following shorthand is used:
```@example Usage
sol[allvariables] # equivalent to sol[all_variable_symbols(sol)]
```

Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref).

```@example Usage
σ_getter = getp(sys, σ)
σ_getter(sol) # can also pass `prob`
```

Note that this also supports arrays/tuples of parameter symbols:

```@example Usage
σ_ρ_getter = getp(sys, (σ, ρ))
σ_ρ_getter(sol)
```

Now suppose the system has to be solved with a different value of the parameter `β`.

```@example Usage
β_setter = setp(sys, β)
β_setter(prob, 3)
```

The updated parameter values can be checked using [`parameter_values`](@ref).

```@example Usage
parameter_values(prob)
```

Solving the new system, note that the parameter getter functions still work on the new
solution object.

```@example Usage
sol2 = solve(prob, Tsit5())
σ_getter(sol)
```

```@example Usage
σ_ρ_getter(sol)
```

To set the entire parameter vector at once, [`parameter_values`](@ref) can be used
(note the usage of broadcasted assignment).

```@example Usage
parameter_values(prob) .= [29.0, 11.0, 2.5]
parameter_values(prob)
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also mention symbolic interpolation indexing, i.e. sol(t; idxs=x), plot indexing plot(sol, idxs=(x,y))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

3 changes: 2 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ include("trait.jl")

export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
observed, is_time_dependent, constant_structure, symbolic_container
observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols,
all_symbols, solvedvariables, allvariables
include("interface.jl")

export SymbolCache
Expand Down
47 changes: 47 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
Return a vector of the symbolic variables being solved for in the system `sys`. If
`constant_structure(sys) == false` this accepts an additional parameter indicating
the current time index. The returned vector should not be mutated.

For types that implement `Base.getindex` with symbolic indices using this interface,
The shorthand `sys[solvedvariables]` can be used as shorthand for
`sys[variable_symbols(sys)]`. See: [`solvedvariables`](@ref).
"""
variable_symbols(sys) = variable_symbols(symbolic_container(sys))
variable_symbols(sys, i) = variable_symbols(symbolic_container(sys), i)
Expand Down Expand Up @@ -110,3 +114,46 @@
number of variables or parameters over time.
"""
constant_structure(sys) = constant_structure(symbolic_container(sys))

"""
all_variable_symbols(sys)

Return a vector of variable symbols in the system, including observed quantities.

For types that implement `Base.getindex` with symbolic indices using this interface,
The shorthand `sys[allvariables]` can be used as shorthand for
`sys[all_variable_symbols(sys)]`. See: [`allvariables`](@ref).
"""
all_variable_symbols(sys) = all_variable_symbols(symbolic_container(sys))

Check warning on line 127 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L127

Added line #L127 was not covered by tests

"""
all_symbols(sys)

Return an array of all symbols in the system. This includes parameters and independent
variables.
"""
all_symbols(sys) = all_symbols(symbolic_container(sys))

Check warning on line 135 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L135

Added line #L135 was not covered by tests

struct SolvedVariables end

"""
const solvedvariables = SolvedVariables()

This singleton is used as a shortcut to allow indexing all solution variables
(excluding observed quantities). It has a [`symbolic_type`](@ref) of
[`ScalarSymbolic`](@ref). See: [`variable_symbols`](@ref).
"""
const solvedvariables = SolvedVariables()
symbolic_type(::Type{SolvedVariables}) = ScalarSymbolic()

struct AllVariables end

"""
const allvariables = AllVariables()

This singleton is used as a shortcut to allow indexing all solution variables
(including observed quantities). It has a [`symbolic_type`](@ref) of
[`ScalarSymbolic`](@ref). See [`all_variable_symbols`](@ref).
"""
const allvariables = AllVariables()
symbolic_type(::Type{AllVariables}) = ScalarSymbolic()
2 changes: 2 additions & 0 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
end
end
constant_structure(::SymbolCache) = true
all_variable_symbols(sc::SymbolCache) = variable_symbols(sc)
all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))

Check warning on line 75 in src/symbol_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/symbol_cache.jl#L74-L75

Added lines #L74 - L75 were not covered by tests

function Base.copy(sc::SymbolCache)
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),
Expand Down
10 changes: 10 additions & 0 deletions test/example_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not
end
SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t)
SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static
SymbolicIndexingInterface.all_variable_symbols(sys::SystemMockup) = sys.vars
function SymbolicIndexingInterface.all_symbols(sys::SystemMockup)
vcat(sys.vars, sys.params, independent_variable_symbols(sys))
end

sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)

Expand All @@ -79,6 +83,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
@test variable_symbols(sys) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == [:t]
@test all_variable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]

sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)

Expand All @@ -93,6 +99,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
@test variable_symbols(sys) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == []
@test all_variable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :x, :y, :z]

sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t)
@test !constant_structure(sys)
Expand All @@ -115,3 +123,5 @@ end
@test variable_symbols(sys, 3) == [:x, :y, :z]
@test parameter_symbols(sys) == [:a, :b, :c]
@test independent_variable_symbols(sys) == [:t]
@test all_variable_symbols(sys) == [:x, :y, :z]
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]
2 changes: 2 additions & 0 deletions test/fallback_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ all_syms = [:x, :y, :z, :a, :b, :t]
@test variable_symbols(sys) == variable_symbols(sc)
@test parameter_symbols(sys) == parameter_symbols(sc)
@test independent_variable_symbols(sys) == independent_variable_symbols(sc)
@test all_variable_symbols(sys) == variable_symbols(sc)
@test all_symbols(sys) == all_symbols(sc)
7 changes: 7 additions & 0 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
@test variable_symbols(sc) == [:x, :y, :z]
@test parameter_symbols(sc) == [:a, :b]
@test independent_variable_symbols(sc) == [:t]
@test all_variable_symbols(sc) == [:x, :y, :z]
@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z]

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
# make sure the constructor works
@test_nowarn SymbolCache([:x, :y])

Expand All @@ -30,12 +33,16 @@ sc = SymbolCache()
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t]))
@test independent_variable_symbols(sc) == []
@test !is_time_dependent(sc)
@test all_variable_symbols(sc) == []
@test all_symbols(sc) == []

sc = SymbolCache(nothing, nothing, :t)
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b]))
@test is_independent_variable(sc, :t)
@test independent_variable_symbols(sc) == [:t]
@test is_time_dependent(sc)
@test all_variable_symbols(sc) == []
@test all_symbols(sc) == [:t]

sc2 = copy(sc)
@test sc.variables == sc2.variables
Expand Down
Loading