Skip to content

Commit

Permalink
fixup! wip: better parameter indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 10, 2024
1 parent b8139fa commit c7066ec
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 30 deletions.
33 changes: 29 additions & 4 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,50 @@ end
function (gpi::GetParameterIndex)(::Timeseries, prob, args)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args))
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob)
parameter_values.(
(prob,), (gpi.idx,),
eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob)
gpi.((ts,), (prob,), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob)
for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
buffer[buf_idx] = gpi(ts, prob, ts_idx)
end
buffer
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(prob, gpi.idx, i)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, ::Colon)
gpi(ts, prob)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, ::Colon)
gpi(buffer, ts, prob)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i::AbstractArray{Bool})
map(only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) do idx
gpi(ts, prob, idx)
end
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool})
for (buf_idx, ts_idx) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,))))
buffer[buf_idx] = gpi(ts, prob, ts_idx)
end
buffer
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i)
gpi.((ts,), (prob,), i)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i)
for (buf_idx, subidx) in zip(eachindex(buffer), i)
buffer[buf_idx] = gpi(ts, prob, subidx)
end
buffer
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::NotTimeseries, prob)
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi))
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::AbstractArray, ::NotTimeseries, prob)
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi))
end

function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
return GetParameterIndex(p)
Expand All @@ -84,6 +106,9 @@ as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.p
function (gpti::GetParameterTimeseriesIndex)(ts::Timeseries, prob, args...)
gpti.param_timeseries_idx(ts, prob, args...)
end
function (gpti::GetParameterTimeseriesIndex)(buffer::AbstractArray, ts::Timeseries, prob, args...)
gpti.param_timeseries_idx(buffer, ts, prob, args...)
end
function (gpti::GetParameterTimeseriesIndex)(ts::NotTimeseries, prob)
gpti.param_idx(ts, prob)
end
Expand Down
106 changes: 80 additions & 26 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,55 +166,109 @@ dval = fs.p[4]
bidx = timeseries_parameter_index(sys, :b)
cidx = timeseries_parameter_index(sys, :c)

for (sym, indexer_trait, timeseries_index, val, check_inference) in [
(:a, IndexerNotTimeseries, 0, aval, true),
(1, IndexerNotTimeseries, 0, aval, true),
([:a, :d], IndexerNotTimeseries, 0, [aval, dval], true),
((:a, :d), IndexerNotTimeseries, 0, (aval, dval), true),
([1, 4], IndexerNotTimeseries, 0, [aval, dval], true),
((1, 4), IndexerNotTimeseries, 0, (aval, dval), true),
([:a, 4], IndexerNotTimeseries, 0, [aval, dval], true),
((:a, 4), IndexerNotTimeseries, 0, (aval, dval), true),
(:b, IndexerBoth, 1, bval, true),
(bidx, IndexerTimeseries, 1, bval, true),
([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], true),
((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), true),
([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], true),
((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), true),
([:b, :b], IndexerBoth, 1, vcat.(bval, bval), true),
((:b, :b), IndexerBoth, 1, tuple.(bval, bval), true),
([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), true),
((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval), true),
([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), true),
((bidx, bidx), IndexerTimeseries, 1, tuple.(bval, bval), true),
for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [
(:a, IndexerNotTimeseries, 0, aval, nothing, true),
(1, IndexerNotTimeseries, 0, aval, nothing, true),
([:a, :d], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true),
((:a, :d), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true),
([1, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true),
((1, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true),
([:a, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true),
((:a, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true),
(:b, IndexerBoth, 1, bval, zeros(length(bval)), true),
(bidx, IndexerTimeseries, 1, bval, zeros(length(bval)), true),
([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true),
((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true),
([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true),
((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true),
([:b, :b], IndexerBoth, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true),
((:b, :b), IndexerBoth, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true),
([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true),
((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval),map(_ -> zeros(2), bval), true),
([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true),
((bidx, bidx), IndexerTimeseries, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true),
]
getter = getp(sys, sym)
@test is_indexer_timeseries(getter) isa indexer_trait
if indexer_trait <: Union{IndexerTimeseries, IndexerBoth}
@test indexer_timeseries_index(getter) == timeseries_index
end
test_inplace = buffer !== nothing
test_non_timeseries = indexer_trait !== IndexerTimeseries
if test_inplace && test_non_timeseries
non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end]
non_timeseries_buffer = indexer_trait == IndexerNotTimeseries ? deepcopy(buffer) : deepcopy(buffer[end])
test_non_timeseries_inplace = non_timeseries_buffer isa AbstractArray
end
if check_inference
@inferred getter(fs)
if indexer_trait != IndexerTimeseries
if test_inplace
@inferred getter(deepcopy(buffer), fs)
end
if test_non_timeseries
@inferred getter(parameter_values(fs))
if test_inplace && test_non_timeseries_inplace && test_non_timeseries_inplace
@inferred getter(deepcopy(non_timeseries_buffer), parameter_values(fs))
end
end
end
@test getter(fs) == val

if indexer_trait == IndexerTimeseries
@test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs))
else
if test_inplace
tmp = deepcopy(buffer)
getter(tmp, fs)
if val isa Tuple
target = collect(val)
elseif eltype(val) <: Tuple
target = collect.(val)
else
target = val
end
@test tmp == target
end
if test_non_timeseries
non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end]
@test getter(parameter_values(fs)) == non_timeseries_val
if test_inplace && test_non_timeseries && test_non_timeseries_inplace
getter(non_timeseries_buffer, parameter_values(fs))
if non_timeseries_val isa Tuple
target = collect(non_timeseries_val)
else
target = non_timeseries_val
end
@test non_timeseries_buffer == target
end
else
@test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs))
if test_inplace
@test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter([], parameter_values(fs))
end
end
for subidx in [1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2]
if indexer_trait <: IndexerNotTimeseries
@test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter(fs, subidx)
if test_inplace
@test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter([], fs, subidx)
end
else
if check_inference
@inferred getter(fs, subidx)
if test_inplace && buffer[subidx] isa AbstractArray
@inferred getter(deepcopy(buffer[subidx]), fs, subidx)
end
end
@test getter(fs, subidx) == val[subidx]
if test_inplace && buffer[subidx] isa AbstractArray
tmp = deepcopy(buffer[subidx])
getter(tmp, fs, subidx)
if val[subidx] isa Tuple
target = collect(val[subidx])
elseif eltype(val) <: Tuple
target = collect.(val[subidx])
else
target = val[subidx]
end
@test tmp == target
end
end
end
end
Expand Down

0 comments on commit c7066ec

Please sign in to comment.