Skip to content

Commit

Permalink
feat: make getp type-stable, make .ps type-stable, add tests for new …
Browse files Browse the repository at this point in the history
…indexing
  • Loading branch information
AayushSabharwal committed Mar 7, 2024
1 parent cc5da90 commit 8d626dc
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 34 deletions.
80 changes: 48 additions & 32 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,30 +105,35 @@ function getp(sys, p)
end

function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
_getter = let p = p
return let p = p
function _getter(::NotTimeseries, prob)
parameter_values(prob, p)

Check warning on line 110 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L108-L110

Added lines #L108 - L110 were not covered by tests
end
function _getter(::Timeseries, prob)
parameter_values(prob, p)

Check warning on line 113 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L112-L113

Added lines #L112 - L113 were not covered by tests
end
function _getter(::Timeseries, prob, i)
parameter_values(parameter_values_at_time(prob, i), p)
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(parameter_values_at_time(prob, only(to_indices(parameter_timeseries(prob), (i,)))), p)

Check warning on line 116 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L115-L116

Added lines #L115 - L116 were not covered by tests
end
function _getter(::Timeseries, prob, ::Colon)
parameter_values.((parameter_values_at_time(prob, i) for i in eachindex(parameter_timeseries(prob))), (p,))
function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
parameter_values.(parameter_values_at_time.((prob,), (j for j in only(to_indices(parameter_timeseries(prob), (i,))))), p)

Check warning on line 119 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
end
end
return let _getter = _getter
function getter(prob, args...)
return _getter(is_timeseries(prob), prob, args...)
function _getter(::Timeseries, prob, i)
parameter_values.(parameter_values_at_time.((prob,), i), (p,))

Check warning on line 122 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L121-L122

Added lines #L121 - L122 were not covered by tests
end
getter = let _getter = _getter
function getter(prob, args...)
return _getter(is_timeseries(prob), prob, args...)

Check warning on line 126 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L124-L126

Added lines #L124 - L126 were not covered by tests
end
end
getter

Check warning on line 129 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L129

Added line #L129 was not covered by tests
end
end

function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return getp(sys, idx)
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, sys, NotSymbolic(), NotSymbolic(), idx)
return _getp(sys, NotSymbolic(), NotSymbolic(), idx)

Check warning on line 136 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L135-L136

Added lines #L135 - L136 were not covered by tests
end

for (t1, t2) in [
Expand All @@ -139,43 +144,54 @@ for (t1, t2) in [
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
getters = getp.((sys,), p)

_getter = return let getters = getters
return let getters = getters
function _getter(::NotTimeseries, prob)
map(g -> g(prob), getters)

Check warning on line 149 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L148-L149

Added lines #L148 - L149 were not covered by tests
end
function _getter(::Timeseries, prob)
map(g -> g(prob), getters)

Check warning on line 152 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L151-L152

Added lines #L151 - L152 were not covered by tests
end
function _getter(::Timeseries, prob, i)
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
map(g -> g(prob, i), getters)

Check warning on line 155 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L154-L155

Added lines #L154 - L155 were not covered by tests
end
function _getter(::Timeseries, prob, ::Colon)
[map(g -> g(prob, i), getters) for i in eachindex(parameter_timeseries(prob))]
end
function _getter(buffer, ::NotTimeseries, prob)
map!(g -> g(prob), buffer, getters)
function _getter(::Timeseries, prob, i)
[map(g -> g(prob, j), getters) for j in only(to_indices(parameter_timeseries(prob), (i,)))]

Check warning on line 158 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L157-L158

Added lines #L157 - L158 were not covered by tests
end
function _getter(buffer, ::Timeseries, prob)
map!(g -> g(prob), buffer, getters)
function _getter!(buffer, ::NotTimeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer

Check warning on line 164 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L160-L164

Added lines #L160 - L164 were not covered by tests
end
function _getter(buffer, ::Timeseries, prob, i)
map!(g -> g(prob, i), buffer, getters)
function _getter!(buffer, ::Timeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer

Check warning on line 170 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L166-L170

Added lines #L166 - L170 were not covered by tests
end
function _getter(buffer, ::Timeseries, prob, ::Colon)
for (bufi, tsi) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob)))
map!(g -> g(prob, tsi), buffer[bufi], getters)
function _getter!(buffer, ::Timeseries, prob, i::Union{Int, CartesianIndex})
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob, i)
end
buffer

Check warning on line 176 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L172-L176

Added lines #L172 - L176 were not covered by tests
end
_getter
end

return let _getter = _getter
function getter(prob, i...)
return _getter(is_timeseries(prob), prob, i...)
function _getter!(buffer, ::Timeseries, prob, i)
for (bufi, tsi) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
for (g, bufj) in zip(getters, eachindex(buffer[bufi]))
buffer[bufi][bufj] = g(prob, tsi)
end

