From e440ff768c85a6db8fd3c5438c1206ed0cae7a3a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 15:55:39 +0530 Subject: [PATCH] feat: add optional history function to `ProblemState` --- src/problem_state.jl | 13 ++++++++++--- test/problem_state_test.jl | 5 ++++- test/state_indexing_test.jl | 3 +++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/problem_state.jl b/src/problem_state.jl index 831deae..2d07a0f 100644 --- a/src/problem_state.jl +++ b/src/problem_state.jl @@ -1,23 +1,30 @@ """ struct ProblemState - function ProblemState(; u = nothing, p = nothing, t = nothing) + function ProblemState(; u = nothing, p = nothing, t = nothing, h = nothing) A value provider struct which can be used as an argument to the function returned by [`getsym`](@ref) or [`setsym`](@ref). It stores the state vector, parameter object and current time, and forwards calls to [`state_values`](@ref), [`parameter_values`](@ref), [`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained objects. + +A history function may be provided using the `h` keyword, which will be returned with +[`get_history_function`](@ref). """ -struct ProblemState{U, P, T} +struct ProblemState{U, P, T, H} u::U p::P t::T + h::H end -ProblemState(; u = nothing, p = nothing, t = nothing) = ProblemState(u, p, t) +function ProblemState(; u = nothing, p = nothing, t = nothing, h = nothing) + ProblemState(u, p, t, h) +end state_values(ps::ProblemState) = ps.u parameter_values(ps::ProblemState) = ps.p current_time(ps::ProblemState) = ps.t set_state!(ps::ProblemState, val, idx) = set_state!(ps.u, val, idx) set_parameter!(ps::ProblemState, val, idx) = set_parameter!(ps.p, val, idx) +get_history_function(ps::ProblemState) = ps.h diff --git a/test/problem_state_test.jl b/test/problem_state_test.jl index 906ee73..a9245dd 100644 --- a/test/problem_state_test.jl +++ b/test/problem_state_test.jl @@ -2,7 +2,8 @@ using SymbolicIndexingInterface using Test sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) -prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5) +prob = ProblemState(; + u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5, h = Returns(ones(3))) for (i, sym) in enumerate(variable_symbols(sys)) @test getsym(sys, sym)(prob) == prob.u[i] @@ -13,3 +14,5 @@ end @test getsym(sys, :t)(prob) == prob.t @test getsym(sys, :(x + a + t))(prob) == 1.6 + +@test get_history_function(prob) !== nothing diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index ade57bb..4537276 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -347,6 +347,9 @@ getter = getsym(sys, :(x + y)) @test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11] @test getter(fs, 1) ≈ 2.8 +pstate = ProblemState(; u = u0, p = p, t = ts[1], h = t -> t .* ones(length(u0))) +@test getter(pstate) ≈ 2.8 + struct TupleObservedWrapper{S} sys::S end