Skip to content

Commit

Permalink
Merge pull request #42 from SciML/as/getp-generic
Browse files Browse the repository at this point in the history
refactor: make getp generic of the parameter container
  • Loading branch information
ChrisRackauckas authored Feb 14, 2024
2 parents b6bbb66 + 599575d commit ccbfdc5
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 77 deletions.
4 changes: 2 additions & 2 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ pages = [
"Tutorials" => [
"Using the SciML Symbolic Indexing Interface" => "usage.md",
"Simple Demonstration of a Symbolic System Structure" => "simple_sii_sys.md",
"Implementing the Complete Symbolic Indexing Interface" => "complete_sii.md",
"Implementing the Complete Symbolic Indexing Interface" => "complete_sii.md"
],
"Defining Solution Wrapper Fallbacks" => "solution_wrappers.md",
"API" => "api.md",
"API" => "api.md"
]
11 changes: 6 additions & 5 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getna
include("trait.jl")

export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
observed, is_time_dependent, constant_structure, symbolic_container,
all_variable_symbols,
all_symbols, solvedvariables, allvariables
parameter_symbols, is_independent_variable, independent_variable_symbols,
is_observed,
observed, is_time_dependent, constant_structure, symbolic_container,
all_variable_symbols,
all_symbols, solvedvariables, allvariables
include("interface.jl")

export SymbolCache
Expand All @@ -17,7 +18,7 @@ export parameter_values, set_parameter!, getp, setp
include("parameter_indexing.jl")

export Timeseries,
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

export ParameterIndexingProxy
Expand Down
27 changes: 18 additions & 9 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""
parameter_values(p)
parameter_values(p, i)
Return an indexable collection containing the value of each parameter in `p`.
Return an indexable collection containing the value of each parameter in `p`. The two-
argument version of this function returns the parameter value at index `i`. The
two-argument version of this function will default to returning
`parameter_values(p)[i]`.
If this function is called with an `AbstractArray`, it will return the same array.
"""
function parameter_values end

parameter_values(arr::AbstractArray) = arr
parameter_values(arr::AbstractArray, i) = arr[i]
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)

"""
set_parameter!(sys, val, idx)
Expand All @@ -19,16 +25,19 @@ defined to enable the proper functioning of [`setp`](@ref).
See: [`parameter_values`](@ref)
"""
function set_parameter!(sys, val, idx)
parameter_values(sys)[idx] = val
function set_parameter! end

function set_parameter!(sys::AbstractArray, val, idx)
sys[idx] = val
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)

"""
getp(sys, p)
Return a function that takes an array representing the parameter vector or an integrator
or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a
direct numerical index or a symbolic value, or an array/tuple of the aforementioned.
direct index or a symbolic value, or an array/tuple of the aforementioned.
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
typically does not need to be implemented, and has a default implementation relying on
Expand All @@ -42,21 +51,21 @@ end

function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
return function getter(sol)
return parameter_values(sol)[p]
return parameter_values(sol, p)
end
end

function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return function getter(sol)
return parameter_values(sol)[idx]
return parameter_values(sol, idx)
end
end

