From 55eca79d2e9c1f85d337e5a53bb0cbb62edfb704 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 9 Jan 2024 14:55:24 +0530 Subject: [PATCH] feat: improve getu/setu/getp/setp handling of nested variables - also addresses type-stability of the closures returned from the above functions --- src/parameter_indexing.jl | 49 +++++++++--------- src/state_indexing.jl | 89 ++++++++++++++++++--------------- test/parameter_indexing_test.jl | 38 +++++++++++--- test/state_indexing_test.jl | 73 ++++++++++++++++++++++----- 4 files changed, 165 insertions(+), 84 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 52430728..0b14eea0 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -23,7 +23,9 @@ end getp(sys, p) Return a function that takes an integrator or solution of `sys`, and returns the value of -the parameter `p`. Note that `p` can be a direct numerical index or a symbolic value. +the parameter `p`. Note that `p` can be a direct numerical index or a symbolic value, or +an array/tuple of the aforementioned. + Requires that the integrator or solution implement [`parameter_values`](@ref). This function typically does not need to be implemented, and has a default implementation relying on [`parameter_values`](@ref). @@ -31,34 +33,33 @@ typically does not need to be implemented, and has a default implementation rely function getp(sys, p) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) - if symtype != NotSymbolic() - return _getp(sys, symtype, p) - else - return _getp(sys, elsymtype, p) - end + _getp(sys, symtype, elsymtype, p) end -function _getp(sys, ::NotSymbolic, p) +function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) return function getter(sol) return parameter_values(sol)[p] end end -function _getp(sys, ::ScalarSymbolic, p) +function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) return function getter(sol) return parameter_values(sol)[idx] end end -function _getp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray}) - idxs = parameter_index.((sys,), p) - return function getter(sol) - return getindex.((parameter_values(sol),), idxs) +for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] + @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) + getters = getp.((sys,), p) + + return function getter(sol) + map(g -> g(sol), getters) + end end end -function _getp(sys, ::ArraySymbolic, p) +function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p) return getp(sys, collect(p)) end @@ -76,33 +77,31 @@ implemented. function setp(sys, p) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) - if symtype != NotSymbolic() - return _setp(sys, symtype, p) - else - return _setp(sys, elsymtype, p) - end + _setp(sys, symtype, elsymtype, p) end -function _setp(sys, ::NotSymbolic, p) +function _setp(sys, ::NotSymbolic, ::NotSymbolic, p) return function setter!(sol, val) set_parameter!(sol, val, p) end end -function _setp(sys, ::ScalarSymbolic, p) +function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) return function setter!(sol, val) set_parameter!(sol, val, idx) end end -function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray}) - idxs = parameter_index.((sys,), p) - return function setter!(sol, val) - set_parameter!.((sol,), val, idxs) +for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] + @eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2) + setters = setp.((sys,), p) + return function setter!(sol, val) + map((s!, v) -> s!(sol, v), setters, val) + end end end -function _setp(sys, ::ArraySymbolic, p) +function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p) return setp(sys, collect(p)) end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 94a02a6c..b57a80ee 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -93,23 +93,20 @@ relying on the above functions. function getu(sys, sym) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - - if symtype != NotSymbolic() - _getu(sys, symtype, sym) - else - _getu(sys, elsymtype, sym) - end + _getu(sys, symtype, elsymtype, sym) end -function _getu(sys, ::NotSymbolic, sym) +function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) _getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,)) _getter(::NotTimeseries, prob) = state_values(prob)[sym] - return function getter(prob) - return _getter(is_timeseries(prob), prob) + return let _getter = _getter + function getter(prob) + return _getter(is_timeseries(prob), prob) + end end end -function _getu(sys, ::ScalarSymbolic, sym) +function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) return getu(sys, idx) @@ -125,8 +122,10 @@ function _getu(sys, ::ScalarSymbolic, sym) return fn(state_values(prob), parameter_values(prob), current_time(prob)) end - return function getter2(prob) - return _getter2(is_timeseries(prob), prob) + return let _getter2 = _getter2 + function getter2(prob) + return _getter2(is_timeseries(prob), prob) + end end else function _getter3(::Timeseries, prob) @@ -136,8 +135,10 @@ function _getu(sys, ::ScalarSymbolic, sym) return fn(state_values(prob), parameter_values(prob)) end - return function getter3(prob) - return _getter3(is_timeseries(prob), prob) + return let _getter3 = _getter3 + function getter3(prob) + return _getter3(is_timeseries(prob), prob) + end end end end @@ -153,24 +154,38 @@ state_values(t::TimeseriesIndexWrapper) = state_values(t.timeseries)[t.idx] parameter_values(t::TimeseriesIndexWrapper) = parameter_values(t.timeseries) current_time(t::TimeseriesIndexWrapper) = current_time(t.timeseries)[t.idx] -function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray}) - getters = getu.((sys,), sym) - _call(getter, prob) = getter(prob) +for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] + @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) + getters = getu.((sys,), sym) + _call(getter, prob) = getter(prob) - function _getter(::Timeseries, prob) - tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob))) - return [_getter(NotTimeseries(), tiw) for tiw in tiws] - end - _getter(::NotTimeseries, prob) = _call.(getters, (prob,)) - return function getter(prob) - return _getter(is_timeseries(prob), prob) + return let getters = getters, _call = _call + _getter(::NotTimeseries, prob) = map(g -> g(prob), getters) + function _getter(::Timeseries, prob) + tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob))) + # Ideally this should recursively call `_getter` but that leads to type-instability + # since the reference to itself is boxed + # Turning this broadcasted `_call` into a map also makes this type-unstable + + return map(tiw -> _call.(getters, (tiw,)), tiws) + end + + # Need another scope for this to not box `_getter` + let _getter = _getter + function getter(prob) + return _getter(is_timeseries(prob), prob) + end + end + end end end -function _getu(sys, ::ArraySymbolic, sym) +function _getu(sys, ::ArraySymbolic, ::NotSymbolic, sym) return getu(sys, collect(sym)) end +# setu doesn't need the same `let` blocks to be inferred for some reason + """ setu(sys, sym) @@ -186,21 +201,16 @@ This function does not work on types for which [`is_timeseries`](@ref) is function setu(sys, sym) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - - if symtype != NotSymbolic() - _setu(sys, symtype, sym) - else - _setu(sys, elsymtype, sym) - end + _setu(sys, symtype, elsymtype, sym) end -function _setu(sys, ::NotSymbolic, sym) +function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym) return function setter!(prob, val) set_state!(prob, val, sym) end end -function _setu(sys, ::ScalarSymbolic, sym) +function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) is_variable(sys, sym) || error("Invalid symbol $sym for `setu`") idx = variable_index(sys, sym) return function setter!(prob, val) @@ -208,14 +218,15 @@ function _setu(sys, ::ScalarSymbolic, sym) end end -function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray}) - setters = setu.((sys,), sym) - _call!(setter!, prob, val) = setter!(prob, val) - return function setter!(prob, val) - _call!.(setters, (prob,), val) +for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] + @eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2) + setters = setu.((sys,), sym) + return function setter!(prob, val) + map((s!, v) -> s!(prob, v), setters, val) + end end end -function _setu(sys, ::ArraySymbolic, sym) +function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym) return setu(sys, collect(sym)) end diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 8f858e68..a4650b60 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -9,15 +9,37 @@ end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p -sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) -p = [1.0, 2.0] +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) +p = [1.0, 2.0, 3.0] fi = FakeIntegrator(sys, copy(p)) -for (i, sym) in [(1, :a), (2, :b), ([1, 2], [:a, :b]), ((1, 2), (:a, :b))] +new_p = [4.0, 5.0, 6.0] +for (sym, oldval, newval, check_inference) in [ + (:a, p[1], new_p[1], true), + (1, p[1], new_p[1], true), + ([:a, :b], p[1:2], new_p[1:2], true), + (1:2, p[1:2], new_p[1:2], true), + ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), +] get = getp(sys, sym) set! = setp(sys, sym) - true_value = i isa Tuple ? getindex.((p,), i) : p[i] - @test get(fi) == ParameterIndexingProxy(fi)[sym] == true_value - set!(fi, 0.5 .* i) - @test get(fi) == ParameterIndexingProxy(fi)[sym] == 0.5 .* i - set!(fi, true_value) + if check_inference + @inferred get(fi) + end + @test get(fi) == oldval + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval end diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 7fc3a61c..60612b82 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -11,16 +11,44 @@ SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) u = [1.0, 2.0, 3.0] fi = FakeIntegrator(sys, copy(u)) -for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))] +# checking inference for non-concretely typed arrays will always fail +for (sym, val, newval, check_inference) in [ + (:x, u[1], 4.0, true) + (:y, u[2], 4.0, true) + (:z, u[3], 4.0, true) + (1, u[1], 4.0, true) + ([:x, :y], u[1:2], 4ones(2), true) + ([1, 2], u[1:2], 4ones(2), true) + ((:z, :y), (u[3], u[2]), (4.0, 5.0), true) + ((3, 2), (u[3], u[2]), (4.0, 5.0), true) + ([:x, [:y, :z]], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false) + ([:x, 2:3], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false) + ([:x, (:y, :z)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false) + ([:x, Tuple(2:3)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false) + ([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false) + ([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false) + ((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true) + ((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) + ((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) + ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true) +] get = getu(sys, sym) set! = setu(sys, sym) - true_value = i isa Tuple ? getindex.((u,), i) : u[i] - @test get(fi) == true_value - set!(fi, 0.5 .* i) - @test get(fi) == 0.5 .* i - set!(fi, true_value) + if check_inference + @inferred get(fi) + end + @test get(fi) == val + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test get(fi) == newval + set!(fi, val) + @test get(fi) == val end + struct FakeSolution{S, U} sys::S u::U @@ -33,12 +61,33 @@ SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] sol = FakeSolution(sys, u) -for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))] + +xvals = getindex.(sol.u, 1) +yvals = getindex.(sol.u, 2) +zvals = getindex.(sol.u, 3) + +for (sym, ans, check_inference) in [ + (:x, xvals, true) + (:y, yvals, true) + (:z, zvals, true) + (1, xvals, true) + ([:x, :y], vcat.(xvals, yvals), true) + (1:2, vcat.(xvals, yvals), true) + ([:x, 2], vcat.(xvals, yvals), false) + ((:z, :y), tuple.(zvals, yvals), true) + ((3, 2), tuple.(zvals, yvals), true) + ([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false) + ([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false) + ([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false) + ([:x, [:y, :z], (:x, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false) + ([:x, [:y, 3], (1, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false) + ((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true) + ((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true) + ((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true) +] get = getu(sys, sym) - true_value = if i isa Tuple - [getindex.((v,), i) for v in u] - else - getindex.(u, (i,)) + if check_inference + @inferred get(sol) end - @test get(sol) == true_value + @test get(sol) == ans end