Skip to content

Commit

Permalink
Merge pull request #106 from SciML/as/setsym-oop
Browse files Browse the repository at this point in the history
feat: add `setsym_oop`
  • Loading branch information
ChrisRackauckas authored Oct 29, 2024
2 parents 7b37511 + e8934c4 commit 33cef2f
Show file tree
Hide file tree
Showing 16 changed files with 263 additions and 123 deletions.
11 changes: 7 additions & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ is_markovian

If the index provider contains parameters that change during the course of the simulation
at discrete time points, it must implement the following methods to ensure correct
functioning of [`getu`](@ref) and [`getp`](@ref) for value providers that save the parameter
functioning of [`getsym`](@ref) and [`getp`](@ref) for value providers that save the parameter
timeseries. Note that there can be multiple parameter timeseries, in case different parameters
may change at different times.

Expand All @@ -69,10 +69,13 @@ is_timeseries
state_values
set_state!
current_time
getu
setu
getsym
setsym
```

!!! note
`getu` and `setu` have been renamed to [`getsym`](@ref) and [`setsym`](@ref) respectively.

#### Historical value providers

```@docs
Expand All @@ -95,7 +98,7 @@ ParameterIndexingProxy

If a solution object saves a timeseries of parameter values that are updated during the
simulation (such as by callbacks), it must implement the following methods to ensure
correct functioning of [`getu`](@ref) and [`getp`](@ref).
correct functioning of [`getsym`](@ref) and [`getp`](@ref).

Parameter timeseries support requires that the value provider store the different
timeseries in a [`ParameterTimeseriesCollection`](@ref).
Expand Down
28 changes: 14 additions & 14 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ end
```

If a type contains the value of state variables, it can define [`state_values`](@ref) to
enable the usage of [`getu`](@ref) and [`setu`](@ref). These methods retturn getter/
enable the usage of [`getsym`](@ref) and [`setsym`](@ref). These methods retturn getter/
setter functions to access or update the value of a state variable (or a collection of
them). If the type also supports generating [`observed`](@ref) functions, `getu` also
them). If the type also supports generating [`observed`](@ref) functions, `getsym` also
enables returning functions to access the value of arbitrary expressions involving
the system's symbols. This also requires that the type implement
[`parameter_values`](@ref) and [`current_time`](@ref) (if the system is time-dependent).
Expand All @@ -202,21 +202,21 @@ Then the following example would work:
```julia
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict())
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
getx = getu(sys, :x)
getx = getsym(sys, :x)
getx(integrator) # 1.0

get_expr = getu(sys, :(x + y + t))
get_expr = getsym(sys, :(x + y + t))
get_expr(integrator) # 13.0

setx! = setu(sys, :y)
setx! = setsym(sys, :y)
setx!(integrator, 0.0)
getx(integrator) # 0.0
```

In case a type stores timeseries data (such as solutions), then it must also implement
the [`Timeseries`](@ref) trait. The type would then return a timeseries from
[`state_values`](@ref) and [`current_time`](@ref) and the function returned from
[`getu`](@ref) would then return a timeseries as well. For example, consider the
[`getsym`](@ref) would then return a timeseries as well. For example, consider the
`ExampleSolution` below:

```julia
Expand All @@ -242,20 +242,20 @@ Then the following example would work:
```julia
# using the same system that the ExampleIntegrator used
sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys)
getx = getu(sys, :x)
getx = getsym(sys, :x)
getx(sol) # [1.0, 1.5]

get_expr = getu(sys, :(x + y + t))
get_expr = getsym(sys, :(x + y + t))
get_expr(sol) # [9.0, 11.0]

get_arr = getu(sys, [:y, :(x + a)])
get_arr = getsym(sys, [:y, :(x + a)])
get_arr(sol) # [[2.0, 5.0], [2.5, 5.5]]

get_tuple = getu(sys, (:z, :(z * t)))
get_tuple = getsym(sys, (:z, :(z * t)))
get_tuple(sol) # [(3.0, 18.0), (3.5, 24.5)]
```

Note that `setu` is not designed to work for `Timeseries` objects.
Note that `setsym` is not designed to work for `Timeseries` objects.

If a type needs to perform some additional actions when updating the state/parameters
or if it is not possible to return a mutable reference to the state/parameter vector
Expand Down Expand Up @@ -315,7 +315,7 @@ setp(integrator, :b)(integrator, 3.0) # functionally the same as above

If a solution object includes modified parameter values (such as through callbacks) during the
simulation, it must implement several additional functions for correct functioning of
[`getu`](@ref) and [`getp`](@ref). [`ParameterTimeseriesCollection`](@ref) helps in
[`getsym`](@ref) and [`getp`](@ref). [`ParameterTimeseriesCollection`](@ref) helps in
implementing parameter timeseries objects. The following mockup gives an example of
correct implementation of these functions and the indexing syntax they enable.

Expand Down Expand Up @@ -457,12 +457,12 @@ sol.ps[:(b + c)] # observed quantities work too
```

```@example param_timeseries
getu(sol, :b)(sol) # works
getsym(sol, :b)(sol) # works
```

```@example param_timeseries
try
getu(sol, [:b, :d])(sol) # errors since :b and :d belong to different timeseries
getsym(sol, [:b, :d])(sol) # errors since :b and :d belong to different timeseries
catch e
showerror(stdout, e)
end
Expand Down
6 changes: 3 additions & 3 deletions docs/src/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ sol[allvariables] # equivalent to sol[all_variable_symbols(sol)]

### Evaluating expressions

`getu` also generates functions for expressions if the object passed to it supports
`getsym` also generates functions for expressions if the object passed to it supports
[`observed`](@ref). For example:

