From 5ab2679935e39d63912b0637841c06f71cf7f4a1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Jan 2024 12:56:16 +0530 Subject: [PATCH] feat: add support for parameter and indepvar symbols in getu --- src/state_indexing.jl | 32 ++++++++++++---- test/state_indexing_test.jl | 76 ++++++++++++++++++++++++++++++++++--- 2 files changed, 95 insertions(+), 13 deletions(-) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 42578c63..008f79b9 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -94,9 +94,10 @@ Return a function that takes an integrator, problem or solution of `sys`, and re the value of the symbolic `sym`. If `sym` is not an observed quantity, the returned function can also directly be called with an array of values representing the state vector. `sym` can be a direct index into the state vector, a symbolic state, a symbolic -expression involving symbolic quantities in the system `sys`, or an array/tuple of the -aforementioned. If the returned function is called with a timeseries object, it can also -be given a second argument representing the index at which to find the value of `sym`. +expression involving symbolic quantities in the system `sys`, a parameter symbol, or the +independent variable symbol, or an array/tuple of the aforementioned. If the returned +function is called with a timeseries object, it can also be given a second argument +representing the index at which to find the value of `sym`. At minimum, this requires that the integrator, problem or solution implement [`state_values`](@ref). To support symbolic expressions, the integrator or problem @@ -131,6 +132,19 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) return getu(sys, idx) + elseif is_parameter(sys, sym) + return let fn = getp(sys, sym) + getter(prob, args...) = fn(prob) + getter + end + elseif is_independent_variable(sys, sym) + _getter(::IsTimeseriesTrait, prob) = current_time(prob) + _getter(::Timeseries, prob, i) = current_time(prob, i) + return let _getter = _getter + getter(prob) = _getter(is_timeseries(prob), prob) + getter(prob, i) = _getter(is_timeseries(prob), prob, i) + getter + end elseif is_observed(sys, sym) fn = observed(sys, sym) if is_time_dependent(sys) @@ -226,11 +240,15 @@ function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym) end function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) - is_variable(sys, sym) || error("Invalid symbol $sym for `setu`") - idx = variable_index(sys, sym) - return function setter!(prob, val) - set_state!(prob, val, idx) + if is_variable(sys, sym) + idx = variable_index(sys, sym) + return function setter!(prob, val) + set_state!(prob, val, idx) + end + elseif is_parameter(sys, sym) + return setp(sys, sym) end + error("Invalid symbol $sym for `setu`") end for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index f305f6ea..6d8ee6a0 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -1,16 +1,22 @@ using SymbolicIndexingInterface -struct FakeIntegrator{S, U} +struct FakeIntegrator{S, U, P, T} sys::S u::U + p::P + t::T end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u +SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p +SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t -sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) u = [1.0, 2.0, 3.0] -fi = FakeIntegrator(sys, copy(u)) +p = [11.0, 12.0, 13.0] +t = 0.5 +fi = FakeIntegrator(sys, copy(u), copy(p), t) # checking inference for non-concretely typed arrays will always fail for (sym, val, newval, check_inference) in [ (:x, u[1], 4.0, true) @@ -61,19 +67,60 @@ for (sym, val, newval, check_inference) in [ @test get(u) == val end +for (sym, oldval, newval, check_inference) in [ + (:a, p[1], 4.0, true) + (:b, p[2], 5.0, true) + (:c, p[3], 6.0, true) + ([:a, :b], p[1:2], [4.0, 5.0], true) + ((:c, :b), (p[3], p[2]), (6.0, 5.0), true) + ([:x, :a], [u[1], p[1]], [4.0, 5.0], false) + ((:y, :b), (u[2], p[2]), (5.0, 6.0), true) +] + get = getu(fi, sym) + set! = setu(fi, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == oldval + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval +end -struct FakeSolution{S, U} +for (sym, val, check_inference) in [ + (:t, t, true), + ([:x, :a, :t], [u[1], p[1], t], false), + ((:x, :a, :t), (u[1], p[1], t), true), +] + get = getu(fi, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == val +end + +struct FakeSolution{S, U, P, T} sys::S u::U + p::P + t::T end SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries() SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u +SymbolicIndexingInterface.parameter_values(fp::FakeSolution) = fp.p +SymbolicIndexingInterface.current_time(fp::FakeSolution) = fp.t -sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] -sol = FakeSolution(sys, u) +t = [1.5, 2.0] +sol = FakeSolution(sys, u, p, t) xvals = getindex.(sol.u, 1) yvals = getindex.(sol.u, 2) @@ -97,6 +144,11 @@ for (sym, ans, check_inference) in [ ((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true) ((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true) ((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true) + ([:x, :a], vcat.(xvals, p[1]), false) + ((:y, :b), tuple.(yvals, p[2]), true) + (:t, t, true) + ([:x, :a, :t], vcat.(xvals, p[1], t), false) + ((:x, :a, :t), tuple.(xvals, p[1], t), true) ] get = getu(sys, sym) if check_inference @@ -110,3 +162,15 @@ for (sym, ans, check_inference) in [ @test get(sol, i) == ans[i] end end + +for (sym, val) in [ + (:a, p[1]) + (:b, p[2]) + (:c, p[3]) + ([:a, :b], p[1:2]) + ((:c, :b), (p[3], p[2])) +] + get = getu(fi, sym) + @inferred get(fi) + @test get(fi) == val +end