diff --git a/Project.toml b/Project.toml index 0ccd484..df1aacb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicIndexingInterface" uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" authors = ["Aayush Sabharwal and contributors"] -version = "0.3.19" +version = "0.3.20" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/docs/pages.jl b/docs/pages.jl index b27f2c5..ecb5c77 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -2,6 +2,7 @@ pages = [ "Home" => "index.md", + "Terminology" => "terminology.md", "Tutorials" => [ "Using the SciML Symbolic Indexing Interface" => "usage.md", "Simple Demonstration of a Symbolic System Structure" => "simple_sii_sys.md", diff --git a/docs/src/api.md b/docs/src/api.md index d6b3cab..dc1c754 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,6 +1,8 @@ # Interface Functions -## Mandatory methods +## Index provider interface + +### Mandatory methods ```@docs symbolic_container @@ -22,14 +24,16 @@ solvedvariables allvariables ``` -## Optional Methods +### Optional Methods -### Observed equation handling +#### Observed equation handling ```@docs observed ``` +## Value provider interface + ### Parameter indexing ```@docs @@ -41,6 +45,19 @@ setp ParameterIndexingProxy ``` +#### Parameter timeseries + +If a solution object saves a timeseries of parameter values that are updated during the +simulation (such as by callbacks), it must implement the following methods to ensure +correct functioning of [`getu`](@ref) and [`getp`](@ref). + +```@docs +parameter_timeseries +parameter_values_at_time +parameter_values_at_state_time +``` + + ### State indexing ```@docs @@ -54,22 +71,17 @@ getu setu ``` -## Container objects +### Batched Queries and Updates ```@docs -remake_buffer +BatchedInterface +associated_systems ``` -### Parameter timeseries - -If a solution object saves a timeseries of parameter values that are updated during the -simulation (such as by callbacks), it must implement the following methods to ensure -correct functioning of [`getu`](@ref) and [`getp`](@ref). +## Container objects ```@docs -parameter_timeseries -parameter_values_at_time -parameter_values_at_state_time +remake_buffer ``` # Symbolic Trait @@ -90,10 +102,3 @@ symbolic_evaluate SymbolCache ProblemState ``` - -### Batched Queries and Updates - -```@docs -BatchedInterface -associated_systems -``` diff --git a/docs/src/terminology.md b/docs/src/terminology.md new file mode 100644 index 0000000..46c5173 --- /dev/null +++ b/docs/src/terminology.md @@ -0,0 +1,61 @@ +# Terminology + +SymbolicIndexingInterface.jl uses various library-specific terminology throughout its +documentation. This page attempts to provide comprehensive explanations of the terms +used. + +## Indexes + +An index is an object that defines how to extract specific data from a data structure. +Indexes may be anything from integer indexes into arrays, to custom types that contain +information specific to a particular data structure. + +In code samples, an index is typically denoted with the name `idx` or `i`. + +## Symbolic variables + +Symbolic variables are objects that represent quantities (states, parameters, time, etc.) +or collections of quantities used in a numerical simulation in a more human-accessible +manner. Typically the values of these quantities are stored in custom data structures +and referred to using indexes that do not convey any semantic meaning to users. Symbolic +variables cannot directly be used to access values from these data structures and need +to be translated into indexes. + +In code samples, a symbolic variable is typically denoted with the name `sym`. + +Symbolic variables are also sometimes referred to as "symbolic indices". + +## Index providers + +Index providers translate symbolic variables into indexes. In general, an "index" can +be anything from integer indexes into an array, or custom types that define how to +index into complicated data structures. `nothing` is reserved to denote the absence of +an index, in cases where the index provider is unaware of a particular symbolic variable. +It cannot be used as an index. ModelingToolkit.jl systems are examples of index providers. + +In code samples, an index provider is typically denoted with the name `indp`. + +## Value providers + +Value providers store values of symbolic variables. Given an appropriate index from an +index provider, value providers return the value (or values) stored at that index. The +problem, integrator and solution types in SciML are all examples of value providers. +Each value provider is (directly or indirectly) associated with an index provider that +defines the set of valid symbolic variables for that value provider, and the corresponding +indexes. + +A value provider may not store values for all symbolic variables in the corresponding index +provider. For example, a parameter object (even a plain `Array` storing parameter values) +is a value provider specifically for the symbolic variables referring to parameters. + +In code samples, a value provider is typically denoted with the name `valp`. + +!!! note + It is important to note that an object may be both a value- and index- provider. For + example, SciML's problem, integrator and solution types are both value- and index- + providers. This allows for several syntactic improvements. The [`symbolic_container`](@ref) + function is useful in defining such objects. + +!!! note "Timeseries objects" + The documentation uses "Timeseries objects" to refer to value providers which implement + the [`Timeseries`](@ref) variant of the [`is_timeseries`](@ref) trait. diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 7312843..3b90d0f 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -17,7 +17,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate -include("interface.jl") +include("index_provider_interface.jl") export SymbolCache include("symbol_cache.jl") diff --git a/src/batched_interface.jl b/src/batched_interface.jl index 5c3617b..25a06b0 100644 --- a/src/batched_interface.jl +++ b/src/batched_interface.jl @@ -1,14 +1,14 @@ """ struct BatchedInterface{S <: AbstractVector, I} - function BatchedInterface(syssyms::Tuple...) + function BatchedInterface(indp_syms::Tuple...) A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref). -Given `Tuple`s, where the first element of each tuple is a system and the second an -array of symbols (either variables or parameters) in the system, `BatchedInterface` will -compute the union of all symbols and associate each symbol with the first system with -which it occurs. +Given `Tuple`s, where the first element of each tuple is an index provider and the second +an array of symbolic variables (either states or parameters) in the index provider, +`BatchedInterface` will compute the union of all symbols and associate each symbol with +the first index provider with which it occurs. -For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and +For example, given two index providers `s1 = SymbolCache([:x, :y, :z])` and `s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also @@ -24,17 +24,17 @@ See also: [`associated_systems`](@ref). struct BatchedInterface{S <: AbstractVector, I, T} "Order of symbols in the union." symbol_order::S - "Index of the system each symbol in the union is associated with." + "Index of the index provider each symbol in the union is associated with." associated_systems::Vector{Int} - "Index of symbol in the system it is associated with." + "Index of symbol in the index provider it is associated with." associated_indexes::I - "Whether the symbol is a state in the system it is associated with." + "Whether the symbol is a state in the index provider it is associated with." isstate::BitVector - "Map from system to indexes of its symbols in the union." + "Map from index provider to indexes of its symbols in the union." system_to_symbol_subset::Vector{Vector{Int}} - "Map from system to indexes of its symbols in the system." + "Map from index provider to indexes of its symbols in the index provider." system_to_symbol_indexes::Vector{Vector{T}} - "Map from system to whether each of its symbols is a state in the system." + "Map from index provider to whether each of its symbols is a state in the index provider." system_to_isstate::Vector{BitVector} end @@ -102,7 +102,7 @@ is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing associated_systems(bi::BatchedInterface) Return an array of integers of the same length as `variable_symbols(bi)` where each value -is the index of the system associated with the corresponding symbol in +is the index of the index provider associated with the corresponding symbol in `variable_symbols(bi)`. """ associated_systems(bi::BatchedInterface) = bi.associated_systems @@ -110,18 +110,19 @@ associated_systems(bi::BatchedInterface) = bi.associated_systems """ getu(bi::BatchedInterface) -Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols), -return a function which takes `n` corresponding problems and returns an array of the values -of the symbols in the union. The returned function can also be passed an `AbstractArray` of -the appropriate `eltype` and size as its first argument, in which case the operation will -populate the array in-place with the values of the symbols in the union. +Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding +symbols), return a function which takes `n` corresponding value providers and returns an +array of the values of the symbols in the union. The returned function can also be passed +an `AbstractArray` of the appropriate `eltype` and size as its first argument, in which +case the operation will populate the array in-place with the values of the symbols in the +union. -Note that all of the problems passed to the function returned by `getu` must satisfy +Note that all of the value providers passed to the function returned by `getu` must satisfy `is_timeseries(prob) === NotTimeseries()`. The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is -obtained from the problem corresponding to the associated system (i.e. the problem at -index `associated_systems(bi)[i]`). +obtained from the problem corresponding to the associated index provider (i.e. the value +provider at index `associated_systems(bi)[i]`). See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref), [`NotTimeseries`](@ref). @@ -180,16 +181,16 @@ end """ setu(bi::BatchedInterface) -Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols), -return a function which takes `n` corresponding problems and an array of the values, and -updates each of the problems with the values of the corresponding symbols. +Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding +symbols), return a function which takes `n` corresponding problems and an array of the +values, and updates each of the problems with the values of the corresponding symbols. -Note that all of the problems passed to the function returned by `setu` must satisfy +Note that all of the value providers passed to the function returned by `setu` must satisfy `is_timeseries(prob) === NotTimeseries()`. -Note that if any subset of the `n` systems share common symbols (among those passed to -`BatchedInterface`) then all of the corresponding problems in the subset will be updated -with the values of the common symbols. +Note that if any subset of the `n` index providers share common symbols (among those passed +to `BatchedInterface`) then all of the corresponding value providers in the subset will be +updated with the values of the common symbols. See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref). """ diff --git a/src/index_provider_interface.jl b/src/index_provider_interface.jl new file mode 100644 index 0000000..512bded --- /dev/null +++ b/src/index_provider_interface.jl @@ -0,0 +1,175 @@ +""" + symbolic_container(indp) + +Using `indp`, return an object that implements the index provider interface. In case `indp` +itself implements the interface, `indp` can be returned as-is. All index provider interface +methods fall back to calling the same method on `symbolic_container(indp)`, so this may be +used for trivial implementations of the interface that forward all calls to another object. +""" +function symbolic_container end + +""" + is_variable(indp, sym) + +Check whether the given `sym` is a variable in `indp`. +""" +is_variable(indp, sym) = is_variable(symbolic_container(indp), sym) + +""" + variable_index(indp, sym, [i]) + +Return the index of the given variable `sym` in `indp`, or `nothing` otherwise. If +[`constant_structure`](@ref) is `false`, this accepts the current time index as an +additional parameter `i`. +""" +variable_index(indp, sym) = variable_index(symbolic_container(indp), sym) +variable_index(indp, sym, i) = variable_index(symbolic_container(indp), sym, i) + +""" + variable_symbols(indp, [i]) + +Return a vector of the symbolic variables being solved for in the index provider `indp`. +If `constant_structure(sys) == false` this accepts an additional parameter indicating +the current time index. The returned vector should not be mutated. + +For types that implement `Base.getindex` with symbolic indices using this interface, +the shorthand `valp[solvedvariables]` can be used as shorthand for +`valp[variable_symbols(sys)]`. See: [`solvedvariables`](@ref). +""" +variable_symbols(indp) = variable_symbols(symbolic_container(indp)) +variable_symbols(indp, i) = variable_symbols(symbolic_container(indp), i) + +""" + is_parameter(indp, sym) + +Check whether the given `sym` is a parameter in `indp`. +""" +is_parameter(indp, sym) = is_parameter(symbolic_container(indp), sym) + +""" + parameter_index(indp, sym) + +Return the index of the given parameter `sym` in `indp`, or `nothing` otherwise. +""" +parameter_index(indp, sym) = parameter_index(symbolic_container(indp), sym) + +""" + parameter_symbols(indp) + +Return a vector of the symbolic parameters of the given index provider `indp`. The returned +vector should not be mutated. +""" +parameter_symbols(indp) = parameter_symbols(symbolic_container(indp)) + +""" + is_independent_variable(indp, sym) + +Check whether the given `sym` is an independent variable in `indp`. The returned vector +should not be mutated. +""" +is_independent_variable(indp, sym) = is_independent_variable(symbolic_container(indp), sym) + +""" + independent_variable_symbols(indp) + +Return a vector of the symbolic independent variables of the given index provider `indp`. +""" +independent_variable_symbols(indp) = independent_variable_symbols(symbolic_container(indp)) + +""" + is_observed(indp, sym) + +Check whether the given `sym` is an observed value in `indp`. +""" +is_observed(indp, sym) = is_observed(symbolic_container(indp), sym) + +""" + observed(indp, sym, [states]) + +Return the observed function of the given `sym` in `indp`. The returned function should +have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and +parameter vector, respectively. If `istimedependent(indp) == true`, the function should +accept the current time `t` as its third parameter. If `constant_structure(indp) == false`, +`observed` accepts a third parameter, which can either be a vector of symbols indicating +the order of states or a time index, which identifies the order of states. This function +does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus, +it is mandatory to always check `is_observed` before using this function. + +See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref) +""" +observed(indp, sym) = observed(symbolic_container(indp), sym) +observed(indp, sym, states) = observed(symbolic_container(indp), sym, states) + +""" + is_time_dependent(indp) + +Check if `indp` has time as (one of) its independent variables. +""" +is_time_dependent(indp) = is_time_dependent(symbolic_container(indp)) + +""" + constant_structure(indp) + +Check if `indp` has a constant structure. Constant structure index providers do not change +the number of variables or parameters over time. +""" +constant_structure(indp) = constant_structure(symbolic_container(indp)) + +""" + all_variable_symbols(indp) + +Return a vector of variable symbols in the system, including observed quantities. + +For types that implement `Base.getindex` with symbolic indices using this interface, +The shorthand `sys[allvariables]` can be used as shorthand for +`valp[all_variable_symbols(indp)]`. + +See: [`allvariables`](@ref). +""" +all_variable_symbols(indp) = all_variable_symbols(symbolic_container(indp)) + +""" + all_symbols(indp) + +Return an array of all symbols in the index provider. This includes parameters and +independent variables. +""" +all_symbols(indp) = all_symbols(symbolic_container(indp)) + +""" + default_values(indp) + +Return a dictionary mapping symbols in the index provider to their default value, if any. +This includes parameter symbols. The dictionary must be mutable. +""" +function default_values(indp) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) + default_values(symbolic_container(indp)) + else + Dict() + end +end + +struct SolvedVariables end + +""" + const solvedvariables = SolvedVariables() + +This singleton is used as a shortcut to allow indexing of all solution variables +(excluding observed quantities). It has a [`symbolic_type`](@ref) of +[`ScalarSymbolic`](@ref). See: [`variable_symbols`](@ref). +""" +const solvedvariables = SolvedVariables() +symbolic_type(::Type{SolvedVariables}) = ScalarSymbolic() + +struct AllVariables end + +""" + const allvariables = AllVariables() + +This singleton is used as a shortcut to allow indexing of all solution variables +(including observed quantities). It has a [`symbolic_type`](@ref) of +[`ScalarSymbolic`](@ref). See [`all_variable_symbols`](@ref). +""" +const allvariables = AllVariables() +symbolic_type(::Type{AllVariables}) = ScalarSymbolic() diff --git a/src/interface.jl b/src/interface.jl deleted file mode 100644 index 4820783..0000000 --- a/src/interface.jl +++ /dev/null @@ -1,173 +0,0 @@ -""" - symbolic_container(p) - -Using `p`, return an object that implements the symbolic indexing interface. In case `p` -itself implements the interface, `p` can be returned as-is. All symbolic indexing interface -methods fall back to calling the same method on `symbolic_container(p)`, so this may be -used for trivial implementations of the interface that forward all calls to another object. -""" -function symbolic_container end - -""" - is_variable(sys, sym) - -Check whether the given `sym` is a variable in `sys`. -""" -is_variable(sys, sym) = is_variable(symbolic_container(sys), sym) - -""" - variable_index(sys, sym, [i]) - -Return the index of the given variable `sym` in `sys`, or `nothing` otherwise. If -[`constant_structure`](@ref) is `false`, this accepts the current time index as an -additional parameter `i`. -""" -variable_index(sys, sym) = variable_index(symbolic_container(sys), sym) -variable_index(sys, sym, i) = variable_index(symbolic_container(sys), sym, i) - -""" - variable_symbols(sys, [i]) - -Return a vector of the symbolic variables being solved for in the system `sys`. If -`constant_structure(sys) == false` this accepts an additional parameter indicating -the current time index. The returned vector should not be mutated. - -For types that implement `Base.getindex` with symbolic indices using this interface, -the shorthand `sys[solvedvariables]` can be used as shorthand for -`sys[variable_symbols(sys)]`. See: [`solvedvariables`](@ref). -""" -variable_symbols(sys) = variable_symbols(symbolic_container(sys)) -variable_symbols(sys, i) = variable_symbols(symbolic_container(sys), i) - -""" - is_parameter(sys, sym) - -Check whether the given `sym` is a parameter in `sys`. -""" -is_parameter(sys, sym) = is_parameter(symbolic_container(sys), sym) - -""" - parameter_index(sys, sym) - -Return the index of the given parameter `sym` in `sys`, or `nothing` otherwise. -""" -parameter_index(sys, sym) = parameter_index(symbolic_container(sys), sym) - -""" - parameter_symbols(sys) - -Return a vector of the symbolic parameters of the given system `sys`. The returned -vector should not be mutated. -""" -parameter_symbols(sys) = parameter_symbols(symbolic_container(sys)) - -""" - is_independent_variable(sys, sym) - -Check whether the given `sym` is an independent variable in `sys`. The returned vector -should not be mutated. -""" -is_independent_variable(sys, sym) = is_independent_variable(symbolic_container(sys), sym) - -""" - independent_variable_symbols(sys) - -Return a vector of the symbolic independent variables of the given system `sys`. -""" -independent_variable_symbols(sys) = independent_variable_symbols(symbolic_container(sys)) - -""" - is_observed(sys, sym) - -Check whether the given `sym` is an observed value in `sys`. -""" -is_observed(sys, sym) = is_observed(symbolic_container(sys), sym) - -""" - observed(sys, sym, [states]) - -Return the observed function of the given `sym` in `sys`. The returned function should -have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and -parameter vector, respectively. If `istimedependent(sys) == true`, the function should accept -the current time `t` as its third parameter. If `constant_structure(sys) == false`, -accept a third parameter, which can either be a vector of symbols indicating the order -of states or a time index, which identifies the order of states. This function -does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus, -it is mandatory to always check `is_observed` before using this function. - -See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref) -""" -observed(sys, sym) = observed(symbolic_container(sys), sym) -observed(sys, sym, states) = observed(symbolic_container(sys), sym, states) - -""" - is_time_dependent(sys) - -Check if `sys` has time as (one of) its independent variables. -""" -is_time_dependent(sys) = is_time_dependent(symbolic_container(sys)) - -""" - constant_structure(sys) - -Check if `sys` has a constant structure. Constant structure systems do not change the -number of variables or parameters over time. -""" -constant_structure(sys) = constant_structure(symbolic_container(sys)) - -""" - all_variable_symbols(sys) - -Return a vector of variable symbols in the system, including observed quantities. - -For types that implement `Base.getindex` with symbolic indices using this interface, -The shorthand `sys[allvariables]` can be used as shorthand for -`sys[all_variable_symbols(sys)]`. See: [`allvariables`](@ref). -""" -all_variable_symbols(sys) = all_variable_symbols(symbolic_container(sys)) - -""" - all_symbols(sys) - -Return an array of all symbols in the system. This includes parameters and independent -variables. -""" -all_symbols(sys) = all_symbols(symbolic_container(sys)) - -""" - default_values(sys) - -Return a dictionary mapping symbols in the system to their default value, if any. This -includes parameter symbols. The dictionary must be mutable. -""" -function default_values(sys) - if hasmethod(symbolic_container, Tuple{typeof(sys)}) - default_values(symbolic_container(sys)) - else - Dict() - end -end - -struct SolvedVariables end - -""" - const solvedvariables = SolvedVariables() - -This singleton is used as a shortcut to allow indexing of all solution variables -(excluding observed quantities). It has a [`symbolic_type`](@ref) of -[`ScalarSymbolic`](@ref). See: [`variable_symbols`](@ref). -""" -const solvedvariables = SolvedVariables() -symbolic_type(::Type{SolvedVariables}) = ScalarSymbolic() - -struct AllVariables end - -""" - const allvariables = AllVariables() - -This singleton is used as a shortcut to allow indexing of all solution variables -(including observed quantities). It has a [`symbolic_type`](@ref) of -[`ScalarSymbolic`](@ref). See [`all_variable_symbols`](@ref). -""" -const allvariables = AllVariables() -symbolic_type(::Type{AllVariables}) = ScalarSymbolic() diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index faa4551..d8b3122 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -17,20 +17,21 @@ end set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx) """ - getp(sys, p) + getp(indp, sym) -Return a function that takes an array representing the parameter object or an integrator -or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a -direct index or a symbolic value, or an array/tuple of the aforementioned. +Return a function that takes an value provider, and returns the value of the +parameter `sym`. The value provider has to at least store the values of parameters +in the corresponding index provider. Note that `sym` can be an index, symbolic variable, +or an array/tuple of the aforementioned. -If `p` is an array/tuple of parameters, then the returned function can also be used -as an in-place getter function. The first argument is the buffer to which the parameter -values should be written, and the second argument is the parameter object/integrator/ -solution from which the values are obtained. +If `sym` is an array/tuple of parameters, then the returned function can also be used +as an in-place getter function. The first argument is the buffer (must be an +`AbstractArray`) to which the parameter values should be written, and the second argument +is the value provider. -Requires that the integrator or solution implement [`parameter_values`](@ref). This function -typically does not need to be implemented, and has a default implementation relying on -[`parameter_values`](@ref). +Requires that the value provider implement [`parameter_values`](@ref). This function +may not always need to be implemented, and has a default implementation for collections +that implement `getindex`. If the returned function is used on a timeseries object which saves parameter timeseries, it can be used to index said timeseries. The timeseries object must implement @@ -45,7 +46,7 @@ function getp(sys, p) _getp(sys, symtype, elsymtype, p) end -struct GetParameterIndex{I} <: AbstractIndexer +struct GetParameterIndex{I} <: AbstractGetIndexer idx::I end @@ -78,7 +79,7 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) sys, NotSymbolic(), NotSymbolic(), idx) end -struct MultipleParameterGetters{G} +struct MultipleParameterGetters{G} <: AbstractGetIndexer getters::G end @@ -148,50 +149,62 @@ function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p) return getp(sys, collect(p)) end +struct ParameterHookWrapper{S, O} <: AbstractSetIndexer + setter::S + original_index::O +end + +function (phw::ParameterHookWrapper)(prob, args...) + res = phw.setter(prob, args...) + finalize_parameters_hook!(prob, phw.original_index) + res +end + """ - setp(sys, p) + setp(indp, sym) -Return a function that takes an array representing the parameter vector or an integrator -or problem of `sys`, and a value, and sets the parameter `p` to that value. Note that `p` -can be a direct index or a symbolic value. +Return a function that takes an index provider and a value, and sets the parameter `sym` +to that value. Note that `sym` can be an index, a symbolic variable, or an array/tuple of +the aforementioned. -Requires that the integrator implement [`parameter_values`](@ref) and the returned -collection be a mutable reference to the parameter vector in the integrator. In -case `parameter_values` cannot return such a mutable reference, or additional actions -need to be performed when updating parameters, [`set_parameter!`](@ref) must be -implemented. +Requires that the value provider implement [`parameter_values`](@ref) and the returned +collection be a mutable reference to the parameter object. In case `parameter_values` +cannot return such a mutable reference, or additional actions need to be performed when +updating parameters, [`set_parameter!`](@ref) must be implemented. """ function setp(sys, p; run_hook = true) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) return if run_hook - let _setter! = _setp(sys, symtype, elsymtype, p), p = p - function setter!(prob, args...) - res = _setter!(prob, args...) - finalize_parameters_hook!(prob, p) - res - end - end + return ParameterHookWrapper(_setp(sys, symtype, elsymtype, p), p) else _setp(sys, symtype, elsymtype, p) end end +struct SetParameterIndex{I} <: AbstractSetIndexer + idx::I +end + +function (spi::SetParameterIndex)(prob, val) + set_parameter!(prob, val, spi.idx) +end + function _setp(sys, ::NotSymbolic, ::NotSymbolic, p) - return let p = p - function setter!(sol, val) - set_parameter!(sol, val, p) - end - end + return SetParameterIndex(p) end function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) - return let idx = idx - function setter!(sol, val) - set_parameter!(sol, val, idx) - end - end + return SetParameterIndex(idx) +end + +struct MultipleSetters{S} <: AbstractSetIndexer + setters::S +end + +function (ms::MultipleSetters)(prob, val) + map((s!, v) -> s!(prob, v), ms.setters, val) end for (t1, t2) in [ @@ -201,11 +214,7 @@ for (t1, t2) in [ ] @eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2) setters = setp.((sys,), p; run_hook = false) - return let setters = setters - function setter!(sol, val) - map((s!, v) -> s!(sol, v), setters, val) - end - end + return MultipleSetters(setters) end end diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index ab365ed..cc0925d 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -1,9 +1,9 @@ """ struct ParameterIndexingProxy -This struct wraps any struct implementing the symbolic indexing interface. It allows -`getindex` and `setindex!` operations to get/set parameter values. Requires that the -wrapped type support [`getp`](@ref) and [`setp`](@ref) for getting and setting +This struct wraps any struct implementing the value provider and index provider interfaces. +It allows `getindex` and `setindex!` operations to get/set parameter values. Requires that +the wrapped type support [`getp`](@ref) and [`setp`](@ref) for getting and setting parameter values respectively. """ struct ParameterIndexingProxy{T} diff --git a/src/problem_state.jl b/src/problem_state.jl index 4a2312c..fa93656 100644 --- a/src/problem_state.jl +++ b/src/problem_state.jl @@ -2,9 +2,9 @@ struct ProblemState function ProblemState(; u = nothing, p = nothing, t = nothing) -A struct which can be used as an argument to the function returned by [`getu`](@ref) or -[`setu`](@ref). It stores the state vector, parameter object and current time, and -forwards calls to [`state_values`](@ref), [`parameter_values`](@ref), +A value provider struct which can be used as an argument to the function returned by +[`getu`](@ref) or [`setu`](@ref). It stores the state vector, parameter object and +current time, and forwards calls to [`state_values`](@ref), [`parameter_values`](@ref), [`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained objects. """ diff --git a/src/remake.jl b/src/remake.jl index 0eb06c3..9f2029e 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -1,8 +1,8 @@ """ - remake_buffer(sys, oldbuffer, vals::Dict) + remake_buffer(indp, oldbuffer, vals::Dict) Return a copy of the buffer `oldbuffer` with values from `vals`. The keys of `vals` -are symbolic variables whose index in the buffer is determined using `sys`. The types of +are symbolic variables whose index in the buffer is determined using `indp`. The types of values in `vals` may not match the types of values stored at the corresponding indexes in the buffer, in which case the type of the buffer should be promoted accordingly. In general, this method should attempt to preserve the types of values stored in `vals` as @@ -10,7 +10,7 @@ much as possible. Types can be promoted for type-stability, to maintain performa returned buffer should be of the same type (ignoring type-parameters) as `oldbuffer`. This method is already implemented for -`remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays +`remake_buffer(indp, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays as well. It is also implemented for `oldbuffer::Tuple`. """ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index e854b29..272c841 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -8,21 +8,20 @@ end current_time(p, i) = current_time(p)[i] """ - getu(sys, sym) - -Return a function that takes an integrator, problem or solution of `sys`, and returns -the value of the symbolic `sym`. If `sym` is not an observed quantity, the returned -function can also directly be called with an array of values representing the state -vector. `sym` can be a direct index into the state vector, a symbolic state, a symbolic -expression involving symbolic quantities in the system `sys`, a parameter symbol, or the -independent variable symbol, or an array/tuple of the aforementioned. If the returned -function is called with a timeseries object, it can also be given a second argument -representing the index at which to find the value of `sym`. - -At minimum, this requires that the integrator, problem or solution implement -[`state_values`](@ref). To support symbolic expressions, the integrator or problem -must implement [`observed`](@ref), [`parameter_values`](@ref) and -[`current_time`](@ref). + getu(indp, sym) + +Return a function that takes a value provider and returns the value of the symbolic +variable `sym`. If `sym` is not an observed quantity, the returned function can also +directly be called with an array of values representing the state vector. `sym` can be an +index into the state vector, a symbolic variable, a symbolic expression involving symbolic +variables in the index provider `indp`, a parameter symbol, the independent variable +symbol, or an array/tuple of the aforementioned. If the returned function is called with +a timeseries object, it can also be given a second argument representing the index at +which to return the value of `sym`. + +At minimum, this requires that the value provider implement [`state_values`](@ref). To +support symbolic expressions, the value provider must implement [`observed`](@ref), +[`parameter_values`](@ref) and [`current_time`](@ref). This function typically does not need to be implemented, and has a default implementation relying on the above functions. @@ -33,7 +32,7 @@ function getu(sys, sym) _getu(sys, symtype, elsymtype, sym) end -struct GetStateIndex{I} <: AbstractIndexer +struct GetStateIndex{I} <: AbstractGetIndexer idx::I end function (gsi::GetStateIndex)(::Timeseries, prob) @@ -50,7 +49,7 @@ function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) return GetStateIndex(sym) end -struct GetpAtStateTime{G} <: AbstractIndexer +struct GetpAtStateTime{G} <: AbstractGetIndexer getter::G end @@ -65,12 +64,12 @@ function (g::GetpAtStateTime)(::NotTimeseries, prob) g.getter(prob) end -struct GetIndepvar <: AbstractIndexer end +struct GetIndepvar <: AbstractGetIndexer end (::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob) (::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i) -struct TimeDependentObservedFunction{F} <: AbstractIndexer +struct TimeDependentObservedFunction{F} <: AbstractGetIndexer obsfn::F end @@ -89,7 +88,7 @@ function (o::TimeDependentObservedFunction)(::NotTimeseries, prob) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) end -struct TimeIndependentObservedFunction{F} <: AbstractIndexer +struct TimeIndependentObservedFunction{F} <: AbstractGetIndexer obsfn::F end @@ -116,7 +115,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) error("Invalid symbol $sym for `getu`") end -struct MultipleGetters{G} <: AbstractIndexer +struct MultipleGetters{G} <: AbstractGetIndexer getters::G end @@ -131,12 +130,18 @@ function (mg::MultipleGetters)(::NotTimeseries, prob) return map(g -> g(prob), mg.getters) end -struct AsTupleWrapper{G} <: AbstractIndexer +struct AsTupleWrapper{G} <: AbstractGetIndexer getter::G end -function (atw::AsTupleWrapper)(::IsTimeseriesTrait, args...) - return Tuple(atw.getter(args...)) +function (atw::AsTupleWrapper)(::Timeseries, prob) + return Tuple.(atw.getter(prob)) +end +function (atw::AsTupleWrapper)(::Timeseries, prob, i) + return Tuple(atw.getter(prob, i)) +end +function (atw::AsTupleWrapper)(::NotTimeseries, prob) + return Tuple(atw.getter(prob)) end for (t1, t2) in [ @@ -151,7 +156,7 @@ for (t1, t2) in [ return MultipleGetters(getters) else obs = observed(sys, sym isa Tuple ? collect(sym) : sym) - getter = if is_timeseries(sys) + getter = if is_time_dependent(sys) TimeDependentObservedFunction(obs) else TimeIndependentObservedFunction(obs) @@ -179,15 +184,15 @@ end """ setu(sys, sym) -Return a function that takes an array representing the state vector or an integrator or -problem of `sys`, and a value, and sets the the state `sym` to that value. Note that `sym` -can be a direct index, a symbolic state, or an array/tuple of the aforementioned. +Return a function that takes a value provider and a value, and sets the the state `sym` to +that value. Note that `sym` can be an index, a symbolic variable, or an array/tuple of the +aforementioned. -Requires that the integrator implement [`state_values`](@ref) and the -returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to -be performed when updating state, [`set_state!`](@ref) can be defined. -This function does not work on types for which [`is_timeseries`](@ref) is -[`Timeseries`](@ref). +Requires that the value provider implement [`state_values`](@ref) and the returned +collection be a mutable reference to the state vector in the value provider. Alternatively, +if this is not possible or additional actions need to be performed when updating state, +[`set_state!`](@ref) can be defined. This function does not work on types for which +[`is_timeseries`](@ref) is [`Timeseries`](@ref). """ function setu(sys, sym) symtype = symbolic_type(sym) @@ -195,18 +200,22 @@ function setu(sys, sym) _setu(sys, symtype, elsymtype, sym) end +struct SetStateIndex{I} <: AbstractSetIndexer + idx::I +end + +function (ssi::SetStateIndex)(prob, val) + set_state!(prob, val, ssi.idx) +end + function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym) - return function setter!(prob, val) - set_state!(prob, val, sym) - end + return SetStateIndex(sym) end function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return function setter!(prob, val) - set_state!(prob, val, idx) - end + return SetStateIndex(idx) elseif is_parameter(sys, sym) return setp(sys, sym) end @@ -220,16 +229,18 @@ for (t1, t2) in [ ] @eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2) setters = setu.((sys,), sym) - return function setter!(prob, val) - map((s!, v) -> s!(prob, v), setters, val) - end + return MultipleSetters(setters) end end function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return setu(sys, idx) + if idx isa AbstractArray + return MultipleSetters(SetStateIndex.(idx)) + else + return SetStateIndex(idx) + end elseif is_parameter(sys, sym) return setp(sys, sym) end diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 481925f..ad61409 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -2,7 +2,7 @@ struct SymbolCache{V,P,I} function SymbolCache(vars, [params, [indepvars]]) -A struct implementing the symbolic indexing interface for the trivial case of having a +A struct implementing the index provider interface for the trivial case of having a vector of variables, parameters, and independent variables. It is considered time dependent if it contains at least one independent variable. It returns `true` for `is_observed(::SymbolCache, sym)` if `sym isa Expr`. Functions can be generated using diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 5214841..1f455b8 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -3,13 +3,13 @@ ########### """ - parameter_values(p) - parameter_values(p, i) + parameter_values(valp) + parameter_values(valp, i) -Return an indexable collection containing the value of each parameter in `p`. The two- +Return an indexable collection containing the value of each parameter in `valp`. The two- argument version of this function returns the parameter value at index `i`. The two-argument version of this function will default to returning -`parameter_values(p)[i]`. +`parameter_values(valp)[i]`. If this function is called with an `AbstractArray` or `Tuple`, it will return the same array/tuple. @@ -17,15 +17,14 @@ array/tuple. function parameter_values end """ - parameter_values_at_time(p, i) + parameter_values_at_time(valp, i) -Return an indexable collection containing the value of all parameters in `p` at time index -`i`. This is useful when parameter values change during the simulation -(such as through callbacks) and their values are saved. `i` is the time index in the -timeseries formed by these changing parameter values, obtained using -[`parameter_timeseries`](@ref). +Return an indexable collection containing the value of all parameters in `valp` at time +index `i`. This is useful when parameter values change during the simulation (such as +through callbacks) and their values are saved. `i` is the time index in the timeserie + formed by these changing parameter values, obtained using [`parameter_timeseries`](@ref). -By default, this function returns `parameter_values(p)` regardless of `i`, and only needs +By default, this function returns `parameter_values(valp)` regardless of `i`, and only needs to be specialized for timeseries objects where parameter values are not constant at all times. The resultant object should be indexable using [`parameter_values`](@ref). @@ -35,16 +34,16 @@ implemented for [`getu`](@ref) to work correctly. function parameter_values_at_time end """ - parameter_values_at_state_time(p, i) + parameter_values_at_state_time(valp, i) -Return an indexable collection containing the value of all parameters in `p` at time +Return an indexable collection containing the value of all parameters in `valp` at time index `i`. This is useful when parameter values change during the simulation (such as through callbacks) and their values are saved. `i` is the time index in the timeseries formed by dependent variables (as opposed to the timeseries of the parameters, as in [`parameter_values_at_time`](@ref)). -By default, this function returns `parameter_values(p)` regardless of `i`, and only needs -to be specialized for timeseries objects where parameter values are not constant at +By default, this function returns `parameter_values(valp)` regardless of `i`, and only +needs to be specialized for timeseries objects where parameter values are not constant at all times. The resultant object should be indexable using [`parameter_values`](@ref). If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for @@ -53,10 +52,10 @@ If this function is implemented, [`parameter_values_at_time`](@ref) must be impl function parameter_values_at_state_time end """ - parameter_timeseries(p) + parameter_timeseries(valp) Return an iterable of time steps at which the parameter values are saved. This is only -required for objects where `is_timeseries(p) === Timeseries()` and the parameter values +required for objects where `is_timeseries(valp) === Timeseries()` and the parameter values change during the simulation (such as through callbacks). By default, this returns `[0]`. See also: [`parameter_values_at_time`](@ref). @@ -64,41 +63,41 @@ See also: [`parameter_values_at_time`](@ref). function parameter_timeseries end """ - set_parameter!(sys, val, idx) + set_parameter!(valp, val, idx) -Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying -`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the -default implementation does not work for a particular type, this method needs to be -defined to enable the proper functioning of [`setp`](@ref). +Set the parameter at index `idx` to `val` for value provider `valp`. This defaults to +modifying `parameter_values(valp)`. If any additional bookkeeping needs to be performed +or the default implementation does not work for a particular type, this method needs to +be defined to enable the proper functioning of [`setp`](@ref). See: [`parameter_values`](@ref) """ function set_parameter! end """ - finalize_parameters_hook!(prob, p) + finalize_parameters_hook!(valp, sym) This is a callback run one for each call to the function returned by [`setp`](@ref) which can be used to update internal data structures when parameters are modified. This is in contrast to [`set_parameter!`](@ref) which is run once for each parameter that is updated. """ -finalize_parameters_hook!(prob, p) = nothing +finalize_parameters_hook!(valp, sym) = nothing ########### # State Indexing ########### """ - state_values(p) - state_values(p, i) + state_values(valp) + state_values(valp, i) -Return an indexable collection containing the values of all states in the integrator or -problem `p`. If `is_timeseries(p)` is [`Timeseries`](@ref), return a vector of arrays, +Return an indexable collection containing the values of all states in the value provider +`p`. If `is_timeseries(valp)` is [`Timeseries`](@ref), return a vector of arrays, each of which contain the state values at the corresponding timestep. In this case, the two-argument version of the function can also be implemented to efficiently return the state values at timestep `i`. By default, the two-argument method calls -`state_values(p)[i]` +`state_values(valp)[i]` If this function is called with an `AbstractArray`, it will return the same array. @@ -107,10 +106,10 @@ See: [`is_timeseries`](@ref) function state_values end """ - set_state!(sys, val, idx) + set_state!(valp, val, idx) -Set the state at index `idx` to `val` for system `sys`. This defaults to modifying -`state_values(sys)`. If any additional bookkeeping needs to be performed or the +Set the state at index `idx` to `val` for value provider `valp`. This defaults to modifying +`state_values(valp)`. If any additional bookkeeping needs to be performed or the default implementation does not work for a particular type, this method needs to be defined to enable the proper functioning of [`setu`](@ref). @@ -119,11 +118,11 @@ See: [`state_values`](@ref) function set_state! end """ - current_time(p) - current_time(p, i) + current_time(valp) + current_time(valp, i) -Return the current time in the integrator or problem `p`. If -`is_timeseries(p)` is [`Timeseries`](@ref), return the vector of timesteps at which +Return the current time in the value provider `valp`. If +`is_timeseries(valp)` is [`Timeseries`](@ref), return the vector of timesteps at which the state value is saved. In this case, the two-argument version of the function can also be implemented to efficiently return the time at timestep `i`. By default, the two- argument method calls `current_time(p)[i]` @@ -139,5 +138,8 @@ function current_time end abstract type AbstractIndexer end -(ai::AbstractIndexer)(prob) = ai(is_timeseries(prob), prob) -(ai::AbstractIndexer)(prob, i) = ai(is_timeseries(prob), prob, i) +abstract type AbstractGetIndexer <: AbstractIndexer end +abstract type AbstractSetIndexer <: AbstractIndexer end + +(ai::AbstractGetIndexer)(prob) = ai(is_timeseries(prob), prob) +(ai::AbstractGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i) diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 980d3d8..16c8a71 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -75,6 +75,18 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) @test get(u) == val end +for (sym, val, check_inference) in [ + (:(x + y), u[1] + u[2], true), + ([:(x + y), :z], [u[1] + u[2], u[3]], false), + ((:(x + y), :(z + y)), (u[1] + u[2], u[2] + u[3]), false) +] + get = getu(sys, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == val +end + for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) (:b, p[2], 5.0, true) (:c, p[3], 6.0, true) @@ -101,7 +113,7 @@ end for (sym, val, check_inference) in [ (:t, t, true), ([:x, :a, :t], [u[1], p[1], t], false), - ((:x, :a, :t), (u[1], p[1], t), true) + ((:x, :a, :t), (u[1], p[1], t), false) ] get = getu(fi, sym) if check_inference @@ -182,6 +194,18 @@ for (sym, ans, check_inference) in [(:x, xvals, true) end end +for (sym, val, check_inference) in [ + (:(x + y), xvals .+ yvals, true), + ([:(x + y), :z], vcat.(xvals .+ yvals, zvals), false), + ((:(x + y), :(z + y)), tuple.(xvals .+ yvals, yvals .+ zvals), false) +] + get = getu(sys, sym) + if check_inference + @inferred get(sol) + end + @test get(sol) == val +end + for (sym, val) in [(:a, p[1]) (:b, p[2]) (:c, p[3])