```@example Usage
getu(prob, x + y + z)(prob)
getsym(prob, x + y + z)(prob)
```

To evaluate this function using values other than the ones contained in `prob`, we need
Expand All @@ -137,7 +137,7 @@ which has trivial implementations of the above functions. We can thus do:

```@example Usage
temp_state = ProblemState(; u = [0.1, 0.2, 0.3, 0.4], p = parameter_values(prob))
getu(prob, x + y + z)(temp_state)
getsym(prob, x + y + z)(temp_state)
```

Note that providing all of the state vector, parameter object and time may not be
Expand Down
2 changes: 1 addition & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ include("parameter_timeseries_collection.jl")
export getp, setp, setp_oop
include("parameter_indexing.jl")

export getu, setu
export getsym, setsym, getu, setu
include("state_indexing.jl")

export BatchedInterface, setsym_oop, associated_systems
Expand Down
16 changes: 8 additions & 8 deletions src/batched_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
struct BatchedInterface{S <: AbstractVector, I}
function BatchedInterface(indp_syms::Tuple...)
A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref).
A struct which stores information for batched calls to [`getsym`](@ref) or [`setsym`](@ref).
Given `Tuple`s, where the first element of each tuple is an index provider and the second
an array of symbolic variables (either states or parameters) in the index provider,
`BatchedInterface` will compute the union of all symbols and associate each symbol with
Expand All @@ -17,7 +17,7 @@ be retained internally.
`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref),
[`variable_index`](@ref) to query the order of symbols in the union.
See [`getu`](@ref) and [`setu`](@ref) for further details.
See [`getsym`](@ref) and [`setsym`](@ref) for further details.
See also: [`associated_systems`](@ref).
"""
Expand Down Expand Up @@ -118,7 +118,7 @@ is the index of the index provider associated with the corresponding symbol in
associated_systems(bi::BatchedInterface) = bi.associated_systems

"""
getu(bi::BatchedInterface)
getsym(bi::BatchedInterface)
Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding
symbols), return a function which takes `n` corresponding value providers and returns an
Expand All @@ -127,7 +127,7 @@ an `AbstractArray` of the appropriate `eltype` and size as its first argument, i
case the operation will populate the array in-place with the values of the symbols in the
union.
Note that all of the value providers passed to the function returned by `getu` must satisfy
Note that all of the value providers passed to the function returned by `getsym` must satisfy
`is_timeseries(prob) === NotTimeseries()`.
The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is
Expand All @@ -137,7 +137,7 @@ provider at index `associated_systems(bi)[i]`).
See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref),
[`NotTimeseries`](@ref).
"""
function getu(bi::BatchedInterface)
function getsym(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

Expand Down Expand Up @@ -189,13 +189,13 @@ function getu(bi::BatchedInterface)
end

"""
setu(bi::BatchedInterface)
setsym(bi::BatchedInterface)
Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding
symbols), return a function which takes `n` corresponding problems and an array of the
values, and updates each of the problems with the values of the corresponding symbols.
Note that all of the value providers passed to the function returned by `setu` must satisfy
Note that all of the value providers passed to the function returned by `setsym` must satisfy
`is_timeseries(prob) === NotTimeseries()`.
Note that if any subset of the `n` index providers share common symbols (among those passed
Expand All @@ -204,7 +204,7 @@ updated with the values of the common symbols.
See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
"""
function setu(bi::BatchedInterface)
function setsym(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

Expand Down
34 changes: 4 additions & 30 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -726,52 +726,26 @@ function setp_oop(indp, sym)
return _setp_oop(indp, symtype, elsymtype, sym)
end

struct OOPSetter{I, D}
indp::I
idxs::D
end

function (os::OOPSetter)(valp, val)
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
end

function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
if os.idxs isa Union{Tuple, AbstractArray}
return remake_buffer(os.indp, parameter_values(valp), os.idxs, val)
else
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
end
end

function _root_indp(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) != indp
return _root_indp(sc)
else
return indp
end
end

function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym)
return OOPSetter(_root_indp(indp), sym)
return OOPSetter(_root_indp(indp), sym, false)
end

function _setp_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false)
end

for (t1, t2) in [
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _setp_oop(indp, ::NotSymbolic, ::$t1, sym::$t2)
return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym))
return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym), false)
end
end

function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
if is_parameter(indp, sym)
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false)
end
error("$sym is not a valid parameter")
end
2 changes: 1 addition & 1 deletion src/problem_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
function ProblemState(; u = nothing, p = nothing, t = nothing)
A value provider struct which can be used as an argument to the function returned by
[`getu`](@ref) or [`setu`](@ref). It stores the state vector, parameter object and
[`getsym`](@ref) or [`setsym`](@ref). It stores the state vector, parameter object and
current time, and forwards calls to [`state_values`](@ref), [`parameter_values`](@ref),
[`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained
objects.
Expand Down
4 changes: 2 additions & 2 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function remake_buffer(sys, oldbuffer::AbstractArray, idxs, vals)
else
v = elT(v)
end
setu(sys, k)(newbuffer, v)
setsym(sys, k)(newbuffer, v)
end
else
mutbuffer = remake_buffer(sys, collect(oldbuffer), idxs, vals)
Expand Down Expand Up @@ -80,7 +80,7 @@ end
function remake_buffer(sys, oldbuffer::Tuple, idxs, vals)
wrap = TupleRemakeWrapper(oldbuffer)
for (idx, val) in zip(idxs, vals)
setu(sys, idx)(wrap, val)
setsym(sys, idx)(wrap, val)
end
return wrap.t
end
Expand Down
Loading

0 comments on commit 33cef2f

Please sign in to comment.