Skip to content

Commit

Permalink
Merge pull request #64 from SciML/as/problem-state
Browse files Browse the repository at this point in the history
feat: add `ProblemState`
  • Loading branch information
ChrisRackauckas authored Mar 28, 2024
2 parents 494548d + c1b60a6 commit da7aa70
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
- {user: SciML, repo: RecursiveArrayTools.jl, group: SymbolicIndexingInterface}
- {user: JuliaSymbolics, repo: Symbolics.jl, group: SymbolicIndexingInterface}
- {user: SciML, repo: SciMLBase.jl, group: SymbolicIndexingInterface}
- {user: SciML, repo: ModelingToolkit.jl, group: SymbolicIndexingInterface}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@ symbolic_evaluate

```@docs
SymbolCache
ProblemState
```
28 changes: 26 additions & 2 deletions docs/src/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ Consider the following example:

```@example Usage
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Plots
using ModelingToolkit: t_nounits as t, D_nounits as D
@parameters σ ρ β
@variables t x(t) y(t) z(t) w(t)
D = Differential(t)
@variables x(t) y(t) z(t) w(t)
eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
Expand Down Expand Up @@ -121,6 +121,30 @@ output, the following shorthand is used:
sol[allvariables] # equivalent to sol[all_variable_symbols(sol)]
```

### Evaluating expressions

`getu` also generates functions for expressions if the object passed to it supports
[`observed`](@ref). For example:

```@example Usage
getu(prob, x + y + z)(prob)
```

To evaluate this function using values other than the ones contained in `prob`, we need
an object that supports [`state_values`](@ref), [`parameter_values`](@ref),
[`current_time`](@ref). SymbolicIndexingInterface provides the [`ProblemState`](@ref) type,
which has trivial implementations of the above functions. We can thus do:

```@example Usage
temp_state = ProblemState(; u = [0.1, 0.2, 0.3, 0.4], p = parameter_values(prob))
getu(prob, x + y + z)(temp_state)
```

Note that providing all of the state vector, parameter object and time may not be
necessary if the function generated by `observed` does not access them. ModelingToolkit.jl
generates functions that access the parameters regardless of whether they are used in the
expression, and thus it needs to be provided to the `ProblemState`.

## Parameter Indexing: Getting and Setting Parameter Values

Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref).
Expand Down
3 changes: 3 additions & 0 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ include("parameter_indexing.jl")
export state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

export ProblemState
include("problem_state.jl")

export ParameterIndexingProxy
include("parameter_indexing_proxy.jl")

Expand Down
23 changes: 23 additions & 0 deletions src/problem_state.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
struct ProblemState
function ProblemState(; u = nothing, p = nothing, t = nothing)
A struct which can be used as an argument to the function returned by [`getu`](@ref) or
[`setu`](@ref). It stores the state vector, parameter object and current time, and
forwards calls to [`state_values`](@ref), [`parameter_values`](@ref),
[`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained
objects.
"""
struct ProblemState{U, P, T}
u::U
p::P
t::T
end

ProblemState(; u = nothing, p = nothing, t = nothing) = ProblemState(u, p, t)

state_values(ps::ProblemState) = ps.u
parameter_values(ps::ProblemState) = ps.p
current_time(ps::ProblemState) = ps.t
set_state!(ps::ProblemState, val, idx) = set_state!(ps.u, val, idx)
set_parameter!(ps::ProblemState, val, idx) = set_parameter!(ps.p, val, idx)
15 changes: 15 additions & 0 deletions test/problem_state_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using SymbolicIndexingInterface
using Test

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5)

for (i, sym) in enumerate(variable_symbols(sys))
@test getu(sys, sym)(prob) == prob.u[i]
end
for (i, sym) in enumerate(parameter_symbols(sys))
@test getp(sys, sym)(prob) == prob.p[i]
end
@test getu(sys, :t)(prob) == prob.t

@test getu(sys, :(x + a + t))(prob) == 1.6
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ end
@safetestset "Remake test" begin
@time include("remake_test.jl")
end
@safetestset "ProblemState test" begin
@time include("problem_state_test.jl")
end

0 comments on commit da7aa70

Please sign in to comment.