Skip to content

Commit

Permalink
Merge pull request #71 from SciML/as/getu-struct
Browse files Browse the repository at this point in the history
refactor: return functors from getu and getp
  • Loading branch information
ChrisRackauckas authored Apr 30, 2024
2 parents 632032f + f08f203 commit a148e4f
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 336 deletions.
9 changes: 6 additions & 3 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ export SymbolCache
include("symbol_cache.jl")

export parameter_values, set_parameter!, finalize_parameters_hook!,
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp,
setp
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries,
state_values, set_state!, current_time
include("value_provider_interface.jl")

export getp, setp
include("parameter_indexing.jl")

export state_values, set_state!, current_time, getu, setu
export getu, setu
include("state_indexing.jl")

export BatchedInterface, associated_systems
Expand Down
239 changes: 75 additions & 164 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,101 +1,21 @@
"""
parameter_values(p)
parameter_values(p, i)
Return an indexable collection containing the value of each parameter in `p`. The two-
argument version of this function returns the parameter value at index `i`. The
two-argument version of this function will default to returning
`parameter_values(p)[i]`.
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
array/tuple.
"""
function parameter_values end

parameter_values(arr::AbstractArray) = arr
parameter_values(arr::Tuple) = arr
parameter_values(arr::AbstractArray, i) = arr[i]
parameter_values(arr::Tuple, i) = arr[i]
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)

"""
parameter_values_at_time(p, i)
Return an indexable collection containing the value of all parameters in `p` at time index
`i`. This is useful when parameter values change during the simulation
(such as through callbacks) and their values are saved. `i` is the time index in the
timeseries formed by these changing parameter values, obtained using
[`parameter_timeseries`](@ref).
By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
to be specialized for timeseries objects where parameter values are not constant at all
times. The resultant object should be indexable using [`parameter_values`](@ref).
If this function is implemented, [`parameter_values_at_state_time`](@ref) must be
implemented for [`getu`](@ref) to work correctly.
"""
function parameter_values_at_time end
parameter_values_at_time(p, i) = parameter_values(p)

"""
parameter_values_at_state_time(p, i)
Return an indexable collection containing the value of all parameters in `p` at time
index `i`. This is useful when parameter values change during the simulation (such as
through callbacks) and their values are saved. `i` is the time index in the timeseries
formed by dependent variables (as opposed to the timeseries of the parameters, as in
[`parameter_values_at_time`](@ref)).
By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
to be specialized for timeseries objects where parameter values are not constant at
all times. The resultant object should be indexable using [`parameter_values`](@ref).
If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for
[`getp`](@ref) to work correctly.
"""
function parameter_values_at_state_time end
parameter_values_at_state_time(p, i) = parameter_values(p)

"""
parameter_timeseries(p)
Return an iterable of time steps at which the parameter values are saved. This is only
required for objects where `is_timeseries(p) === Timeseries()` and the parameter values
change during the simulation (such as through callbacks). By default, this returns `[0]`.
See also: [`parameter_values_at_time`](@ref).
"""
function parameter_timeseries end
parameter_timeseries(_) = [0]

"""
set_parameter!(sys, val, idx)
Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying
`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the
default implementation does not work for a particular type, this method needs to be
defined to enable the proper functioning of [`setp`](@ref).
See: [`parameter_values`](@ref)
"""
function set_parameter! end

# Tuple only included for the error message
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)
sys[idx] = val
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)

"""
finalize_parameters_hook!(prob, p)
This is a callback run one for each call to the function returned by [`setp`](@ref)
which can be used to update internal data structures when parameters are modified.
This is in contrast to [`set_parameter!`](@ref) which is run once for each parameter
that is updated.
"""
finalize_parameters_hook!(prob, p) = nothing

"""
getp(sys, p)
Expand Down Expand Up @@ -125,43 +45,87 @@ function getp(sys, p)
_getp(sys, symtype, elsymtype, p)
end

struct GetParameterIndex{I} <: AbstractIndexer
idx::I
end

function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob)
parameter_values(prob, gpi.idx)
end
function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(
parameter_values_at_time(
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
gpi.idx)
end
function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
parameter_values.(
parameter_values_at_time.((prob,),
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
(gpi.idx,))
end
function (gpi::GetParameterIndex)(::Timeseries, prob, i)
parameter_values.(parameter_values_at_time.((prob,), i), (gpi.idx,))
end

function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
return let p = p
function _getter(::NotTimeseries, prob)
parameter_values(prob, p)
end
function _getter(::Timeseries, prob)
parameter_values(prob, p)
end
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(
parameter_values_at_time(
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
p)
end
function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
parameter_values.(
parameter_values_at_time.((prob,),
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
(p,))
end
function _getter(::Timeseries, prob, i)
parameter_values.(parameter_values_at_time.((prob,), i), (p,))
end
getter = let _getter = _getter
function getter(prob, args...)
return _getter(is_timeseries(prob), prob, args...)
end
end
getter
end
return GetParameterIndex(p)
end

function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any},
sys, NotSymbolic(), NotSymbolic(), idx)
return _getp(sys, NotSymbolic(), NotSymbolic(), idx)
end

struct MultipleParameterGetters{G}
getters::G
end

function (mpg::MultipleParameterGetters)(::IsTimeseriesTrait, prob)
map(g -> g(prob), mpg.getters)
end
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i::Union{Int, CartesianIndex})
map(g -> g(prob, i), mpg.getters)
end
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i)
[map(g -> g(prob, j), mpg.getters)
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
end
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob)
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer
end
function (mpg::MultipleParameterGetters)(
buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex})
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
buffer[bufi] = g(prob, i)
end
buffer
end
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob, i)
for (bufi, tsi) in zip(
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
for (g, bufj) in zip(mpg.getters, eachindex(buffer[bufi]))
buffer[bufi][bufj] = g(prob, tsi)
end
end
buffer
end
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::NotTimeseries, prob)
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer
end

function (mpg::MultipleParameterGetters)(buffer::AbstractArray, prob, i...)
mpg(buffer, is_timeseries(prob), prob, i...)
end
function (mpg::MultipleParameterGetters)(prob, i...)
mpg(is_timeseries(prob), prob, i...)
end

for (t1, t2) in [
Expand All @@ -171,60 +135,7 @@ for (t1, t2) in [
]
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
getters = getp.((sys,), p)

return let getters = getters
function _getter(::NotTimeseries, prob)
map(g -> g(prob), getters)
end
function _getter(::Timeseries, prob)
map(g -> g(prob), getters)
end
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
map(g -> g(prob, i), getters)
end
function _getter(::Timeseries, prob, i)
[map(g -> g(prob, j), getters)
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
end
function _getter!(buffer, ::NotTimeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer
end
function _getter!(buffer, ::Timeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer
end
function _getter!(buffer, ::Timeseries, prob, i::Union{Int, CartesianIndex})
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob, i)
end
buffer
end
function _getter!(buffer, ::Timeseries, prob, i)
for (bufi, tsi) in zip(
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
for (g, bufj) in zip(getters, eachindex(buffer[bufi]))
buffer[bufi][bufj] = g(prob, tsi)
end
end
buffer
end
_getter, _getter!
getter = let _getter = _getter, _getter! = _getter!
function getter(prob, i...)
return _getter(is_timeseries(prob), prob, i...)
end
function getter(buffer::AbstractArray, prob, i...)
return _getter!(buffer, is_timeseries(prob), prob, i...)
end
getter
end
getter
end
return MultipleParameterGetters(getters)
end
end

Expand Down
Loading

0 comments on commit a148e4f

Please sign in to comment.