Skip to content

Commit

Permalink
refactor: format
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 17, 2024
1 parent 5ab2679 commit 48461a4
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 25 deletions.
12 changes: 10 additions & 2 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
end
end

for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
for (t1, t2) in [
(ArraySymbolic, Any),
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
]
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
getters = getp.((sys,), p)

Expand Down Expand Up @@ -99,7 +103,11 @@ function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
end
end

for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
for (t1, t2) in [
(ArraySymbolic, Any),
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
]
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
setters = setp.((sys,), p)
return function setter!(sol, val)
Expand Down
19 changes: 15 additions & 4 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
current_time(prob))
end
function _getter2(::Timeseries, prob, i)
return fn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
return fn(state_values(prob, i),

Check warning on line 157 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L156-L157

Added lines #L156 - L157 were not covered by tests
parameter_values(prob),
current_time(prob, i))
end
function _getter2(::NotTimeseries, prob)
return fn(state_values(prob), parameter_values(prob), current_time(prob))
Expand All @@ -181,14 +183,19 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
error("Invalid symbol $sym for `getu`")
end

for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
for (t1, t2) in [
(ScalarSymbolic, Any),
(ArraySymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
]
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
getters = getu.((sys,), sym)
_call(getter, args...) = getter(args...)
return let getters = getters, _call = _call
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
function _getter(::Timeseries, prob)
return map(i -> _call.(getters, (prob,), (i,)), eachindex(state_values(prob)))
return map(i -> _call.(getters, (prob,), (i,)),

Check warning on line 197 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L191-L197

Added lines #L191 - L197 were not covered by tests
eachindex(state_values(prob)))
end
function _getter(::Timeseries, prob, i)
return _call.(getters, (prob,), (i,))

Check warning on line 201 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L200-L201

Added lines #L200 - L201 were not covered by tests
Expand Down Expand Up @@ -251,7 +258,11 @@ function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
error("Invalid symbol $sym for `setu`")

Check warning on line 258 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L258

Added line #L258 was not covered by tests
end

for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
for (t1, t2) in [
(ScalarSymbolic, Any),
(ArraySymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
]
@eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2)
setters = setu.((sys,), sym)
return function setter!(prob, val)
Expand Down
36 changes: 17 additions & 19 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ p = [11.0, 12.0, 13.0]
t = 0.5
fi = FakeIntegrator(sys, copy(u), copy(p), t)
# checking inference for non-concretely typed arrays will always fail
for (sym, val, newval, check_inference) in [
(:x, u[1], 4.0, true)
for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
(:y, u[2], 4.0, true)
(:z, u[3], 4.0, true)
(1, u[1], 4.0, true)
Expand All @@ -36,8 +35,7 @@ for (sym, val, newval, check_inference) in [
((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true)
((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)
]
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)]
get = getu(sys, sym)
set! = setu(sys, sym)
if check_inference
Expand Down Expand Up @@ -67,15 +65,13 @@ for (sym, val, newval, check_inference) in [
@test get(u) == val
end

for (sym, oldval, newval, check_inference) in [
(:a, p[1], 4.0, true)
for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
(:b, p[2], 5.0, true)
(:c, p[3], 6.0, true)
([:a, :b], p[1:2], [4.0, 5.0], true)
((:c, :b), (p[3], p[2]), (6.0, 5.0), true)
([:x, :a], [u[1], p[1]], [4.0, 5.0], false)
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)
]
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)]
get = getu(fi, sym)
set! = setu(fi, sym)
if check_inference
Expand Down Expand Up @@ -126,8 +122,7 @@ xvals = getindex.(sol.u, 1)
yvals = getindex.(sol.u, 2)
zvals = getindex.(sol.u, 3)

for (sym, ans, check_inference) in [
(:x, xvals, true)
for (sym, ans, check_inference) in [(:x, xvals, true)
(:y, yvals, true)
(:z, zvals, true)
(1, xvals, true)
Expand All @@ -139,17 +134,22 @@ for (sym, ans, check_inference) in [
([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false)
([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
([:x, [:y, :z], (:x, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false)
([:x, [:y, 3], (1, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false)
([:x, [:y, :z], (:x, :z)],
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)),
false)
([:x, [:y, 3], (1, :z)],
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)),
false)
((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true)
((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true)
((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true)
((:x, [:y, :z], (:z, :y)),
tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)),
true)
([:x, :a], vcat.(xvals, p[1]), false)
((:y, :b), tuple.(yvals, p[2]), true)
(:t, t, true)
([:x, :a, :t], vcat.(xvals, p[1], t), false)
((:x, :a, :t), tuple.(xvals, p[1], t), true)
]
((:x, :a, :t), tuple.(xvals, p[1], t), true)]
get = getu(sys, sym)
if check_inference
@inferred get(sol)
Expand All @@ -163,13 +163,11 @@ for (sym, ans, check_inference) in [
end
end

for (sym, val) in [
(:a, p[1])
for (sym, val) in [(:a, p[1])
(:b, p[2])
(:c, p[3])
([:a, :b], p[1:2])
((:c, :b), (p[3], p[2]))
]
((:c, :b), (p[3], p[2]))]
get = getu(fi, sym)
@inferred get(fi)
@test get(fi) == val
Expand Down

0 comments on commit 48461a4

Please sign in to comment.