diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index db0abf4a..e28720b6 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -53,7 +53,11 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) end end -for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] +for (t1, t2) in [ + (ArraySymbolic, Any), + (ScalarSymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}), +] @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) getters = getp.((sys,), p) @@ -99,7 +103,11 @@ function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) end end -for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] +for (t1, t2) in [ + (ArraySymbolic, Any), + (ScalarSymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}), +] @eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2) setters = setp.((sys,), p) return function setter!(sol, val) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 7de8deeb..d84e869b 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -154,7 +154,9 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) current_time(prob)) end function _getter2(::Timeseries, prob, i) - return fn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) + return fn(state_values(prob, i), + parameter_values(prob), + current_time(prob, i)) end function _getter2(::NotTimeseries, prob) return fn(state_values(prob), parameter_values(prob), current_time(prob)) @@ -181,7 +183,11 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) error("Invalid symbol $sym for `getu`") end -for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] +for (t1, t2) in [ + (ScalarSymbolic, Any), + (ArraySymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}), +] @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) getters = getu.((sys,), sym) _call(getter, args...) = getter(args...) @@ -252,7 +258,11 @@ function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) error("Invalid symbol $sym for `setu`") end -for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})] +for (t1, t2) in [ + (ScalarSymbolic, Any), + (ArraySymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}), +] @eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2) setters = setu.((sys,), sym) return function setter!(prob, val) diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 6d8ee6a0..e47a4601 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -18,8 +18,7 @@ p = [11.0, 12.0, 13.0] t = 0.5 fi = FakeIntegrator(sys, copy(u), copy(p), t) # checking inference for non-concretely typed arrays will always fail -for (sym, val, newval, check_inference) in [ - (:x, u[1], 4.0, true) +for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) (:y, u[2], 4.0, true) (:z, u[3], 4.0, true) (1, u[1], 4.0, true) @@ -36,8 +35,7 @@ for (sym, val, newval, check_inference) in [ ((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true) ((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) ((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) - ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true) -] + ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)] get = getu(sys, sym) set! = setu(sys, sym) if check_inference @@ -67,15 +65,13 @@ for (sym, val, newval, check_inference) in [ @test get(u) == val end -for (sym, oldval, newval, check_inference) in [ - (:a, p[1], 4.0, true) +for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) (:b, p[2], 5.0, true) (:c, p[3], 6.0, true) ([:a, :b], p[1:2], [4.0, 5.0], true) ((:c, :b), (p[3], p[2]), (6.0, 5.0), true) ([:x, :a], [u[1], p[1]], [4.0, 5.0], false) - ((:y, :b), (u[2], p[2]), (5.0, 6.0), true) -] + ((:y, :b), (u[2], p[2]), (5.0, 6.0), true)] get = getu(fi, sym) set! = setu(fi, sym) if check_inference @@ -126,8 +122,7 @@ xvals = getindex.(sol.u, 1) yvals = getindex.(sol.u, 2) zvals = getindex.(sol.u, 3) -for (sym, ans, check_inference) in [ - (:x, xvals, true) +for (sym, ans, check_inference) in [(:x, xvals, true) (:y, yvals, true) (:z, zvals, true) (1, xvals, true) @@ -139,17 +134,22 @@ for (sym, ans, check_inference) in [ ([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false) ([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false) ([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false) - ([:x, [:y, :z], (:x, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false) - ([:x, [:y, 3], (1, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false) + ([:x, [:y, :z], (:x, :z)], + vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), + false) + ([:x, [:y, 3], (1, :z)], + vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), + false) ((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true) ((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true) - ((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true) + ((:x, [:y, :z], (:z, :y)), + tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), + true) ([:x, :a], vcat.(xvals, p[1]), false) ((:y, :b), tuple.(yvals, p[2]), true) (:t, t, true) ([:x, :a, :t], vcat.(xvals, p[1], t), false) - ((:x, :a, :t), tuple.(xvals, p[1], t), true) -] + ((:x, :a, :t), tuple.(xvals, p[1], t), true)] get = getu(sys, sym) if check_inference @inferred get(sol) @@ -163,13 +163,11 @@ for (sym, ans, check_inference) in [ end end -for (sym, val) in [ - (:a, p[1]) +for (sym, val) in [(:a, p[1]) (:b, p[2]) (:c, p[3]) ([:a, :b], p[1:2]) - ((:c, :b), (p[3], p[2])) -] + ((:c, :b), (p[3], p[2]))] get = getu(fi, sym) @inferred get(fi) @test get(fi) == val