Skip to content

Commit

Permalink
feat: improve getu/setu/getp/setp handling of nested variables
Browse files Browse the repository at this point in the history
- also addresses type-stability of the closures returned from the above functions
  • Loading branch information
AayushSabharwal committed Jan 9, 2024
1 parent a7e0efb commit 55eca79
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 84 deletions.
49 changes: 24 additions & 25 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,43 @@ end
getp(sys, p)
Return a function that takes an integrator or solution of `sys`, and returns the value of
the parameter `p`. Note that `p` can be a direct numerical index or a symbolic value.
the parameter `p`. Note that `p` can be a direct numerical index or a symbolic value, or
an array/tuple of the aforementioned.
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
typically does not need to be implemented, and has a default implementation relying on
[`parameter_values`](@ref).
"""
function getp(sys, p)
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
if symtype != NotSymbolic()
return _getp(sys, symtype, p)
else
return _getp(sys, elsymtype, p)
end
_getp(sys, symtype, elsymtype, p)
end

function _getp(sys, ::NotSymbolic, p)
function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)

Check warning on line 39 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L39

Added line #L39 was not covered by tests
return function getter(sol)
return parameter_values(sol)[p]
end
end

function _getp(sys, ::ScalarSymbolic, p)
function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return function getter(sol)
return parameter_values(sol)[idx]
end
end

function _getp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
idxs = parameter_index.((sys,), p)
return function getter(sol)
return getindex.((parameter_values(sol),), idxs)
for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
getters = getp.((sys,), p)

return function getter(sol)
map(g -> g(sol), getters)
end
end
end

function _getp(sys, ::ArraySymbolic, p)
function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p)

Check warning on line 62 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L62

Added line #L62 was not covered by tests
return getp(sys, collect(p))
end

Expand All @@ -76,33 +77,31 @@ implemented.
function setp(sys, p)
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
if symtype != NotSymbolic()
return _setp(sys, symtype, p)
else
return _setp(sys, elsymtype, p)
end
_setp(sys, symtype, elsymtype, p)
end

function _setp(sys, ::NotSymbolic, p)
function _setp(sys, ::NotSymbolic, ::NotSymbolic, p)

Check warning on line 83 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L83

Added line #L83 was not covered by tests
return function setter!(sol, val)
set_parameter!(sol, val, p)
end
end

function _setp(sys, ::ScalarSymbolic, p)
function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return function setter!(sol, val)
set_parameter!(sol, val, idx)
end
end

function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
idxs = parameter_index.((sys,), p)
return function setter!(sol, val)
set_parameter!.((sol,), val, idxs)
for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
setters = setp.((sys,), p)
return function setter!(sol, val)
map((s!, v) -> s!(sol, v), setters, val)

Check warning on line 100 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L97-L100

Added lines #L97 - L100 were not covered by tests
end
end
end

function _setp(sys, ::ArraySymbolic, p)
function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p)

Check warning on line 105 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L105

Added line #L105 was not covered by tests
return setp(sys, collect(p))
end
89 changes: 50 additions & 39 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,20 @@ relying on the above functions.
function getu(sys, sym)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

if symtype != NotSymbolic()
_getu(sys, symtype, sym)
else
_getu(sys, elsymtype, sym)
end
_getu(sys, symtype, elsymtype, sym)

Check warning on line 96 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L96

Added line #L96 was not covered by tests
end

function _getu(sys, ::NotSymbolic, sym)
function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)

Check warning on line 99 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L99

Added line #L99 was not covered by tests
_getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,))
_getter(::NotTimeseries, prob) = state_values(prob)[sym]
return function getter(prob)
return _getter(is_timeseries(prob), prob)
return let _getter = _getter
function getter(prob)
return _getter(is_timeseries(prob), prob)

Check warning on line 104 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L102-L104

Added lines #L102 - L104 were not covered by tests
end
end
end

function _getu(sys, ::ScalarSymbolic, sym)
function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)

Check warning on line 109 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L109

Added line #L109 was not covered by tests
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return getu(sys, idx)
Expand All @@ -125,8 +122,10 @@ function _getu(sys, ::ScalarSymbolic, sym)
return fn(state_values(prob), parameter_values(prob), current_time(prob))
end

return function getter2(prob)
return _getter2(is_timeseries(prob), prob)
return let _getter2 = _getter2
function getter2(prob)
return _getter2(is_timeseries(prob), prob)

Check warning on line 127 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L125-L127

Added lines #L125 - L127 were not covered by tests
end
end
else
function _getter3(::Timeseries, prob)
Expand All @@ -136,8 +135,10 @@ function _getu(sys, ::ScalarSymbolic, sym)
return fn(state_values(prob), parameter_values(prob))
end

return function getter3(prob)
return _getter3(is_timeseries(prob), prob)
return let _getter3 = _getter3
function getter3(prob)
return _getter3(is_timeseries(prob), prob)

Check warning on line 140 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L138-L140

Added lines #L138 - L140 were not covered by tests
end
end
end
end
Expand All @@ -153,24 +154,38 @@ state_values(t::TimeseriesIndexWrapper) = state_values(t.timeseries)[t.idx]
parameter_values(t::TimeseriesIndexWrapper) = parameter_values(t.timeseries)
current_time(t::TimeseriesIndexWrapper) = current_time(t.timeseries)[t.idx]

function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray})
getters = getu.((sys,), sym)
_call(getter, prob) = getter(prob)
for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
getters = getu.((sys,), sym)
_call(getter, prob) = getter(prob)

Check warning on line 160 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L158-L160

Added lines #L158 - L160 were not covered by tests

function _getter(::Timeseries, prob)
tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob)))
return [_getter(NotTimeseries(), tiw) for tiw in tiws]
end
_getter(::NotTimeseries, prob) = _call.(getters, (prob,))
return function getter(prob)
return _getter(is_timeseries(prob), prob)
return let getters = getters, _call = _call
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
function _getter(::Timeseries, prob)
tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob)))

Check warning on line 165 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L162-L165

Added lines #L162 - L165 were not covered by tests
# Ideally this should recursively call `_getter` but that leads to type-instability
# since the reference to itself is boxed
# Turning this broadcasted `_call` into a map also makes this type-unstable

return map(tiw -> _call.(getters, (tiw,)), tiws)

Check warning on line 170 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L170

Added line #L170 was not covered by tests
end

# Need another scope for this to not box `_getter`
let _getter = _getter
function getter(prob)
return _getter(is_timeseries(prob), prob)

Check warning on line 176 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L174-L176

Added lines #L174 - L176 were not covered by tests
end
end
end
end
end

function _getu(sys, ::ArraySymbolic, sym)
function _getu(sys, ::ArraySymbolic, ::NotSymbolic, sym)

Check warning on line 183 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L183

Added line #L183 was not covered by tests
return getu(sys, collect(sym))
end

# setu doesn't need the same `let` blocks to be inferred for some reason

"""
setu(sys, sym)
Expand All @@ -186,36 +201,32 @@ This function does not work on types for which [`is_timeseries`](@ref) is
function setu(sys, sym)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