Check warning on line 182 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L178-L182

Added lines #L178 - L182 were not covered by tests
end
buffer
end
function getter(buffer::AbstractArray, prob, i...)
return _getter(buffer, is_timeseries(prob), prob, i...)
_getter, _getter!
getter = let _getter = _getter, _getter! = _getter!
function getter(prob, i...)
return _getter(is_timeseries(prob), prob, i...)

Check warning on line 189 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L186-L189

Added lines #L186 - L189 were not covered by tests
end
function getter(buffer::AbstractArray, prob, i...)
return _getter!(buffer, is_timeseries(prob), prob, i...)

Check warning on line 192 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L191-L192

Added lines #L191 - L192 were not covered by tests
end
getter

Check warning on line 194 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L194

Added line #L194 was not covered by tests
end
getter

Check warning on line 196 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L196

Added line #L196 was not covered by tests
end
Expand Down
4 changes: 2 additions & 2 deletions src/parameter_indexing_proxy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct ParameterIndexingProxy{T}
wrapped::T
end

function Base.getindex(p::ParameterIndexingProxy, idx)
return getp(p.wrapped, idx)(p.wrapped)
function Base.getindex(p::ParameterIndexingProxy, idx, args...)
getp(p.wrapped, idx)(p.wrapped, args...)

Check warning on line 14 in src/parameter_indexing_proxy.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing_proxy.jl#L13-L14

Added lines #L13 - L14 were not covered by tests
end

function Base.setindex!(p::ParameterIndexingProxy, val, idx)
Expand Down
78 changes: 78 additions & 0 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ struct FakeIntegrator{S, P}
p::P
end

function Base.getproperty(fi::FakeIntegrator, s::Symbol)
s === :ps ? ParameterIndexingProxy(fi) : getfield(fi, s)
end
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(p))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
(:a, p[1], new_p[1], true),
(1, p[1], new_p[1], true),
Expand All @@ -33,6 +37,7 @@ for (sym, oldval, newval, check_inference) in [
if check_inference
@inferred get(fi)
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval
if check_inference
@inferred set!(fi, newval)
Expand All @@ -43,6 +48,11 @@ for (sym, oldval, newval, check_inference) in [
set!(fi, oldval)
@test get(fi) == oldval

fi.ps[sym] = newval
@test get(fi) == newval
fi.ps[sym] = oldval
@test get(fi) == oldval

if check_inference
@inferred get(p)
end
Expand All @@ -68,3 +78,71 @@ for (sym, val) in [
@inferred get(buffer, fi)
@test buffer == val
end

struct FakeSolution
sys::SymbolCache
u::Vector{Vector{Float64}}
t::Vector{Float64}
p::Vector{Vector{Float64}}
pt::Vector{Float64}
end

function Base.getproperty(fs::FakeSolution, s::Symbol)
s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s)
end
SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys
SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p[end]
SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[end][i]
function SymbolicIndexingInterface.parameter_values_at_time(fs::FakeSolution, t)
fs.p[t]
end
function SymbolicIndexingInterface.parameter_values_at_state_time(fs::FakeSolution, t)
ptind = searchsortedfirst(fs.pt, fs.t[t])
fs.p[ptind - 1]
end
SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution) = fs.pt
SymbolicIndexingInterface.is_timeseries(::Type{FakeSolution}) = Timeseries()
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
fs = FakeSolution(
sys,
[i * ones(3) for i in 1:5],
[0.2i for i in 1:5],
[2i * ones(3) for i in 1:10],
[0.1i for i in 1:10],
)
ps = fs.p
p = fs.p[end]
avals = getindex.(ps, 1)
bvals = getindex.(ps, 2)
cvals = getindex.(ps, 3)
@test parameter_timeseries(fs) == fs.pt
for (sym, val, arrval, check_inference) in [
(:a, p[1], avals, true),
(1, p[1], avals, true),
([:a, :b], p[1:2], vcat.(avals, bvals), true),
(1:2, p[1:2], vcat.(avals, bvals), true),
((1, 2), Tuple(p[1:2]), tuple.(avals, bvals), true),
([:a, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false),
((:a, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true),
([1, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false),
((1, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true)
]
get = getp(sys, sym)
if check_inference
@inferred get(fs)
end
@test get(fs) == fs.ps[sym]
@test get(fs) == val

for sub_inds in [:, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))]
if check_inference
@inferred get(fs, sub_inds)
end
@test get(fs, sub_inds) == fs.ps[sym, sub_inds]
@test get(fs, sub_inds) == arrval[sub_inds]
end
end

0 comments on commit 8d626dc

Please sign in to comment.