From 3344f9f1c393bc2477a403026853975b291ee77c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 29 Oct 2024 16:43:09 +0530 Subject: [PATCH] refactor: rename `getu`/`setu` to `getsym`/`setsym` --- docs/src/api.md | 11 ++-- docs/src/complete_sii.md | 28 +++++----- docs/src/usage.md | 6 +-- src/SymbolicIndexingInterface.jl | 2 +- src/batched_interface.jl | 16 +++--- src/problem_state.jl | 2 +- src/remake.jl | 4 +- src/state_indexing.jl | 51 ++++++++++--------- src/value_provider_interface.jl | 2 +- test/batched_interface_test.jl | 4 +- test/downstream/batchedinterface_arrayvars.jl | 6 +-- test/parameter_indexing_test.jl | 4 +- test/problem_state_test.jl | 6 +-- test/simple_adjoints_test.jl | 8 +-- test/state_indexing_test.jl | 42 +++++++-------- 15 files changed, 99 insertions(+), 93 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index cf767b9a..29150c0a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -46,7 +46,7 @@ is_markovian If the index provider contains parameters that change during the course of the simulation at discrete time points, it must implement the following methods to ensure correct -functioning of [`getu`](@ref) and [`getp`](@ref) for value providers that save the parameter +functioning of [`getsym`](@ref) and [`getp`](@ref) for value providers that save the parameter timeseries. Note that there can be multiple parameter timeseries, in case different parameters may change at different times. @@ -69,10 +69,13 @@ is_timeseries state_values set_state! current_time -getu -setu +getsym +setsym ``` +!!! note + `getu` and `setu` have been renamed to [`getsym`](@ref) and [`setsym`](@ref) respectively. + #### Historical value providers ```@docs @@ -95,7 +98,7 @@ ParameterIndexingProxy If a solution object saves a timeseries of parameter values that are updated during the simulation (such as by callbacks), it must implement the following methods to ensure -correct functioning of [`getu`](@ref) and [`getp`](@ref). +correct functioning of [`getsym`](@ref) and [`getp`](@ref). Parameter timeseries support requires that the value provider store the different timeseries in a [`ParameterTimeseriesCollection`](@ref). diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index ae55e982..c75cdc80 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -174,9 +174,9 @@ end ``` If a type contains the value of state variables, it can define [`state_values`](@ref) to -enable the usage of [`getu`](@ref) and [`setu`](@ref). These methods retturn getter/ +enable the usage of [`getsym`](@ref) and [`setsym`](@ref). These methods retturn getter/ setter functions to access or update the value of a state variable (or a collection of -them). If the type also supports generating [`observed`](@ref) functions, `getu` also +them). If the type also supports generating [`observed`](@ref) functions, `getsym` also enables returning functions to access the value of arbitrary expressions involving the system's symbols. This also requires that the type implement [`parameter_values`](@ref) and [`current_time`](@ref) (if the system is time-dependent). @@ -202,13 +202,13 @@ Then the following example would work: ```julia sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict()) integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys) -getx = getu(sys, :x) +getx = getsym(sys, :x) getx(integrator) # 1.0 -get_expr = getu(sys, :(x + y + t)) +get_expr = getsym(sys, :(x + y + t)) get_expr(integrator) # 13.0 -setx! = setu(sys, :y) +setx! = setsym(sys, :y) setx!(integrator, 0.0) getx(integrator) # 0.0 ``` @@ -216,7 +216,7 @@ getx(integrator) # 0.0 In case a type stores timeseries data (such as solutions), then it must also implement the [`Timeseries`](@ref) trait. The type would then return a timeseries from [`state_values`](@ref) and [`current_time`](@ref) and the function returned from -[`getu`](@ref) would then return a timeseries as well. For example, consider the +[`getsym`](@ref) would then return a timeseries as well. For example, consider the `ExampleSolution` below: ```julia @@ -242,20 +242,20 @@ Then the following example would work: ```julia # using the same system that the ExampleIntegrator used sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys) -getx = getu(sys, :x) +getx = getsym(sys, :x) getx(sol) # [1.0, 1.5] -get_expr = getu(sys, :(x + y + t)) +get_expr = getsym(sys, :(x + y + t)) get_expr(sol) # [9.0, 11.0] -get_arr = getu(sys, [:y, :(x + a)]) +get_arr = getsym(sys, [:y, :(x + a)]) get_arr(sol) # [[2.0, 5.0], [2.5, 5.5]] -get_tuple = getu(sys, (:z, :(z * t))) +get_tuple = getsym(sys, (:z, :(z * t))) get_tuple(sol) # [(3.0, 18.0), (3.5, 24.5)] ``` -Note that `setu` is not designed to work for `Timeseries` objects. +Note that `setsym` is not designed to work for `Timeseries` objects. If a type needs to perform some additional actions when updating the state/parameters or if it is not possible to return a mutable reference to the state/parameter vector @@ -315,7 +315,7 @@ setp(integrator, :b)(integrator, 3.0) # functionally the same as above If a solution object includes modified parameter values (such as through callbacks) during the simulation, it must implement several additional functions for correct functioning of -[`getu`](@ref) and [`getp`](@ref). [`ParameterTimeseriesCollection`](@ref) helps in +[`getsym`](@ref) and [`getp`](@ref). [`ParameterTimeseriesCollection`](@ref) helps in implementing parameter timeseries objects. The following mockup gives an example of correct implementation of these functions and the indexing syntax they enable. @@ -457,12 +457,12 @@ sol.ps[:(b + c)] # observed quantities work too ``` ```@example param_timeseries -getu(sol, :b)(sol) # works +getsym(sol, :b)(sol) # works ``` ```@example param_timeseries try - getu(sol, [:b, :d])(sol) # errors since :b and :d belong to different timeseries + getsym(sol, [:b, :d])(sol) # errors since :b and :d belong to different timeseries catch e showerror(stdout, e) end diff --git a/docs/src/usage.md b/docs/src/usage.md index 693256d4..dfeac575 100644 --- a/docs/src/usage.md +++ b/docs/src/usage.md @@ -123,11 +123,11 @@ sol[allvariables] # equivalent to sol[all_variable_symbols(sol)] ### Evaluating expressions -`getu` also generates functions for expressions if the object passed to it supports +`getsym` also generates functions for expressions if the object passed to it supports [`observed`](@ref). For example: ```@example Usage -getu(prob, x + y + z)(prob) +getsym(prob, x + y + z)(prob) ``` To evaluate this function using values other than the ones contained in `prob`, we need @@ -137,7 +137,7 @@ which has trivial implementations of the above functions. We can thus do: ```@example Usage temp_state = ProblemState(; u = [0.1, 0.2, 0.3, 0.4], p = parameter_values(prob)) -getu(prob, x + y + z)(temp_state) +getsym(prob, x + y + z)(temp_state) ``` Note that providing all of the state vector, parameter object and time may not be diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 7b2f3ccd..f444fb47 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -35,7 +35,7 @@ include("parameter_timeseries_collection.jl") export getp, setp, setp_oop include("parameter_indexing.jl") -export getu, setu +export getsym, setsym, getu, setu include("state_indexing.jl") export BatchedInterface, setsym_oop, associated_systems diff --git a/src/batched_interface.jl b/src/batched_interface.jl index c8d02d8e..ad8273cd 100644 --- a/src/batched_interface.jl +++ b/src/batched_interface.jl @@ -2,7 +2,7 @@ struct BatchedInterface{S <: AbstractVector, I} function BatchedInterface(indp_syms::Tuple...) -A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref). +A struct which stores information for batched calls to [`getsym`](@ref) or [`setsym`](@ref). Given `Tuple`s, where the first element of each tuple is an index provider and the second an array of symbolic variables (either states or parameters) in the index provider, `BatchedInterface` will compute the union of all symbols and associate each symbol with @@ -17,7 +17,7 @@ be retained internally. `BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref), [`variable_index`](@ref) to query the order of symbols in the union. -See [`getu`](@ref) and [`setu`](@ref) for further details. +See [`getsym`](@ref) and [`setsym`](@ref) for further details. See also: [`associated_systems`](@ref). """ @@ -118,7 +118,7 @@ is the index of the index provider associated with the corresponding symbol in associated_systems(bi::BatchedInterface) = bi.associated_systems """ - getu(bi::BatchedInterface) + getsym(bi::BatchedInterface) Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding symbols), return a function which takes `n` corresponding value providers and returns an @@ -127,7 +127,7 @@ an `AbstractArray` of the appropriate `eltype` and size as its first argument, i case the operation will populate the array in-place with the values of the symbols in the union. -Note that all of the value providers passed to the function returned by `getu` must satisfy +Note that all of the value providers passed to the function returned by `getsym` must satisfy `is_timeseries(prob) === NotTimeseries()`. The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is @@ -137,7 +137,7 @@ provider at index `associated_systems(bi)[i]`). See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref), [`NotTimeseries`](@ref). """ -function getu(bi::BatchedInterface) +function getsym(bi::BatchedInterface) numprobs = length(bi.system_to_symbol_subset) probnames = [Symbol(:prob, i) for i in 1:numprobs] @@ -189,13 +189,13 @@ function getu(bi::BatchedInterface) end """ - setu(bi::BatchedInterface) + setsym(bi::BatchedInterface) Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding symbols), return a function which takes `n` corresponding problems and an array of the values, and updates each of the problems with the values of the corresponding symbols. -Note that all of the value providers passed to the function returned by `setu` must satisfy +Note that all of the value providers passed to the function returned by `setsym` must satisfy `is_timeseries(prob) === NotTimeseries()`. Note that if any subset of the `n` index providers share common symbols (among those passed @@ -204,7 +204,7 @@ updated with the values of the common symbols. See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref). """ -function setu(bi::BatchedInterface) +function setsym(bi::BatchedInterface) numprobs = length(bi.system_to_symbol_subset) probnames = [Symbol(:prob, i) for i in 1:numprobs] diff --git a/src/problem_state.jl b/src/problem_state.jl index fa936564..831deae2 100644 --- a/src/problem_state.jl +++ b/src/problem_state.jl @@ -3,7 +3,7 @@ function ProblemState(; u = nothing, p = nothing, t = nothing) A value provider struct which can be used as an argument to the function returned by -[`getu`](@ref) or [`setu`](@ref). It stores the state vector, parameter object and +[`getsym`](@ref) or [`setsym`](@ref). It stores the state vector, parameter object and current time, and forwards calls to [`state_values`](@ref), [`parameter_values`](@ref), [`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained objects. diff --git a/src/remake.jl b/src/remake.jl index c1dbc477..21de62da 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -46,7 +46,7 @@ function remake_buffer(sys, oldbuffer::AbstractArray, idxs, vals) else v = elT(v) end - setu(sys, k)(newbuffer, v) + setsym(sys, k)(newbuffer, v) end else mutbuffer = remake_buffer(sys, collect(oldbuffer), idxs, vals) @@ -80,7 +80,7 @@ end function remake_buffer(sys, oldbuffer::Tuple, idxs, vals) wrap = TupleRemakeWrapper(oldbuffer) for (idx, val) in zip(idxs, vals) - setu(sys, idx)(wrap, val) + setsym(sys, idx)(wrap, val) end return wrap.t end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 19452321..6745b522 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -3,7 +3,7 @@ function set_state!(sys, val, idx) end """ - getu(indp, sym) + getsym(indp, sym) Return a function that takes a value provider and returns the value of the symbolic variable `sym`. If `sym` is not an observed quantity, the returned function can also @@ -25,10 +25,10 @@ If the value provider is a parameter timeseries object, the same rules apply as [`getp`](@ref). The difference here is that `sym` may also contain non-parameter symbols, and the values are always returned corresponding to the state timeseries. """ -function getu(sys, sym) +function getsym(sys, sym) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - _getu(sys, symtype, elsymtype, sym) + _getsym(sys, symtype, elsymtype, sym) end struct GetStateIndex{I} <: AbstractStateGetIndexer @@ -47,7 +47,7 @@ function (gsi::GetStateIndex)(::NotTimeseries, prob) state_values(prob, gsi.idx) end -function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) +function _getsym(sys, ::NotSymbolic, ::NotSymbolic, sym) return GetStateIndex(sym) end @@ -142,10 +142,10 @@ function (o::TimeIndependentObservedFunction)(::IsTimeseriesTrait, prob) return o.obsfn(state_values(prob), parameter_values(prob)) end -function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) +function _getsym(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return getu(sys, idx) + return getsym(sys, idx) elseif is_parameter(sys, sym) return getp(sys, sym) elseif is_independent_variable(sys, sym) @@ -168,7 +168,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) return getp(sys, sym) end end - error("Invalid symbol $sym for `getu`") + error("Invalid symbol $sym for `getsym`") end struct MultipleGetters{I, G} <: AbstractStateGetIndexer @@ -247,7 +247,7 @@ for (t1, t2) in [ (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] - @eval function _getu(sys, ::NotSymbolic, elt::$t1, sym::$t2) + @eval function _getsym(sys, ::NotSymbolic, elt::$t1, sym::$t2) if isempty(sym) return MultipleGetters(ContinuousTimeseries(), sym) end @@ -259,7 +259,7 @@ for (t1, t2) in [ end if !is_time_dependent(sys) if num_observed == 0 || num_observed == 1 && sym isa Tuple - return MultipleGetters(nothing, getu.((sys,), sym)) + return MultipleGetters(nothing, getsym.((sys,), sym)) else obs = observed(sys, sym_arr) getter = TimeIndependentObservedFunction(obs) @@ -280,7 +280,7 @@ for (t1, t2) in [ end if num_observed == 0 || num_observed == 1 && sym isa Tuple - getters = getu.((sys,), sym) + getters = getsym.((sys,), sym) return MultipleGetters(ts_idxs, getters) else obs = observed(sys, sym_arr) @@ -297,20 +297,20 @@ for (t1, t2) in [ end end -function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) +function _getsym(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return getu(sys, idx) + return getsym(sys, idx) elseif is_parameter(sys, sym) return getp(sys, sym) end - return getu(sys, collect(sym)) + return getsym(sys, collect(sym)) end -# setu doesn't need the same `let` blocks to be inferred for some reason +# setsym doesn't need the same `let` blocks to be inferred for some reason """ - setu(sys, sym) + setsym(sys, sym) Return a function that takes a value provider and a value, and sets the the state `sym` to that value. Note that `sym` can be an index, a symbolic variable, or an array/tuple of the @@ -322,10 +322,10 @@ if this is not possible or additional actions need to be performed when updating [`set_state!`](@ref) can be defined. This function does not work on types for which [`is_timeseries`](@ref) is [`Timeseries`](@ref). """ -function setu(sys, sym) +function setsym(sys, sym) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - _setu(sys, symtype, elsymtype, sym) + _setsym(sys, symtype, elsymtype, sym) end struct SetStateIndex{I} <: AbstractSetIndexer @@ -336,18 +336,18 @@ function (ssi::SetStateIndex)(prob, val) set_state!(prob, val, ssi.idx) end -function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym) +function _setsym(sys, ::NotSymbolic, ::NotSymbolic, sym) return SetStateIndex(sym) end -function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) +function _setsym(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) return SetStateIndex(idx) elseif is_parameter(sys, sym) return setp(sys, sym) end - error("Invalid symbol $sym for `setu`") + error("Invalid symbol $sym for `setsym`") end for (t1, t2) in [ @@ -355,13 +355,13 @@ for (t1, t2) in [ (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] - @eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2) - setters = setu.((sys,), sym) + @eval function _setsym(sys, ::NotSymbolic, ::$t1, sym::$t2) + setters = setsym.((sys,), sym) return MultipleSetters(setters) end end -function _setu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) +function _setsym(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) if idx isa AbstractArray @@ -372,5 +372,8 @@ function _setu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) elseif is_parameter(sys, sym) return setp(sys, sym) end - return setu(sys, collect(sym)) + return setsym(sys, collect(sym)) end + +const getu = getsym +const setu = setsym diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 1ae30b50..b49700d6 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -110,7 +110,7 @@ state_values(arr, ::Colon) = state_values(arr) Set the state at index `idx` to `val` for value provider `valp`. This defaults to modifying `state_values(valp)`. If any additional bookkeeping needs to be performed or the default implementation does not work for a particular type, this method needs to be -defined to enable the proper functioning of [`setu`](@ref). +defined to enable the proper functioning of [`setsym`](@ref). See: [`state_values`](@ref) """ diff --git a/test/batched_interface_test.jl b/test/batched_interface_test.jl index 13dd415c..533ef617 100644 --- a/test/batched_interface_test.jl +++ b/test/batched_interface_test.jl @@ -28,13 +28,13 @@ bi = BatchedInterface(zip(syss, syms)...) Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0] @test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3] -getter = getu(bi) +getter = getsym(bi) @test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] buf = zeros(8) @inferred getter(buf, probs...) @test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] -setter! = setu(bi) +setter! = setsym(bi) buf .*= 100 setter!(probs..., buf) diff --git a/test/downstream/batchedinterface_arrayvars.jl b/test/downstream/batchedinterface_arrayvars.jl index eaba28c1..f07edc51 100644 --- a/test/downstream/batchedinterface_arrayvars.jl +++ b/test/downstream/batchedinterface_arrayvars.jl @@ -23,13 +23,13 @@ bi = BatchedInterface(zip(syss, syms)...) @test is_variable.((bi,), [x..., y, z]) == Bool[1, 1, 1, 0] @test associated_systems(bi) == [1, 1, 1] -getter = getu(bi) +getter = getsym(bi) @test (@inferred getter(probs...)) == [1.0, 2.0, 3.0] buf = zeros(3) @inferred getter(buf, probs...) @test buf == [1.0, 2.0, 3.0] -setter! = setu(bi) +setter! = setsym(bi) buf .*= 10 setter!(probs..., buf) @@ -58,7 +58,7 @@ probs = [ bi = BatchedInterface(zip(syss, syms)...) -buf = getu(bi)(probs...) +buf = getsym(bi)(probs...) buf .*= 100 setter = setsym_oop(bi) vals = setter(probs..., buf) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 6c864c44..26eaaf51 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -488,7 +488,7 @@ for (sym, val_is_timeseries, val, check_inference) in [ ([:a, :(2b)], true, vcat.(aval, 2 .* bval), true), ((:a, :(2b)), true, tuple.(aval, 2 .* bval), true) ] - getter = getu(sys, sym) + getter = getsym(sys, sym) if check_inference @inferred getter(fs) end @@ -524,7 +524,7 @@ for (sym, val, check_inference) in [ ([:(2b), :(3x)], [2_bval, 3_xval], true), ((:(2b), :(3x)), (2_bval, 3_xval), true) ] - getter = getu(sys, sym) + getter = getsym(sys, sym) @test_throws MixedParameterTimeseriesIndexError getter(fs) for subidx in [1, CartesianIndex(2), :, rand(Bool, 4), rand(1:4, 3), 1:2] @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) diff --git a/test/problem_state_test.jl b/test/problem_state_test.jl index d0609251..906ee735 100644 --- a/test/problem_state_test.jl +++ b/test/problem_state_test.jl @@ -5,11 +5,11 @@ sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5) for (i, sym) in enumerate(variable_symbols(sys)) - @test getu(sys, sym)(prob) == prob.u[i] + @test getsym(sys, sym)(prob) == prob.u[i] end for (i, sym) in enumerate(parameter_symbols(sys)) @test getp(sys, sym)(prob) == prob.p[i] end -@test getu(sys, :t)(prob) == prob.t +@test getsym(sys, :t)(prob) == prob.t -@test getu(sys, :(x + a + t))(prob) == 1.6 +@test getsym(sys, :(x + a + t))(prob) == 1.6 diff --git a/test/simple_adjoints_test.jl b/test/simple_adjoints_test.jl index 329fd104..c3fc2c1b 100644 --- a/test/simple_adjoints_test.jl +++ b/test/simple_adjoints_test.jl @@ -4,14 +4,14 @@ using Zygote sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) pstate = ProblemState(; u = rand(3), p = rand(3), t = rand()) -getter = getu(sys, :x) +getter = getsym(sys, :x) @test Zygote.gradient(getter, pstate)[1].u == [1.0, 0.0, 0.0] -getter = getu(sys, [:x, :z]) +getter = getsym(sys, [:x, :z]) @test Zygote.gradient(sum ∘ getter, pstate)[1].u == [1.0, 0.0, 1.0] -getter = getu(sys, :a) +getter = getsym(sys, :a) @test Zygote.gradient(getter, pstate)[1].p == [1.0, 0.0, 0.0] -getter = getu(sys, [:a, :c]) +getter = getsym(sys, [:a, :c]) @test Zygote.gradient(sum ∘ getter, pstate)[1].p == [1.0, 0.0, 1.0] diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 8e59779a..df655408 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -14,8 +14,8 @@ SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -@test_throws ErrorException getu(sys, :q) -@test_throws ErrorException setu(sys, :q) +@test_throws ErrorException getsym(sys, :q) +@test_throws ErrorException setsym(sys, :q) u = [1.0, 2.0, 3.0] p = [11.0, 12.0, 13.0] @@ -50,8 +50,8 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) (4.0, (5.0, 6.0)), true) ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)] - get = getu(sys, sym) - set! = setu(sys, sym) + get = getsym(sys, sym) + set! = setsym(sys, sym) if check_inference @inferred get(fi) end @@ -84,7 +84,7 @@ for (sym, val, check_inference) in [ ([:(x + y), :z], [u[1] + u[2], u[3]], false), ((:(x + y), :(z + y)), (u[1] + u[2], u[2] + u[3]), false) ] - get = getu(sys, sym) + get = getsym(sys, sym) if check_inference @inferred get(fi) end @@ -92,15 +92,15 @@ for (sym, val, check_inference) in [ end let fi = fi, sys = sys - getter = getu(sys, []) + getter = getsym(sys, []) @test getter(fi) == [] - getter = getu(sys, ()) + getter = getsym(sys, ()) @test getter(fi) == () sc = SymbolCache(nothing, [:a, :b], :t) fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0) - getter = getu(sc, []) + getter = getsym(sc, []) @test getter(fi) == [] - getter = getu(sc, ()) + getter = getsym(sc, ()) @test getter(fi) == () end @@ -111,8 +111,8 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) ((:c, :b), (p[3], p[2]), (6.0, 5.0), true) ([:x, :a], [u[1], p[1]], [4.0, 5.0], false) ((:y, :b), (u[2], p[2]), (5.0, 6.0), true)] - get = getu(fi, sym) - set! = setu(fi, sym) + get = getsym(fi, sym) + set! = setsym(fi, sym) if check_inference @inferred get(fi) end @@ -132,7 +132,7 @@ for (sym, val, check_inference) in [ ([:x, :a, :t], [u[1], p[1], t], false), ((:x, :a, :t), (u[1], p[1], t), false) ] - get = getu(fi, sym) + get = getsym(fi, sym) if check_inference @inferred get(fi) end @@ -202,7 +202,7 @@ for (sym, ans, check_inference) in [(:x, xvals, true) (:t, t, true) ([:x, :a, :t], vcat.(xvals, p[1], t), false) ((:x, :a, :t), tuple.(xvals, p[1], t), true)] - get = getu(sys, sym) + get = getsym(sys, sym) if check_inference @inferred get(sol) end @@ -221,7 +221,7 @@ for (sym, val, check_inference) in [ ([:(x + y), :z], vcat.(xvals .+ yvals, zvals), false), ((:(x + y), :(z + y)), tuple.(xvals .+ yvals, yvals .+ zvals), false) ] - get = getu(sys, sym) + get = getsym(sys, sym) if check_inference @inferred get(sol) end @@ -240,21 +240,21 @@ for (sym, val) in [(:a, p[1]) (:c, p[3]) ([:a, :b], p[1:2]) ((:c, :b), (p[3], p[2]))] - get = getu(sys, sym) + get = getsym(sys, sym) @inferred get(sol) @test get(sol) == val end let sol = sol, sys = sys - getter = getu(sys, []) + getter = getsym(sys, []) @test getter(sol) == [[] for _ in 1:length(sol.t)] - getter = getu(sys, ()) + getter = getsym(sys, ()) @test getter(sol) == [() for _ in 1:length(sol.t)] sc = SymbolCache(nothing, [:a, :b], :t) sol = FakeSolution(sys, [], [1.0, 2.0], []) - getter = getu(sc, []) + getter = getsym(sc, []) @test getter(sol) == [] - getter = getu(sc, ()) + getter = getsym(sc, ()) @test getter(sol) == [] end @@ -285,7 +285,7 @@ for (sym, val, check_inference) in [ ([:(x + a), :(y + b)], [u[1] + p[1], u[2] + p[2]], true), ((:(x + a), :(y + b)), (u[1] + p[1], u[2] + p[2]), true) ] - getter = getu(sys, sym) + getter = getsym(sys, sym) if check_inference @inferred getter(fs) end @@ -318,7 +318,7 @@ ts = 0.0:0.1:1.0 fi = FakeIntegrator(sys, u0, p, ts[1]) fs = FakeSolution(sys, u, p, ts) -getter = getu(sys, :(x + y)) +getter = getsym(sys, :(x + y)) @test getter(fi) ≈ 2.8 @test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11] @test getter(fs, 1) ≈ 2.8