diff --git a/docs/src/api.md b/docs/src/api.md index cf767b9..29150c0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -46,7 +46,7 @@ is_markovian If the index provider contains parameters that change during the course of the simulation at discrete time points, it must implement the following methods to ensure correct -functioning of [`getu`](@ref) and [`getp`](@ref) for value providers that save the parameter +functioning of [`getsym`](@ref) and [`getp`](@ref) for value providers that save the parameter timeseries. Note that there can be multiple parameter timeseries, in case different parameters may change at different times. @@ -69,10 +69,13 @@ is_timeseries state_values set_state! current_time -getu -setu +getsym +setsym ``` +!!! note + `getu` and `setu` have been renamed to [`getsym`](@ref) and [`setsym`](@ref) respectively. + #### Historical value providers ```@docs @@ -95,7 +98,7 @@ ParameterIndexingProxy If a solution object saves a timeseries of parameter values that are updated during the simulation (such as by callbacks), it must implement the following methods to ensure -correct functioning of [`getu`](@ref) and [`getp`](@ref). +correct functioning of [`getsym`](@ref) and [`getp`](@ref). Parameter timeseries support requires that the value provider store the different timeseries in a [`ParameterTimeseriesCollection`](@ref). diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index ae55e98..c75cdc8 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -174,9 +174,9 @@ end ``` If a type contains the value of state variables, it can define [`state_values`](@ref) to -enable the usage of [`getu`](@ref) and [`setu`](@ref). These methods retturn getter/ +enable the usage of [`getsym`](@ref) and [`setsym`](@ref). These methods retturn getter/ setter functions to access or update the value of a state variable (or a collection of -them). If the type also supports generating [`observed`](@ref) functions, `getu` also +them). If the type also supports generating [`observed`](@ref) functions, `getsym` also enables returning functions to access the value of arbitrary expressions involving the system's symbols. This also requires that the type implement [`parameter_values`](@ref) and [`current_time`](@ref) (if the system is time-dependent). @@ -202,13 +202,13 @@ Then the following example would work: ```julia sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict()) integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys) -getx = getu(sys, :x) +getx = getsym(sys, :x) getx(integrator) # 1.0 -get_expr = getu(sys, :(x + y + t)) +get_expr = getsym(sys, :(x + y + t)) get_expr(integrator) # 13.0 -setx! = setu(sys, :y) +setx! = setsym(sys, :y) setx!(integrator, 0.0) getx(integrator) # 0.0 ``` @@ -216,7 +216,7 @@ getx(integrator) # 0.0 In case a type stores timeseries data (such as solutions), then it must also implement the [`Timeseries`](@ref) trait. The type would then return a timeseries from [`state_values`](@ref) and [`current_time`](@ref) and the function returned from -[`getu`](@ref) would then return a timeseries as well. For example, consider the +[`getsym`](@ref) would then return a timeseries as well. For example, consider the `ExampleSolution` below: ```julia @@ -242,20 +242,20 @@ Then the following example would work: ```julia # using the same system that the ExampleIntegrator used sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys) -getx = getu(sys, :x) +getx = getsym(sys, :x) getx(sol) # [1.0, 1.5] -get_expr = getu(sys, :(x + y + t)) +get_expr = getsym(sys, :(x + y + t)) get_expr(sol) # [9.0, 11.0] -get_arr = getu(sys, [:y, :(x + a)]) +get_arr = getsym(sys, [:y, :(x + a)]) get_arr(sol) # [[2.0, 5.0], [2.5, 5.5]] -get_tuple = getu(sys, (:z, :(z * t))) +get_tuple = getsym(sys, (:z, :(z * t))) get_tuple(sol) # [(3.0, 18.0), (3.5, 24.5)] ``` -Note that `setu` is not designed to work for `Timeseries` objects. +Note that `setsym` is not designed to work for `Timeseries` objects. If a type needs to perform some additional actions when updating the state/parameters or if it is not possible to return a mutable reference to the state/parameter vector @@ -315,7 +315,7 @@ setp(integrator, :b)(integrator, 3.0) # functionally the same as above If a solution object includes modified parameter values (such as through callbacks) during the simulation, it must implement several additional functions for correct functioning of -[`getu`](@ref) and [`getp`](@ref). [`ParameterTimeseriesCollection`](@ref) helps in +[`getsym`](@ref) and [`getp`](@ref). [`ParameterTimeseriesCollection`](@ref) helps in implementing parameter timeseries objects. The following mockup gives an example of correct implementation of these functions and the indexing syntax they enable. @@ -457,12 +457,12 @@ sol.ps[:(b + c)] # observed quantities work too ``` ```@example param_timeseries -getu(sol, :b)(sol) # works +getsym(sol, :b)(sol) # works ``` ```@example param_timeseries try - getu(sol, [:b, :d])(sol) # errors since :b and :d belong to different timeseries + getsym(sol, [:b, :d])(sol) # errors since :b and :d belong to different timeseries catch e showerror(stdout, e) end diff --git a/docs/src/usage.md b/docs/src/usage.md index 693256d..dfeac57 100644 --- a/docs/src/usage.md +++ b/docs/src/usage.md @@ -123,11 +123,11 @@ sol[allvariables] # equivalent to sol[all_variable_symbols(sol)] ### Evaluating expressions -`getu` also generates functions for expressions if the object passed to it supports +`getsym` also generates functions for expressions if the object passed to it supports [`observed`](@ref). For example: ```@example Usage -getu(prob, x + y + z)(prob) +getsym(prob, x + y + z)(prob) ``` To evaluate this function using values other than the ones contained in `prob`, we need @@ -137,7 +137,7 @@ which has trivial implementations of the above functions. We can thus do: ```@example Usage temp_state = ProblemState(; u = [0.1, 0.2, 0.3, 0.4], p = parameter_values(prob)) -getu(prob, x + y + z)(temp_state) +getsym(prob, x + y + z)(temp_state) ``` Note that providing all of the state vector, parameter object and time may not be diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 7b2f3cc..f444fb4 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -35,7 +35,7 @@ include("parameter_timeseries_collection.jl") export getp, setp, setp_oop include("parameter_indexing.jl") -export getu, setu +export getsym, setsym, getu, setu include("state_indexing.jl") export BatchedInterface, setsym_oop, associated_systems diff --git a/src/batched_interface.jl b/src/batched_interface.jl index c8d02d8..ad8273c 100644 --- a/src/batched_interface.jl +++ b/src/batched_interface.jl @@ -2,7 +2,7 @@ struct BatchedInterface{S <: AbstractVector, I} function BatchedInterface(indp_syms::Tuple...) -A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref). +A struct which stores information for batched calls to [`getsym`](@ref) or [`setsym`](@ref). Given `Tuple`s, where the first element of each tuple is an index provider and the second an array of symbolic variables (either states or parameters) in the index provider, `BatchedInterface` will compute the union of all symbols and associate each symbol with @@ -17,7 +17,7 @@ be retained internally. `BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref), [`variable_index`](@ref) to query the order of symbols in the union. -See [`getu`](@ref) and [`setu`](@ref) for further details. +See [`getsym`](@ref) and [`setsym`](@ref) for further details. See also: [`associated_systems`](@ref). """ @@ -118,7 +118,7 @@ is the index of the index provider associated with the corresponding symbol in associated_systems(bi::BatchedInterface) = bi.associated_systems """ - getu(bi::BatchedInterface) + getsym(bi::BatchedInterface) Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding symbols), return a function which takes `n` corresponding value providers and returns an @@ -127,7 +127,7 @@ an `AbstractArray` of the appropriate `eltype` and size as its first argument, i case the operation will populate the array in-place with the values of the symbols in the union. -Note that all of the value providers passed to the function returned by `getu` must satisfy +Note that all of the value providers passed to the function returned by `getsym` must satisfy `is_timeseries(prob) === NotTimeseries()`. The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is @@ -137,7 +137,7 @@ provider at index `associated_systems(bi)[i]`). See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref), [`NotTimeseries`](@ref). """ -function getu(bi::BatchedInterface) +function getsym(bi::BatchedInterface) numprobs = length(bi.system_to_symbol_subset) probnames = [Symbol(:prob, i) for i in 1:numprobs] @@ -189,13 +189,13 @@ function getu(bi::BatchedInterface) end """ - setu(bi::BatchedInterface) + setsym(bi::BatchedInterface) Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding symbols), return a function which takes `n` corresponding problems and an array of the values, and updates each of the problems with the values of the corresponding symbols. -Note that all of the value providers passed to the function returned by `setu` must satisfy +Note that all of the value providers passed to the function returned by `setsym` must satisfy `is_timeseries(prob) === NotTimeseries()`. Note that if any subset of the `n` index providers share common symbols (among those passed @@ -204,7 +204,7 @@ updated with the values of the common symbols. See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref). """ -function setu(bi::BatchedInterface) +function setsym(bi::BatchedInterface) numprobs = length(bi.system_to_symbol_subset) probnames = [Symbol(:prob, i) for i in 1:numprobs] diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 1c61f8a..6333cea 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -726,38 +726,12 @@ function setp_oop(indp, sym) return _setp_oop(indp, symtype, elsymtype, sym) end -struct OOPSetter{I, D} - indp::I - idxs::D -end - -function (os::OOPSetter)(valp, val) - return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,)) -end - -function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray}) - if os.idxs isa Union{Tuple, AbstractArray} - return remake_buffer(os.indp, parameter_values(valp), os.idxs, val) - else - return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,)) - end -end - -function _root_indp(indp) - if hasmethod(symbolic_container, Tuple{typeof(indp)}) && - (sc = symbolic_container(indp)) != indp - return _root_indp(sc) - else - return indp - end -end - function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym) - return OOPSetter(_root_indp(indp), sym) + return OOPSetter(_root_indp(indp), sym, false) end function _setp_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) - return OOPSetter(_root_indp(indp), parameter_index(indp, sym)) + return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false) end for (t1, t2) in [ @@ -765,13 +739,13 @@ for (t1, t2) in [ (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] @eval function _setp_oop(indp, ::NotSymbolic, ::$t1, sym::$t2) - return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym)) + return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym), false) end end function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym) if is_parameter(indp, sym) - return OOPSetter(_root_indp(indp), parameter_index(indp, sym)) + return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false) end error("$sym is not a valid parameter") end diff --git a/src/problem_state.jl b/src/problem_state.jl index fa93656..831deae 100644 --- a/src/problem_state.jl +++ b/src/problem_state.jl @@ -3,7 +3,7 @@ function ProblemState(; u = nothing, p = nothing, t = nothing) A value provider struct which can be used as an argument to the function returned by -[`getu`](@ref) or [`setu`](@ref). It stores the state vector, parameter object and +[`getsym`](@ref) or [`setsym`](@ref). It stores the state vector, parameter object and current time, and forwards calls to [`state_values`](@ref), [`parameter_values`](@ref), [`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained objects. diff --git a/src/remake.jl b/src/remake.jl index c1dbc47..21de62d 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -46,7 +46,7 @@ function remake_buffer(sys, oldbuffer::AbstractArray, idxs, vals) else v = elT(v) end - setu(sys, k)(newbuffer, v) + setsym(sys, k)(newbuffer, v) end else mutbuffer = remake_buffer(sys, collect(oldbuffer), idxs, vals) @@ -80,7 +80,7 @@ end function remake_buffer(sys, oldbuffer::Tuple, idxs, vals) wrap = TupleRemakeWrapper(oldbuffer) for (idx, val) in zip(idxs, vals) - setu(sys, idx)(wrap, val) + setsym(sys, idx)(wrap, val) end return wrap.t end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 1945232..b434e4c 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -3,7 +3,7 @@ function set_state!(sys, val, idx) end """ - getu(indp, sym) + getsym(indp, sym) Return a function that takes a value provider and returns the value of the symbolic variable `sym`. If `sym` is not an observed quantity, the returned function can also @@ -25,10 +25,10 @@ If the value provider is a parameter timeseries object, the same rules apply as [`getp`](@ref). The difference here is that `sym` may also contain non-parameter symbols, and the values are always returned corresponding to the state timeseries. """ -function getu(sys, sym) +function getsym(sys, sym) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - _getu(sys, symtype, elsymtype, sym) + _getsym(sys, symtype, elsymtype, sym) end struct GetStateIndex{I} <: AbstractStateGetIndexer @@ -47,7 +47,7 @@ function (gsi::GetStateIndex)(::NotTimeseries, prob) state_values(prob, gsi.idx) end -function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) +function _getsym(sys, ::NotSymbolic, ::NotSymbolic, sym) return GetStateIndex(sym) end @@ -142,10 +142,10 @@ function (o::TimeIndependentObservedFunction)(::IsTimeseriesTrait, prob) return o.obsfn(state_values(prob), parameter_values(prob)) end -function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) +function _getsym(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return getu(sys, idx) + return getsym(sys, idx) elseif is_parameter(sys, sym) return getp(sys, sym) elseif is_independent_variable(sys, sym) @@ -168,7 +168,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) return getp(sys, sym) end end - error("Invalid symbol $sym for `getu`") + error("Invalid symbol $sym for `getsym`") end struct MultipleGetters{I, G} <: AbstractStateGetIndexer @@ -247,7 +247,7 @@ for (t1, t2) in [ (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] - @eval function _getu(sys, ::NotSymbolic, elt::$t1, sym::$t2) + @eval function _getsym(sys, ::NotSymbolic, elt::$t1, sym::$t2) if isempty(sym) return MultipleGetters(ContinuousTimeseries(), sym) end @@ -259,7 +259,7 @@ for (t1, t2) in [ end if !is_time_dependent(sys) if num_observed == 0 || num_observed == 1 && sym isa Tuple - return MultipleGetters(nothing, getu.((sys,), sym)) + return MultipleGetters(nothing, getsym.((sys,), sym)) else obs = observed(sys, sym_arr) getter = TimeIndependentObservedFunction(obs) @@ -280,7 +280,7 @@ for (t1, t2) in [ end if num_observed == 0 || num_observed == 1 && sym isa Tuple - getters = getu.((sys,), sym) + getters = getsym.((sys,), sym) return MultipleGetters(ts_idxs, getters) else obs = observed(sys, sym_arr) @@ -297,20 +297,20 @@ for (t1, t2) in [ end end -function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) +function _getsym(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return getu(sys, idx) + return getsym(sys, idx) elseif is_parameter(sys, sym) return getp(sys, sym) end - return getu(sys, collect(sym)) + return getsym(sys, collect(sym)) end -# setu doesn't need the same `let` blocks to be inferred for some reason +# setsym doesn't need the same `let` blocks to be inferred for some reason """ - setu(sys, sym) + setsym(sys, sym) Return a function that takes a value provider and a value, and sets the the state `sym` to that value. Note that `sym` can be an index, a symbolic variable, or an array/tuple of the @@ -322,10 +322,10 @@ if this is not possible or additional actions need to be performed when updating [`set_state!`](@ref) can be defined. This function does not work on types for which [`is_timeseries`](@ref) is [`Timeseries`](@ref). """ -function setu(sys, sym) +function setsym(sys, sym) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - _setu(sys, symtype, elsymtype, sym) + _setsym(sys, symtype, elsymtype, sym) end struct SetStateIndex{I} <: AbstractSetIndexer @@ -336,18 +336,18 @@ function (ssi::SetStateIndex)(prob, val) set_state!(prob, val, ssi.idx) end -function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym) +function _setsym(sys, ::NotSymbolic, ::NotSymbolic, sym) return SetStateIndex(sym) end -function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) +function _setsym(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) return SetStateIndex(idx) elseif is_parameter(sys, sym) return setp(sys, sym) end - error("Invalid symbol $sym for `setu`") + error("Invalid symbol $sym for `setsym`") end for (t1, t2) in [ @@ -355,13 +355,13 @@ for (t1, t2) in [ (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] - @eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2) - setters = setu.((sys,), sym) + @eval function _setsym(sys, ::NotSymbolic, ::$t1, sym::$t2) + setters = setsym.((sys,), sym) return MultipleSetters(setters) end end -function _setu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) +function _setsym(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) if idx isa AbstractArray @@ -372,5 +372,102 @@ function _setu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) elseif is_parameter(sys, sym) return setp(sys, sym) end - return setu(sys, collect(sym)) + return setsym(sys, collect(sym)) +end + +const getu = getsym +const setu = setsym + +""" + setsym_oop(indp, sym) + +Return a function which takes a value provider `valp` and a value `val`, and returns +`state_values(valp), parameter_values(valp)` with the states/parameters in `sym` set to the +corresponding values in `val`. This allows changing the types of values stored, and leverages +[`remake_buffer`](@ref). Note that `sym` can be an index, a symbolic variable, or an +array/tuple of the aforementioned. All entries `s` in `sym` must satisfy `is_variable(indp, s)` +or `is_parameter(indp, s)`. + +Requires that the value provider implement `state_values`, `parameter_values` and `remake_buffer`. +""" +function setsym_oop(indp, sym) + symtype = symbolic_type(sym) + elsymtype = symbolic_type(eltype(sym)) + return _setsym_oop(indp, symtype, elsymtype, sym) +end + +struct FullSetter{S, P, I, J} + state_setter::S + param_setter::P + state_split::I + param_split::J +end + +FullSetter(ssetter, psetter) = FullSetter(ssetter, psetter, nothing, nothing) + +function (fs::FullSetter)(valp, val) + return fs.state_setter(valp, val[fs.state_split]), + fs.param_setter(valp, val[fs.param_split]) +end + +function (fs::FullSetter{Nothing})(valp, val) + return state_values(valp), fs.param_setter(valp, val) +end + +function (fs::(FullSetter{S, Nothing} where {S}))(valp, val) + return fs.state_setter(valp, val), parameter_values(valp) +end + +function (fs::(FullSetter{Nothing, Nothing}))(valp, val) + return state_values(valp), parameter_values(valp) +end + +function _setsym_oop(indp, ::NotSymbolic, ::NotSymbolic, sym) + return FullSetter(OOPSetter(_root_indp(indp), sym, true), nothing) +end + +function _setsym_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) + if (idx = variable_index(indp, sym)) !== nothing + return FullSetter(OOPSetter(_root_indp(indp), idx, true), nothing) + elseif (idx = parameter_index(indp, sym)) !== nothing + return FullSetter(nothing, OOPSetter(_root_indp(indp), idx, false)) + end + throw(NotVariableOrParameter("setsym_oop", sym)) +end + +for (t1, t2) in [ + (ScalarSymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}) +] + @eval function _setsym_oop(indp, ::NotSymbolic, ::$t1, sym::$t2) + vars = [] + state_split = eltype(eachindex(sym))[] + pars = [] + param_split = eltype(eachindex(sym))[] + for (i, s) in enumerate(sym) + if (idx = variable_index(indp, s)) !== nothing + push!(vars, idx) + push!(state_split, i) + elseif (idx = parameter_index(indp, s)) !== nothing + push!(pars, idx) + push!(param_split, i) + else + throw(NotVariableOrParameter("setsym_oop", s)) + end + end + indp = _root_indp(indp) + return FullSetter(isempty(vars) ? nothing : OOPSetter(indp, identity.(vars), true), + isempty(pars) ? nothing : OOPSetter(indp, identity.(pars), false), + state_split, param_split) + end +end + +function _setsym_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym) + if (idx = variable_index(indp, sym)) !== nothing + return setsym_oop(indp, idx) + elseif (idx = parameter_index(indp, sym)) !== nothing + return FullSetter( + nothing, OOPSetter(indp, idx isa AbstractArray ? idx : (idx,), false)) + end + return setsym_oop(indp, collect(sym)) end diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 1ae30b5..cd52ae1 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -110,7 +110,7 @@ state_values(arr, ::Colon) = state_values(arr) Set the state at index `idx` to `val` for value provider `valp`. This defaults to modifying `state_values(valp)`. If any additional bookkeeping needs to be performed or the default implementation does not work for a particular type, this method needs to be -defined to enable the proper functioning of [`setu`](@ref). +defined to enable the proper functioning of [`setsym`](@ref). See: [`state_values`](@ref) """ @@ -231,6 +231,35 @@ function (fn::Fix1Multiple)(args...) fn.f(fn.arg, args...) end +struct OOPSetter{I, D} + indp::I + idxs::D + is_state::Bool +end + +function (os::OOPSetter)(valp, val) + buffer = os.is_state ? state_values(valp) : parameter_values(valp) + return remake_buffer(os.indp, buffer, (os.idxs,), (val,)) +end + +function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray}) + buffer = os.is_state ? state_values(valp) : parameter_values(valp) + if os.idxs isa Union{Tuple, AbstractArray} + return remake_buffer(os.indp, buffer, os.idxs, val) + else + return remake_buffer(os.indp, buffer, (os.idxs,), (val,)) + end +end + +function _root_indp(indp) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) && + (sc = symbolic_container(indp)) != indp + return _root_indp(sc) + else + return indp + end +end + ########### # Errors ########### @@ -296,3 +325,16 @@ function Base.showerror(io::IO, err::MixedParameterTimeseriesIndexError) indexes $(err.ts_idxs). """) end + +struct NotVariableOrParameter <: Exception + fn::Any + sym::Any +end + +function Base.showerror(io::IO, err::NotVariableOrParameter) + print( + io, """ + `$(err.fn)` requires that the symbolic variable(s) passed to it satisfy `is_variable` + or `is_parameter`. Got `$(err.sym)` which is neither. + """) +end diff --git a/test/batched_interface_test.jl b/test/batched_interface_test.jl index 13dd415..533ef61 100644 --- a/test/batched_interface_test.jl +++ b/test/batched_interface_test.jl @@ -28,13 +28,13 @@ bi = BatchedInterface(zip(syss, syms)...) Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0] @test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3] -getter = getu(bi) +getter = getsym(bi) @test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] buf = zeros(8) @inferred getter(buf, probs...) @test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] -setter! = setu(bi) +setter! = setsym(bi) buf .*= 100 setter!(probs..., buf) diff --git a/test/downstream/batchedinterface_arrayvars.jl b/test/downstream/batchedinterface_arrayvars.jl index eaba28c..f07edc5 100644 --- a/test/downstream/batchedinterface_arrayvars.jl +++ b/test/downstream/batchedinterface_arrayvars.jl @@ -23,13 +23,13 @@ bi = BatchedInterface(zip(syss, syms)...) @test is_variable.((bi,), [x..., y, z]) == Bool[1, 1, 1, 0] @test associated_systems(bi) == [1, 1, 1] -getter = getu(bi) +getter = getsym(bi) @test (@inferred getter(probs...)) == [1.0, 2.0, 3.0] buf = zeros(3) @inferred getter(buf, probs...) @test buf == [1.0, 2.0, 3.0] -setter! = setu(bi) +setter! = setsym(bi) buf .*= 10 setter!(probs..., buf) @@ -58,7 +58,7 @@ probs = [ bi = BatchedInterface(zip(syss, syms)...) -buf = getu(bi)(probs...) +buf = getsym(bi)(probs...) buf .*= 100 setter = setsym_oop(bi) vals = setter(probs..., buf) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 6c864c4..26eaaf5 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -488,7 +488,7 @@ for (sym, val_is_timeseries, val, check_inference) in [ ([:a, :(2b)], true, vcat.(aval, 2 .* bval), true), ((:a, :(2b)), true, tuple.(aval, 2 .* bval), true) ] - getter = getu(sys, sym) + getter = getsym(sys, sym) if check_inference @inferred getter(fs) end @@ -524,7 +524,7 @@ for (sym, val, check_inference) in [ ([:(2b), :(3x)], [2_bval, 3_xval], true), ((:(2b), :(3x)), (2_bval, 3_xval), true) ] - getter = getu(sys, sym) + getter = getsym(sys, sym) @test_throws MixedParameterTimeseriesIndexError getter(fs) for subidx in [1, CartesianIndex(2), :, rand(Bool, 4), rand(1:4, 3), 1:2] @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) diff --git a/test/problem_state_test.jl b/test/problem_state_test.jl index d060925..906ee73 100644 --- a/test/problem_state_test.jl +++ b/test/problem_state_test.jl @@ -5,11 +5,11 @@ sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5) for (i, sym) in enumerate(variable_symbols(sys)) - @test getu(sys, sym)(prob) == prob.u[i] + @test getsym(sys, sym)(prob) == prob.u[i] end for (i, sym) in enumerate(parameter_symbols(sys)) @test getp(sys, sym)(prob) == prob.p[i] end -@test getu(sys, :t)(prob) == prob.t +@test getsym(sys, :t)(prob) == prob.t -@test getu(sys, :(x + a + t))(prob) == 1.6 +@test getsym(sys, :(x + a + t))(prob) == 1.6 diff --git a/test/simple_adjoints_test.jl b/test/simple_adjoints_test.jl index 329fd10..c3fc2c1 100644 --- a/test/simple_adjoints_test.jl +++ b/test/simple_adjoints_test.jl @@ -4,14 +4,14 @@ using Zygote sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) pstate = ProblemState(; u = rand(3), p = rand(3), t = rand()) -getter = getu(sys, :x) +getter = getsym(sys, :x) @test Zygote.gradient(getter, pstate)[1].u == [1.0, 0.0, 0.0] -getter = getu(sys, [:x, :z]) +getter = getsym(sys, [:x, :z]) @test Zygote.gradient(sum ∘ getter, pstate)[1].u == [1.0, 0.0, 1.0] -getter = getu(sys, :a) +getter = getsym(sys, :a) @test Zygote.gradient(getter, pstate)[1].p == [1.0, 0.0, 0.0] -getter = getu(sys, [:a, :c]) +getter = getsym(sys, [:a, :c]) @test Zygote.gradient(sum ∘ getter, pstate)[1].p == [1.0, 0.0, 1.0] diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 8e59779..622e5b1 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -1,4 +1,5 @@ using SymbolicIndexingInterface +using SymbolicIndexingInterface: NotVariableOrParameter struct FakeIntegrator{S, U, P, T} sys::S @@ -14,8 +15,8 @@ SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -@test_throws ErrorException getu(sys, :q) -@test_throws ErrorException setu(sys, :q) +@test_throws ErrorException getsym(sys, :q) +@test_throws ErrorException setsym(sys, :q) u = [1.0, 2.0, 3.0] p = [11.0, 12.0, 13.0] @@ -50,8 +51,8 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) (4.0, (5.0, 6.0)), true) ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)] - get = getu(sys, sym) - set! = setu(sys, sym) + get = getsym(sys, sym) + set! = setsym(sys, sym) if check_inference @inferred get(fi) end @@ -62,6 +63,9 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) set!(fi, newval) end @test get(fi) == newval + + new_states = copy(state_values(fi)) + set!(fi, val) @test get(fi) == val @@ -77,6 +81,15 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) @test get(u) == newval set!(u, val) @test get(u) == val + + if sym isa Union{Vector, Tuple} && any(x -> x isa Union{AbstractArray, Tuple}, sym) + continue + end + + setter = setsym_oop(sys, sym) + svals, pvals = setter(fi, newval) + @test svals ≈ new_states + @test pvals == parameter_values(fi) end for (sym, val, check_inference) in [ @@ -84,7 +97,7 @@ for (sym, val, check_inference) in [ ([:(x + y), :z], [u[1] + u[2], u[3]], false), ((:(x + y), :(z + y)), (u[1] + u[2], u[2] + u[3]), false) ] - get = getu(sys, sym) + get = getsym(sys, sym) if check_inference @inferred get(fi) end @@ -92,15 +105,15 @@ for (sym, val, check_inference) in [ end let fi = fi, sys = sys - getter = getu(sys, []) + getter = getsym(sys, []) @test getter(fi) == [] - getter = getu(sys, ()) + getter = getsym(sys, ()) @test getter(fi) == () sc = SymbolCache(nothing, [:a, :b], :t) fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0) - getter = getu(sc, []) + getter = getsym(sc, []) @test getter(fi) == [] - getter = getu(sc, ()) + getter = getsym(sc, ()) @test getter(fi) == () end @@ -111,8 +124,8 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) ((:c, :b), (p[3], p[2]), (6.0, 5.0), true) ([:x, :a], [u[1], p[1]], [4.0, 5.0], false) ((:y, :b), (u[2], p[2]), (5.0, 6.0), true)] - get = getu(fi, sym) - set! = setu(fi, sym) + get = getsym(fi, sym) + set! = setsym(fi, sym) if check_inference @inferred get(fi) end @@ -123,8 +136,17 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) set!(fi, newval) end @test get(fi) == newval + + newu = copy(state_values(fi)) + newp = copy(parameter_values(fi)) + set!(fi, oldval) @test get(fi) == oldval + + oop_setter = setsym_oop(sys, sym) + uvals, pvals = oop_setter(fi, newval) + @test uvals ≈ newu + @test pvals ≈ newp end for (sym, val, check_inference) in [ @@ -132,11 +154,13 @@ for (sym, val, check_inference) in [ ([:x, :a, :t], [u[1], p[1], t], false), ((:x, :a, :t), (u[1], p[1], t), false) ] - get = getu(fi, sym) + get = getsym(fi, sym) if check_inference @inferred get(fi) end @test get(fi) == val + + @test_throws NotVariableOrParameter setsym_oop(fi, sym) end struct FakeSolution{S, U, P, T} @@ -202,7 +226,7 @@ for (sym, ans, check_inference) in [(:x, xvals, true) (:t, t, true) ([:x, :a, :t], vcat.(xvals, p[1], t), false) ((:x, :a, :t), tuple.(xvals, p[1], t), true)] - get = getu(sys, sym) + get = getsym(sys, sym) if check_inference @inferred get(sol) end @@ -221,7 +245,7 @@ for (sym, val, check_inference) in [ ([:(x + y), :z], vcat.(xvals .+ yvals, zvals), false), ((:(x + y), :(z + y)), tuple.(xvals .+ yvals, yvals .+ zvals), false) ] - get = getu(sys, sym) + get = getsym(sys, sym) if check_inference @inferred get(sol) end @@ -240,21 +264,21 @@ for (sym, val) in [(:a, p[1]) (:c, p[3]) ([:a, :b], p[1:2]) ((:c, :b), (p[3], p[2]))] - get = getu(sys, sym) + get = getsym(sys, sym) @inferred get(sol) @test get(sol) == val end let sol = sol, sys = sys - getter = getu(sys, []) + getter = getsym(sys, []) @test getter(sol) == [[] for _ in 1:length(sol.t)] - getter = getu(sys, ()) + getter = getsym(sys, ()) @test getter(sol) == [() for _ in 1:length(sol.t)] sc = SymbolCache(nothing, [:a, :b], :t) sol = FakeSolution(sys, [], [1.0, 2.0], []) - getter = getu(sc, []) + getter = getsym(sc, []) @test getter(sol) == [] - getter = getu(sc, ()) + getter = getsym(sc, ()) @test getter(sol) == [] end @@ -285,7 +309,7 @@ for (sym, val, check_inference) in [ ([:(x + a), :(y + b)], [u[1] + p[1], u[2] + p[2]], true), ((:(x + a), :(y + b)), (u[1] + p[1], u[2] + p[2]), true) ] - getter = getu(sys, sym) + getter = getsym(sys, sym) if check_inference @inferred getter(fs) end @@ -318,7 +342,7 @@ ts = 0.0:0.1:1.0 fi = FakeIntegrator(sys, u0, p, ts[1]) fs = FakeSolution(sys, u, p, ts) -getter = getu(sys, :(x + y)) +getter = getsym(sys, :(x + y)) @test getter(fi) ≈ 2.8 @test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11] @test getter(fs, 1) ≈ 2.8