Skip to content

Commit

Permalink
refactor: make getu, getp generic of the parameter container
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Feb 13, 2024
1 parent b6bbb66 commit 1c5ebf7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
23 changes: 16 additions & 7 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""
parameter_values(p)
parameter_values(p, i)
Return an indexable collection containing the value of each parameter in `p`.
Return an indexable collection containing the value of each parameter in `p`. The two-
argument version of this function returns the parameter value at index `i`. The
two-argument version of this function will default to returning
`parameter_values(p)[i]`.
If this function is called with an `AbstractArray`, it will return the same array.
"""
function parameter_values end

parameter_values(arr::AbstractArray) = arr
parameter_values(arr::AbstractArray, i) = arr[i]
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)

"""
set_parameter!(sys, val, idx)
Expand All @@ -19,16 +25,19 @@ defined to enable the proper functioning of [`setp`](@ref).
See: [`parameter_values`](@ref)
"""
function set_parameter!(sys, val, idx)
parameter_values(sys)[idx] = val
function set_parameter! end

function set_parameter!(sys::AbstractArray, val, idx)
sys[idx] = val
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)

"""
getp(sys, p)
Return a function that takes an array representing the parameter vector or an integrator
or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a
direct numerical index or a symbolic value, or an array/tuple of the aforementioned.
direct index or a symbolic value, or an array/tuple of the aforementioned.
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
typically does not need to be implemented, and has a default implementation relying on
Expand All @@ -42,14 +51,14 @@ end

function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
return function getter(sol)
return parameter_values(sol)[p]
return parameter_values(sol, p)

Check warning on line 54 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L54

Added line #L54 was not covered by tests
end
end

function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return function getter(sol)
return parameter_values(sol)[idx]
return parameter_values(sol, idx)
end
end

Expand All @@ -76,7 +85,7 @@ end
Return a function that takes an array representing the parameter vector or an integrator
or problem of `sys`, and a value, and sets the parameter `p` to that value. Note that `p`
can be a direct numerical index or a symbolic value.
can be a direct index or a symbolic value.
Requires that the integrator implement [`parameter_values`](@ref) and the returned
collection be a mutable reference to the parameter vector in the integrator. In
Expand Down
4 changes: 2 additions & 2 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end
function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)
_getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,))
_getter(::Timeseries, prob, i) = getindex(state_values(prob, i), sym)
_getter(::NotTimeseries, prob) = state_values(prob)[sym]
_getter(::NotTimeseries, prob) = state_values(prob, sym)

Check warning on line 119 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L119

Added line #L119 was not covered by tests
return let _getter = _getter
function getter(prob)
return _getter(is_timeseries(prob), prob)
Expand Down Expand Up @@ -266,7 +266,7 @@ end
Return a function that takes an array representing the state vector or an integrator or
problem of `sys`, and a value, and sets the the state `sym` to that value. Note that `sym`
can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned.
can be a direct index, a symbolic state, or an array/tuple of the aforementioned.
Requires that the integrator implement [`state_values`](@ref) and the
returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to
Expand Down

0 comments on commit 1c5ebf7

Please sign in to comment.