From 20719c5d26404e4325f43b32d758d5e2039f8304 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 3 Feb 2024 20:40:29 +0530 Subject: [PATCH] refactor: make getu, getp generic of the parameter container --- src/parameter_indexing.jl | 16 +++++++++++----- src/state_indexing.jl | 4 ++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index e28720b..5ff2982 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -1,13 +1,19 @@ """ parameter_values(p) + parameter_values(p, i) -Return an indexable collection containing the value of each parameter in `p`. +Return an indexable collection containing the value of each parameter in `p`. The two- +argument version of this function returns the parameter value at index `i`. The +two-argument version of this function will default to returning +`parameter_values(p)[i]`. If this function is called with an `AbstractArray`, it will return the same array. """ function parameter_values end parameter_values(arr::AbstractArray) = arr +parameter_values(arr::AbstractArray, i) = arr[i] +parameter_values(prob, i) = parameter_values(parameter_values(prob), i) """ set_parameter!(sys, val, idx) @@ -28,7 +34,7 @@ end Return a function that takes an array representing the parameter vector or an integrator or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a -direct numerical index or a symbolic value, or an array/tuple of the aforementioned. +direct index or a symbolic value, or an array/tuple of the aforementioned. Requires that the integrator or solution implement [`parameter_values`](@ref). This function typically does not need to be implemented, and has a default implementation relying on @@ -42,14 +48,14 @@ end function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) return function getter(sol) - return parameter_values(sol)[p] + return parameter_values(sol, p) end end function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) return function getter(sol) - return parameter_values(sol)[idx] + return parameter_values(sol, idx) end end @@ -76,7 +82,7 @@ end Return a function that takes an array representing the parameter vector or an integrator or problem of `sys`, and a value, and sets the parameter `p` to that value. Note that `p` -can be a direct numerical index or a symbolic value. +can be a direct index or a symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the returned collection be a mutable reference to the parameter vector in the integrator. In diff --git a/src/state_indexing.jl b/src/state_indexing.jl index f2cd7f2..4638fe1 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -116,7 +116,7 @@ end function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) _getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,)) _getter(::Timeseries, prob, i) = getindex(state_values(prob, i), sym) - _getter(::NotTimeseries, prob) = state_values(prob)[sym] + _getter(::NotTimeseries, prob) = state_values(prob, sym) return let _getter = _getter function getter(prob) return _getter(is_timeseries(prob), prob) @@ -266,7 +266,7 @@ end Return a function that takes an array representing the state vector or an integrator or problem of `sys`, and a value, and sets the the state `sym` to that value. Note that `sym` -can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned. +can be a direct index, a symbolic state, or an array/tuple of the aforementioned. Requires that the integrator implement [`state_values`](@ref) and the returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to