Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ProblemState #64

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
- {user: SciML, repo: RecursiveArrayTools.jl, group: SymbolicIndexingInterface}
- {user: JuliaSymbolics, repo: Symbolics.jl, group: SymbolicIndexingInterface}
- {user: SciML, repo: SciMLBase.jl, group: SymbolicIndexingInterface}
- {user: SciML, repo: ModelingToolkit.jl, group: SymbolicIndexingInterface}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@ symbolic_evaluate

```@docs
SymbolCache
ProblemState
```
28 changes: 26 additions & 2 deletions docs/src/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ Consider the following example:

```@example Usage
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Plots
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters σ ρ β
@variables t x(t) y(t) z(t) w(t)
D = Differential(t)
@variables x(t) y(t) z(t) w(t)

eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
Expand Down Expand Up @@ -121,6 +121,30 @@ output, the following shorthand is used:
sol[allvariables] # equivalent to sol[all_variable_symbols(sol)]
```

### Evaluating expressions

`getu` also generates functions for expressions if the object passed to it supports
[`observed`](@ref). For example:

```@example Usage
getu(prob, x + y + z)(prob)
```

To evaluate this function using values other than the ones contained in `prob`, we need
an object that supports [`state_values`](@ref), [`parameter_values`](@ref),
[`current_time`](@ref). SymbolicIndexingInterface provides the [`ProblemState`](@ref) type,
which has trivial implementations of the above functions. We can thus do:

```@example Usage
temp_state = ProblemState(; u = [0.1, 0.2, 0.3, 0.4], p = parameter_values(prob))
getu(prob, x + y + z)(temp_state)
```

Note that providing all of the state vector, parameter object and time may not be
necessary if the function generated by `observed` does not access them. ModelingToolkit.jl
generates functions that access the parameters regardless of whether they are used in the
expression, and thus it needs to be provided to the `ProblemState`.

## Parameter Indexing: Getting and Setting Parameter Values

Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref).
Expand Down
3 changes: 3 additions & 0 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ include("parameter_indexing.jl")
export state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

export ProblemState
include("problem_state.jl")

export ParameterIndexingProxy
include("parameter_indexing_proxy.jl")

Expand Down
23 changes: 23 additions & 0 deletions src/problem_state.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
struct ProblemState
function ProblemState(; u = nothing, p = nothing, t = nothing)

A struct which can be used as an argument to the function returned by [`getu`](@ref) or
[`setu`](@ref). It stores the state vector, parameter object and current time, and
forwards calls to [`state_values`](@ref), [`parameter_values`](@ref),
[`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained
objects.
"""
struct ProblemState{U, P, T}
u::U
p::P
t::T
end

ProblemState(; u = nothing, p = nothing, t = nothing) = ProblemState(u, p, t)

Check warning on line 17 in src/problem_state.jl

View check run for this annotation

Codecov / codecov/patch

src/problem_state.jl#L17

Added line #L17 was not covered by tests

state_values(ps::ProblemState) = ps.u
parameter_values(ps::ProblemState) = ps.p
current_time(ps::ProblemState) = ps.t
set_state!(ps::ProblemState, val, idx) = set_state!(ps.u, val, idx)
set_parameter!(ps::ProblemState, val, idx) = set_parameter!(ps.p, val, idx)

Check warning on line 23 in src/problem_state.jl

View check run for this annotation

Codecov / codecov/patch

src/problem_state.jl#L19-L23

Added lines #L19 - L23 were not covered by tests
15 changes: 15 additions & 0 deletions test/problem_state_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using SymbolicIndexingInterface
using Test

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5)

for (i, sym) in enumerate(variable_symbols(sys))
@test getu(sys, sym)(prob) == prob.u[i]
end
for (i, sym) in enumerate(parameter_symbols(sys))
@test getp(sys, sym)(prob) == prob.p[i]
end
@test getu(sys, :t)(prob) == prob.t

@test getu(sys, :(x + a + t))(prob) == 1.6
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ end
@safetestset "Remake test" begin
@time include("remake_test.jl")
end
@safetestset "ProblemState test" begin
@time include("problem_state_test.jl")
end
Loading