From 599575d40565fff8f853c2549e8edcc1d145939a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 3 Feb 2024 20:40:29 +0530 Subject: [PATCH] refactor: make getu, getp generic of the parameter container --- docs/pages.jl | 4 +- src/SymbolicIndexingInterface.jl | 11 +-- src/parameter_indexing.jl | 27 ++++--- src/state_indexing.jl | 8 +- src/symbol_cache.jl | 2 +- test/parameter_indexing_test.jl | 2 +- test/state_indexing_test.jl | 129 ++++++++++++++++++------------- 7 files changed, 106 insertions(+), 77 deletions(-) diff --git a/docs/pages.jl b/docs/pages.jl index 26e31ae..b27f2c5 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -5,8 +5,8 @@ pages = [ "Tutorials" => [ "Using the SciML Symbolic Indexing Interface" => "usage.md", "Simple Demonstration of a Symbolic System Structure" => "simple_sii_sys.md", - "Implementing the Complete Symbolic Indexing Interface" => "complete_sii.md", + "Implementing the Complete Symbolic Indexing Interface" => "complete_sii.md" ], "Defining Solution Wrapper Fallbacks" => "solution_wrappers.md", - "API" => "api.md", + "API" => "api.md" ] diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 7077e92..4edcb21 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -4,10 +4,11 @@ export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getna include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, - parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, - observed, is_time_dependent, constant_structure, symbolic_container, - all_variable_symbols, - all_symbols, solvedvariables, allvariables + parameter_symbols, is_independent_variable, independent_variable_symbols, + is_observed, + observed, is_time_dependent, constant_structure, symbolic_container, + all_variable_symbols, + all_symbols, solvedvariables, allvariables include("interface.jl") export SymbolCache @@ -17,7 +18,7 @@ export parameter_values, set_parameter!, getp, setp include("parameter_indexing.jl") export Timeseries, - NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu + NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu include("state_indexing.jl") export ParameterIndexingProxy diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index e28720b..eaa5f2c 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -1,13 +1,19 @@ """ parameter_values(p) + parameter_values(p, i) -Return an indexable collection containing the value of each parameter in `p`. +Return an indexable collection containing the value of each parameter in `p`. The two- +argument version of this function returns the parameter value at index `i`. The +two-argument version of this function will default to returning +`parameter_values(p)[i]`. If this function is called with an `AbstractArray`, it will return the same array. """ function parameter_values end parameter_values(arr::AbstractArray) = arr +parameter_values(arr::AbstractArray, i) = arr[i] +parameter_values(prob, i) = parameter_values(parameter_values(prob), i) """ set_parameter!(sys, val, idx) @@ -19,16 +25,19 @@ defined to enable the proper functioning of [`setp`](@ref). See: [`parameter_values`](@ref) """ -function set_parameter!(sys, val, idx) - parameter_values(sys)[idx] = val +function set_parameter! end + +function set_parameter!(sys::AbstractArray, val, idx) + sys[idx] = val end +set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx) """ getp(sys, p) Return a function that takes an array representing the parameter vector or an integrator or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a -direct numerical index or a symbolic value, or an array/tuple of the aforementioned. +direct index or a symbolic value, or an array/tuple of the aforementioned. Requires that the integrator or solution implement [`parameter_values`](@ref). This function typically does not need to be implemented, and has a default implementation relying on @@ -42,21 +51,21 @@ end function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) return function getter(sol) - return parameter_values(sol)[p] + return parameter_values(sol, p) end end function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) return function getter(sol) - return parameter_values(sol)[idx] + return parameter_values(sol, idx) end end for (t1, t2) in [ (ArraySymbolic, Any), (ScalarSymbolic, Any), - (NotSymbolic, Union{<:Tuple, <:AbstractArray}), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) getters = getp.((sys,), p) @@ -76,7 +85,7 @@ end Return a function that takes an array representing the parameter vector or an integrator or problem of `sys`, and a value, and sets the parameter `p` to that value. Note that `p` -can be a direct numerical index or a symbolic value. +can be a direct index or a symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the returned collection be a mutable reference to the parameter vector in the integrator. In @@ -106,7 +115,7 @@ end for (t1, t2) in [ (ArraySymbolic, Any), (ScalarSymbolic, Any), - (NotSymbolic, Union{<:Tuple, <:AbstractArray}), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] @eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2) setters = setp.((sys,), p) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index f2cd7f2..5ebc82f 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -116,7 +116,7 @@ end function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) _getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,)) _getter(::Timeseries, prob, i) = getindex(state_values(prob, i), sym) - _getter(::NotTimeseries, prob) = state_values(prob)[sym] + _getter(::NotTimeseries, prob) = state_values(prob, sym) return let _getter = _getter function getter(prob) return _getter(is_timeseries(prob), prob) @@ -186,7 +186,7 @@ end for (t1, t2) in [ (ScalarSymbolic, Any), (ArraySymbolic, Any), - (NotSymbolic, Union{<:Tuple, <:AbstractArray}), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) num_observed = count(x -> is_observed(sys, x), sym) @@ -266,7 +266,7 @@ end Return a function that takes an array representing the state vector or an integrator or problem of `sys`, and a value, and sets the the state `sym` to that value. Note that `sym` -can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned. +can be a direct index, a symbolic state, or an array/tuple of the aforementioned. Requires that the integrator implement [`state_values`](@ref) and the returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to @@ -301,7 +301,7 @@ end for (t1, t2) in [ (ScalarSymbolic, Any), (ArraySymbolic, Any), - (NotSymbolic, Union{<:Tuple, <:AbstractArray}), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] @eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2) setters = setu.((sys,), sym) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index a1f566e..5a2a81c 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -14,7 +14,7 @@ array containing a single variable if the system has only one independent variab struct SymbolCache{ V <: Union{Nothing, AbstractVector}, P <: Union{Nothing, AbstractVector}, - I, + I } variables::V parameters::P diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 42d98f9..a2646f8 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -26,7 +26,7 @@ for (sym, oldval, newval, check_inference) in [ ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) ] get = getp(sys, sym) set! = setp(sys, sym) diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index e47a460..980d3d8 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -19,23 +19,33 @@ t = 0.5 fi = FakeIntegrator(sys, copy(u), copy(p), t) # checking inference for non-concretely typed arrays will always fail for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) - (:y, u[2], 4.0, true) - (:z, u[3], 4.0, true) - (1, u[1], 4.0, true) - ([:x, :y], u[1:2], 4ones(2), true) - ([1, 2], u[1:2], 4ones(2), true) - ((:z, :y), (u[3], u[2]), (4.0, 5.0), true) - ((3, 2), (u[3], u[2]), (4.0, 5.0), true) - ([:x, [:y, :z]], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false) - ([:x, 2:3], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false) - ([:x, (:y, :z)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false) - ([:x, Tuple(2:3)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false) - ([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false) - ([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false) - ((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true) - ((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) - ((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) - ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)] + (:y, u[2], 4.0, true) + (:z, u[3], 4.0, true) + (1, u[1], 4.0, true) + ([:x, :y], u[1:2], 4ones(2), true) + ([1, 2], u[1:2], 4ones(2), true) + ((:z, :y), (u[3], u[2]), (4.0, 5.0), true) + ((3, 2), (u[3], u[2]), (4.0, 5.0), true) + ([:x, [:y, :z]], [u[1], u[2:3]], + [4.0, [5.0, 6.0]], false) + ([:x, 2:3], [u[1], u[2:3]], + [4.0, [5.0, 6.0]], false) + ([:x, (:y, :z)], [u[1], (u[2], u[3])], + [4.0, (5.0, 6.0)], false) + ([:x, Tuple(2:3)], [u[1], (u[2], u[3])], + [4.0, (5.0, 6.0)], false) + ([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], + [4.0, [5.0], (6.0,)], false) + ([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], + [4.0, [5.0], (6.0,)], false) + ((:x, [:y, :z]), (u[1], u[2:3]), + (4.0, [5.0, 6.0]), true) + ((:x, (:y, :z)), (u[1], (u[2], u[3])), + (4.0, (5.0, 6.0)), true) + ((1, (:y, :z)), (u[1], (u[2], u[3])), + (4.0, (5.0, 6.0)), true) + ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), + (4.0, [5.0], (6.0,)), true)] get = getu(sys, sym) set! = setu(sys, sym) if check_inference @@ -66,12 +76,12 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) end for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) - (:b, p[2], 5.0, true) - (:c, p[3], 6.0, true) - ([:a, :b], p[1:2], [4.0, 5.0], true) - ((:c, :b), (p[3], p[2]), (6.0, 5.0), true) - ([:x, :a], [u[1], p[1]], [4.0, 5.0], false) - ((:y, :b), (u[2], p[2]), (5.0, 6.0), true)] + (:b, p[2], 5.0, true) + (:c, p[3], 6.0, true) + ([:a, :b], p[1:2], [4.0, 5.0], true) + ((:c, :b), (p[3], p[2]), (6.0, 5.0), true) + ([:x, :a], [u[1], p[1]], [4.0, 5.0], false) + ((:y, :b), (u[2], p[2]), (5.0, 6.0), true)] get = getu(fi, sym) set! = setu(fi, sym) if check_inference @@ -91,7 +101,7 @@ end for (sym, val, check_inference) in [ (:t, t, true), ([:x, :a, :t], [u[1], p[1], t], false), - ((:x, :a, :t), (u[1], p[1], t), true), + ((:x, :a, :t), (u[1], p[1], t), true) ] get = getu(fi, sym) if check_inference @@ -123,33 +133,42 @@ yvals = getindex.(sol.u, 2) zvals = getindex.(sol.u, 3) for (sym, ans, check_inference) in [(:x, xvals, true) - (:y, yvals, true) - (:z, zvals, true) - (1, xvals, true) - ([:x, :y], vcat.(xvals, yvals), true) - (1:2, vcat.(xvals, yvals), true) - ([:x, 2], vcat.(xvals, yvals), false) - ((:z, :y), tuple.(zvals, yvals), true) - ((3, 2), tuple.(zvals, yvals), true) - ([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false) - ([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false) - ([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false) - ([:x, [:y, :z], (:x, :z)], - vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), - false) - ([:x, [:y, 3], (1, :z)], - vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), - false) - ((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true) - ((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true) - ((:x, [:y, :z], (:z, :y)), - tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), - true) - ([:x, :a], vcat.(xvals, p[1]), false) - ((:y, :b), tuple.(yvals, p[2]), true) - (:t, t, true) - ([:x, :a, :t], vcat.(xvals, p[1], t), false) - ((:x, :a, :t), tuple.(xvals, p[1], t), true)] + (:y, yvals, true) + (:z, zvals, true) + (1, xvals, true) + ([:x, :y], vcat.(xvals, yvals), true) + (1:2, vcat.(xvals, yvals), true) + ([:x, 2], vcat.(xvals, yvals), false) + ((:z, :y), tuple.(zvals, yvals), true) + ((3, 2), tuple.(zvals, yvals), true) + ([:x, [:y, :z]], + vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), + false) + ([:x, (:y, :z)], + vcat.(xvals, tuple.(yvals, zvals)), false) + ([1, (:y, :z)], + vcat.(xvals, tuple.(yvals, zvals)), false) + ([:x, [:y, :z], (:x, :z)], + vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], + tuple.(xvals, zvals)), + false) + ([:x, [:y, 3], (1, :z)], + vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], + tuple.(xvals, zvals)), + false) + ((:x, [:y, :z]), + tuple.(xvals, vcat.(yvals, zvals)), true) + ((:x, (:y, :z)), + tuple.(xvals, tuple.(yvals, zvals)), true) + ((:x, [:y, :z], (:z, :y)), + tuple.(xvals, vcat.(yvals, zvals), + tuple.(zvals, yvals)), + true) + ([:x, :a], vcat.(xvals, p[1]), false) + ((:y, :b), tuple.(yvals, p[2]), true) + (:t, t, true) + ([:x, :a, :t], vcat.(xvals, p[1], t), false) + ((:x, :a, :t), tuple.(xvals, p[1], t), true)] get = getu(sys, sym) if check_inference @inferred get(sol) @@ -164,10 +183,10 @@ for (sym, ans, check_inference) in [(:x, xvals, true) end for (sym, val) in [(:a, p[1]) - (:b, p[2]) - (:c, p[3]) - ([:a, :b], p[1:2]) - ((:c, :b), (p[3], p[2]))] + (:b, p[2]) + (:c, p[3]) + ([:a, :b], p[1:2]) + ((:c, :b), (p[3], p[2]))] get = getu(fi, sym) @inferred get(fi) @test get(fi) == val