for (t1, t2) in [
(ArraySymbolic, Any),
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
getters = getp.((sys,), p)
Expand All @@ -76,7 +85,7 @@ end
Return a function that takes an array representing the parameter vector or an integrator
or problem of `sys`, and a value, and sets the parameter `p` to that value. Note that `p`
can be a direct numerical index or a symbolic value.
can be a direct index or a symbolic value.
Requires that the integrator implement [`parameter_values`](@ref) and the returned
collection be a mutable reference to the parameter vector in the integrator. In
Expand Down Expand Up @@ -106,7 +115,7 @@ end
for (t1, t2) in [
(ArraySymbolic, Any),
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
setters = setp.((sys,), p)
Expand Down
8 changes: 4 additions & 4 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end
function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)
_getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,))
_getter(::Timeseries, prob, i) = getindex(state_values(prob, i), sym)
_getter(::NotTimeseries, prob) = state_values(prob)[sym]
_getter(::NotTimeseries, prob) = state_values(prob, sym)
return let _getter = _getter
function getter(prob)
return _getter(is_timeseries(prob), prob)
Expand Down Expand Up @@ -186,7 +186,7 @@ end
for (t1, t2) in [
(ScalarSymbolic, Any),
(ArraySymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
num_observed = count(x -> is_observed(sys, x), sym)
Expand Down Expand Up @@ -266,7 +266,7 @@ end
Return a function that takes an array representing the state vector or an integrator or
problem of `sys`, and a value, and sets the the state `sym` to that value. Note that `sym`
can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned.
can be a direct index, a symbolic state, or an array/tuple of the aforementioned.
Requires that the integrator implement [`state_values`](@ref) and the
returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to
Expand Down Expand Up @@ -301,7 +301,7 @@ end
for (t1, t2) in [
(ScalarSymbolic, Any),
(ArraySymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2)
setters = setu.((sys,), sym)
Expand Down
2 changes: 1 addition & 1 deletion src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ array containing a single variable if the system has only one independent variab
struct SymbolCache{
V <: Union{Nothing, AbstractVector},
P <: Union{Nothing, AbstractVector},
I,
I
}
variables::V
parameters::P
Expand Down
2 changes: 1 addition & 1 deletion test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ for (sym, oldval, newval, check_inference) in [
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
]
get = getp(sys, sym)
set! = setp(sys, sym)
Expand Down
129 changes: 74 additions & 55 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,33 @@ t = 0.5
fi = FakeIntegrator(sys, copy(u), copy(p), t)
# checking inference for non-concretely typed arrays will always fail
for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
(:y, u[2], 4.0, true)
(:z, u[3], 4.0, true)
(1, u[1], 4.0, true)
([:x, :y], u[1:2], 4ones(2), true)
([1, 2], u[1:2], 4ones(2), true)
((:z, :y), (u[3], u[2]), (4.0, 5.0), true)
((3, 2), (u[3], u[2]), (4.0, 5.0), true)
([:x, [:y, :z]], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false)
([:x, 2:3], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false)
([:x, (:y, :z)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false)
([:x, Tuple(2:3)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false)
([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false)
([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false)
((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true)
((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)]
(:y, u[2], 4.0, true)
(:z, u[3], 4.0, true)
(1, u[1], 4.0, true)
([:x, :y], u[1:2], 4ones(2), true)
([1, 2], u[1:2], 4ones(2), true)
((:z, :y), (u[3], u[2]), (4.0, 5.0), true)
((3, 2), (u[3], u[2]), (4.0, 5.0), true)
([:x, [:y, :z]], [u[1], u[2:3]],
[4.0, [5.0, 6.0]], false)
([:x, 2:3], [u[1], u[2:3]],
[4.0, [5.0, 6.0]], false)
([:x, (:y, :z)], [u[1], (u[2], u[3])],
[4.0, (5.0, 6.0)], false)
([:x, Tuple(2:3)], [u[1], (u[2], u[3])],
[4.0, (5.0, 6.0)], false)
([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)],
[4.0, [5.0], (6.0,)], false)
([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)],
[4.0, [5.0], (6.0,)], false)
((:x, [:y, :z]), (u[1], u[2:3]),
(4.0, [5.0, 6.0]), true)
((:x, (:y, :z)), (u[1], (u[2], u[3])),
(4.0, (5.0, 6.0)), true)
((1, (:y, :z)), (u[1], (u[2], u[3])),
(4.0, (5.0, 6.0)), true)
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)),
(4.0, [5.0], (6.0,)), true)]
get = getu(sys, sym)
set! = setu(sys, sym)
if check_inference
Expand Down Expand Up @@ -66,12 +76,12 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
end

for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
(:b, p[2], 5.0, true)
(:c, p[3], 6.0, true)
([:a, :b], p[1:2], [4.0, 5.0], true)
((:c, :b), (p[3], p[2]), (6.0, 5.0), true)
([:x, :a], [u[1], p[1]], [4.0, 5.0], false)
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)]
(:b, p[2], 5.0, true)
(:c, p[3], 6.0, true)
([:a, :b], p[1:2], [4.0, 5.0], true)
((:c, :b), (p[3], p[2]), (6.0, 5.0), true)
([:x, :a], [u[1], p[1]], [4.0, 5.0], false)
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)]
get = getu(fi, sym)
set! = setu(fi, sym)
if check_inference
Expand All @@ -91,7 +101,7 @@ end
for (sym, val, check_inference) in [
(:t, t, true),
([:x, :a, :t], [u[1], p[1], t], false),
((:x, :a, :t), (u[1], p[1], t), true),
((:x, :a, :t), (u[1], p[1], t), true)
]
get = getu(fi, sym)
if check_inference
Expand Down Expand Up @@ -123,33 +133,42 @@ yvals = getindex.(sol.u, 2)
zvals = getindex.(sol.u, 3)

