Skip to content

Commit

Permalink
refactor: format
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 7, 2024
1 parent 8d626dc commit de0ac6e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
3 changes: 2 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ include("interface.jl")
export SymbolCache
include("symbol_cache.jl")

export parameter_values, set_parameter!, parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp, setp
export parameter_values, set_parameter!, parameter_values_at_time,
parameter_values_at_state_time, parameter_timeseries, getp, setp
include("parameter_indexing.jl")

export state_values, set_state!, current_time, getu, setu
Expand Down
19 changes: 14 additions & 5 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,16 @@ function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
parameter_values(prob, p)

Check warning on line 113 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L112-L113

Added lines #L112 - L113 were not covered by tests
end
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(parameter_values_at_time(prob, only(to_indices(parameter_timeseries(prob), (i,)))), p)
parameter_values(

Check warning on line 116 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L115-L116

Added lines #L115 - L116 were not covered by tests
parameter_values_at_time(
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
p)
end
function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
parameter_values.(parameter_values_at_time.((prob,), (j for j in only(to_indices(parameter_timeseries(prob), (i,))))), p)
parameter_values.(

Check warning on line 122 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L121-L122

Added lines #L121 - L122 were not covered by tests
parameter_values_at_time.((prob,),
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
p)
end
function _getter(::Timeseries, prob, i)
parameter_values.(parameter_values_at_time.((prob,), i), (p,))

Check warning on line 128 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L127-L128

Added lines #L127 - L128 were not covered by tests
Expand All @@ -132,7 +138,8 @@ end

function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, sys, NotSymbolic(), NotSymbolic(), idx)
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any},

Check warning on line 141 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L141

Added line #L141 was not covered by tests
sys, NotSymbolic(), NotSymbolic(), idx)
return _getp(sys, NotSymbolic(), NotSymbolic(), idx)

Check warning on line 143 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L143

Added line #L143 was not covered by tests
end

Expand All @@ -155,7 +162,8 @@ for (t1, t2) in [
map(g -> g(prob, i), getters)

Check warning on line 162 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L161-L162

Added lines #L161 - L162 were not covered by tests
end
function _getter(::Timeseries, prob, i)
[map(g -> g(prob, j), getters) for j in only(to_indices(parameter_timeseries(prob), (i,)))]
[map(g -> g(prob, j), getters)

Check warning on line 165 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L164-L165

Added lines #L164 - L165 were not covered by tests
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
end
function _getter!(buffer, ::NotTimeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
Expand All @@ -176,7 +184,8 @@ for (t1, t2) in [
buffer

Check warning on line 184 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L180-L184

Added lines #L180 - L184 were not covered by tests
end
function _getter!(buffer, ::Timeseries, prob, i)
for (bufi, tsi) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
for (bufi, tsi) in zip(

Check warning on line 187 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L186-L187

Added lines #L186 - L187 were not covered by tests
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
for (g, bufj) in zip(getters, eachindex(buffer[bufi]))
buffer[bufi][bufj] = g(prob, tsi)
end

Check warning on line 191 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L189-L191

Added lines #L189 - L191 were not covered by tests
Expand Down
4 changes: 3 additions & 1 deletion src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ for (t1, t2) in [
end
function _getter2a(::Timeseries, prob)
curtime = current_time(prob)
obs.(state_values(prob), (parameter_values_at_state_time(prob, i) for i in eachindex(curtime)),
obs.(state_values(prob),

Check warning on line 188 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L187-L188

Added lines #L187 - L188 were not covered by tests
(parameter_values_at_state_time(prob, i)
for i in eachindex(curtime)),
curtime)
end
function _getter2a(::Timeseries, prob, i)
Expand Down
1 change: 0 additions & 1 deletion src/trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,3 @@ function is_timeseries end

is_timeseries(x) = is_timeseries(typeof(x))
is_timeseries(::Type) = NotTimeseries()

Check warning on line 103 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L102-L103

Added lines #L102 - L103 were not covered by tests

11 changes: 7 additions & 4 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ fs = FakeSolution(
[i * ones(3) for i in 1:5],
[0.2i for i in 1:5],
[2i * ones(3) for i in 1:10],
[0.1i for i in 1:10],
[0.1i for i in 1:10]
)
ps = fs.p
p = fs.p[end]
Expand All @@ -122,11 +122,13 @@ for (sym, val, arrval, check_inference) in [
([:a, :b], p[1:2], vcat.(avals, bvals), true),
(1:2, p[1:2], vcat.(avals, bvals), true),
((1, 2), Tuple(p[1:2]), tuple.(avals, bvals), true),
([:a, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false),
([:a, [:b, :c]], [p[1], p[2:3]],
[[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false),
((:a, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true),
([1, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false),
([1, [:b, :c]], [p[1], p[2:3]],
[[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false),
((1, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true)
Expand All @@ -138,7 +140,8 @@ for (sym, val, arrval, check_inference) in [
@test get(fs) == fs.ps[sym]
@test get(fs) == val

for sub_inds in [:, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))]
for sub_inds in [
:, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))]
if check_inference
@inferred get(fs, sub_inds)
end
Expand Down

0 comments on commit de0ac6e

Please sign in to comment.