if symtype != NotSymbolic()
_setu(sys, symtype, sym)
else
_setu(sys, elsymtype, sym)
end
_setu(sys, symtype, elsymtype, sym)

Check warning on line 204 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L204

Added line #L204 was not covered by tests
end

function _setu(sys, ::NotSymbolic, sym)
function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym)

Check warning on line 207 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L207

Added line #L207 was not covered by tests
return function setter!(prob, val)
set_state!(prob, val, sym)
end
end

function _setu(sys, ::ScalarSymbolic, sym)
function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)

Check warning on line 213 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L213

Added line #L213 was not covered by tests
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`")
idx = variable_index(sys, sym)
return function setter!(prob, val)
set_state!(prob, val, idx)
end
end

function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray})
setters = setu.((sys,), sym)
_call!(setter!, prob, val) = setter!(prob, val)
return function setter!(prob, val)
_call!.(setters, (prob,), val)
for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
@eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2)
setters = setu.((sys,), sym)
return function setter!(prob, val)
map((s!, v) -> s!(prob, v), setters, val)

Check warning on line 225 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L222-L225

Added lines #L222 - L225 were not covered by tests
end
end
end

function _setu(sys, ::ArraySymbolic, sym)
function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym)

Check warning on line 230 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L230

Added line #L230 was not covered by tests
return setu(sys, collect(sym))
end
38 changes: 30 additions & 8 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,37 @@ end
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p

sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
p = [1.0, 2.0]
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(p))
for (i, sym) in [(1, :a), (2, :b), ([1, 2], [:a, :b]), ((1, 2), (:a, :b))]
new_p = [4.0, 5.0, 6.0]
for (sym, oldval, newval, check_inference) in [
(:a, p[1], new_p[1], true),
(1, p[1], new_p[1], true),
([:a, :b], p[1:2], new_p[1:2], true),
(1:2, p[1:2], new_p[1:2], true),
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
]
get = getp(sys, sym)
set! = setp(sys, sym)
true_value = i isa Tuple ? getindex.((p,), i) : p[i]
@test get(fi) == ParameterIndexingProxy(fi)[sym] == true_value
set!(fi, 0.5 .* i)
@test get(fi) == ParameterIndexingProxy(fi)[sym] == 0.5 .* i
set!(fi, true_value)
if check_inference
@inferred get(fi)
end
@test get(fi) == oldval
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
end
73 changes: 61 additions & 12 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,44 @@ SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
u = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(u))
for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))]
# checking inference for non-concretely typed arrays will always fail
for (sym, val, newval, check_inference) in [
(:x, u[1], 4.0, true)
(:y, u[2], 4.0, true)
(:z, u[3], 4.0, true)
(1, u[1], 4.0, true)
([:x, :y], u[1:2], 4ones(2), true)
([1, 2], u[1:2], 4ones(2), true)
((:z, :y), (u[3], u[2]), (4.0, 5.0), true)
((3, 2), (u[3], u[2]), (4.0, 5.0), true)
([:x, [:y, :z]], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false)
([:x, 2:3], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false)
([:x, (:y, :z)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false)
([:x, Tuple(2:3)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false)
([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false)
([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false)
((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true)
((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)
]
get = getu(sys, sym)
set! = setu(sys, sym)
true_value = i isa Tuple ? getindex.((u,), i) : u[i]
@test get(fi) == true_value
set!(fi, 0.5 .* i)
@test get(fi) == 0.5 .* i
set!(fi, true_value)
if check_inference
@inferred get(fi)
end
@test get(fi) == val
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test get(fi) == newval
set!(fi, val)
@test get(fi) == val
end


struct FakeSolution{S, U}
sys::S
u::U
Expand All @@ -33,12 +61,33 @@ SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
sol = FakeSolution(sys, u)
for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))]

xvals = getindex.(sol.u, 1)
yvals = getindex.(sol.u, 2)
zvals = getindex.(sol.u, 3)

for (sym, ans, check_inference) in [
(:x, xvals, true)
(:y, yvals, true)
(:z, zvals, true)
(1, xvals, true)
([:x, :y], vcat.(xvals, yvals), true)
(1:2, vcat.(xvals, yvals), true)
([:x, 2], vcat.(xvals, yvals), false)
((:z, :y), tuple.(zvals, yvals), true)
((3, 2), tuple.(zvals, yvals), true)
([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false)
([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
([:x, [:y, :z], (:x, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false)
([:x, [:y, 3], (1, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false)
((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true)
((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true)
((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true)
]
get = getu(sys, sym)
true_value = if i isa Tuple
[getindex.((v,), i) for v in u]
else
getindex.(u, (i,))
if check_inference
@inferred get(sol)
end
@test get(sol) == true_value
@test get(sol) == ans
end

0 comments on commit 55eca79

Please sign in to comment.