Skip to content

Commit

Permalink
feat: add optional history function to ProblemState
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 3, 2024
1 parent 2a247d3 commit e440ff7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/problem_state.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
"""
struct ProblemState
function ProblemState(; u = nothing, p = nothing, t = nothing)
function ProblemState(; u = nothing, p = nothing, t = nothing, h = nothing)
A value provider struct which can be used as an argument to the function returned by
[`getsym`](@ref) or [`setsym`](@ref). It stores the state vector, parameter object and
current time, and forwards calls to [`state_values`](@ref), [`parameter_values`](@ref),
[`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained
objects.
A history function may be provided using the `h` keyword, which will be returned with
[`get_history_function`](@ref).
"""
struct ProblemState{U, P, T}
struct ProblemState{U, P, T, H}
u::U
p::P
t::T
h::H
end

ProblemState(; u = nothing, p = nothing, t = nothing) = ProblemState(u, p, t)
function ProblemState(; u = nothing, p = nothing, t = nothing, h = nothing)
ProblemState(u, p, t, h)
end

state_values(ps::ProblemState) = ps.u
parameter_values(ps::ProblemState) = ps.p
current_time(ps::ProblemState) = ps.t
set_state!(ps::ProblemState, val, idx) = set_state!(ps.u, val, idx)
set_parameter!(ps::ProblemState, val, idx) = set_parameter!(ps.p, val, idx)
get_history_function(ps::ProblemState) = ps.h
5 changes: 4 additions & 1 deletion test/problem_state_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ using SymbolicIndexingInterface
using Test

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5)
prob = ProblemState(;
u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5, h = Returns(ones(3)))

for (i, sym) in enumerate(variable_symbols(sys))
@test getsym(sys, sym)(prob) == prob.u[i]
Expand All @@ -13,3 +14,5 @@ end
@test getsym(sys, :t)(prob) == prob.t

@test getsym(sys, :(x + a + t))(prob) == 1.6

@test get_history_function(prob) !== nothing
3 changes: 3 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ getter = getsym(sys, :(x + y))
@test getter(fs) [3.0i + 2(ts[i] - 0.1) for i in 1:11]
@test getter(fs, 1) 2.8

pstate = ProblemState(; u = u0, p = p, t = ts[1], h = t -> t .* ones(length(u0)))
@test getter(pstate) 2.8

struct TupleObservedWrapper{S}
sys::S
end
Expand Down

0 comments on commit e440ff7

Please sign in to comment.