for (sym, ans, check_inference) in [(:x, xvals, true)
(:y, yvals, true)
(:z, zvals, true)
(1, xvals, true)
([:x, :y], vcat.(xvals, yvals), true)
(1:2, vcat.(xvals, yvals), true)
([:x, 2], vcat.(xvals, yvals), false)
((:z, :y), tuple.(zvals, yvals), true)
((3, 2), tuple.(zvals, yvals), true)
([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false)
([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
([:x, [:y, :z], (:x, :z)],
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)),
false)
([:x, [:y, 3], (1, :z)],
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)),
false)
((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true)
((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true)
((:x, [:y, :z], (:z, :y)),
tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)),
true)
([:x, :a], vcat.(xvals, p[1]), false)
((:y, :b), tuple.(yvals, p[2]), true)
(:t, t, true)
([:x, :a, :t], vcat.(xvals, p[1], t), false)
((:x, :a, :t), tuple.(xvals, p[1], t), true)]
(:y, yvals, true)
(:z, zvals, true)
(1, xvals, true)
([:x, :y], vcat.(xvals, yvals), true)
(1:2, vcat.(xvals, yvals), true)
([:x, 2], vcat.(xvals, yvals), false)
((:z, :y), tuple.(zvals, yvals), true)
((3, 2), tuple.(zvals, yvals), true)
([:x, [:y, :z]],
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]),
false)
([:x, (:y, :z)],
vcat.(xvals, tuple.(yvals, zvals)), false)
([1, (:y, :z)],
vcat.(xvals, tuple.(yvals, zvals)), false)
([:x, [:y, :z], (:x, :z)],
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)],
tuple.(xvals, zvals)),
false)
([:x, [:y, 3], (1, :z)],
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)],
tuple.(xvals, zvals)),
false)
((:x, [:y, :z]),
tuple.(xvals, vcat.(yvals, zvals)), true)
((:x, (:y, :z)),
tuple.(xvals, tuple.(yvals, zvals)), true)
((:x, [:y, :z], (:z, :y)),
tuple.(xvals, vcat.(yvals, zvals),
tuple.(zvals, yvals)),
true)
([:x, :a], vcat.(xvals, p[1]), false)
((:y, :b), tuple.(yvals, p[2]), true)
(:t, t, true)
([:x, :a, :t], vcat.(xvals, p[1], t), false)
((:x, :a, :t), tuple.(xvals, p[1], t), true)]
get = getu(sys, sym)
if check_inference
@inferred get(sol)
Expand All @@ -164,10 +183,10 @@ for (sym, ans, check_inference) in [(:x, xvals, true)
end

for (sym, val) in [(:a, p[1])
(:b, p[2])
(:c, p[3])
([:a, :b], p[1:2])
((:c, :b), (p[3], p[2]))]
(:b, p[2])
(:c, p[3])
([:a, :b], p[1:2])
((:c, :b), (p[3], p[2]))]
get = getu(fi, sym)
@inferred get(fi)
@test get(fi) == val
Expand Down

0 comments on commit ccbfdc5

Please sign in to comment.