From c7066ec0d8c3cf5973d9dd56e2c133c505e556e2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 May 2024 19:00:34 +0530 Subject: [PATCH] fixup! wip: better parameter indexing --- src/parameter_indexing.jl | 33 ++++++++-- test/parameter_indexing_test.jl | 106 ++++++++++++++++++++++++-------- 2 files changed, 109 insertions(+), 30 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index f8d2299..a12735c 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -44,10 +44,14 @@ end function (gpi::GetParameterIndex)(::Timeseries, prob, args) throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args)) end -function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob) - parameter_values.( - (prob,), (gpi.idx,), - eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi)))) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob) + gpi.((ts,), (prob,), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi)))) +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob) + for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi)))) + buffer[buf_idx] = gpi(ts, prob, ts_idx) + end + buffer end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob, i::Union{Int, CartesianIndex}) parameter_values(prob, gpi.idx, i) @@ -55,17 +59,35 @@ end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, ::Colon) gpi(ts, prob) end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, ::Colon) + gpi(buffer, ts, prob) +end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i::AbstractArray{Bool}) map(only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) do idx gpi(ts, prob, idx) end end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) + for (buf_idx, ts_idx) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) + buffer[buf_idx] = gpi(ts, prob, ts_idx) + end + buffer +end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i) gpi.((ts,), (prob,), i) end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i) + for (buf_idx, subidx) in zip(eachindex(buffer), i) + buffer[buf_idx] = gpi(ts, prob, subidx) + end + buffer +end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::NotTimeseries, prob) throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi)) end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::AbstractArray, ::NotTimeseries, prob) + throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi)) +end function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) return GetParameterIndex(p) @@ -84,6 +106,9 @@ as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.p function (gpti::GetParameterTimeseriesIndex)(ts::Timeseries, prob, args...) gpti.param_timeseries_idx(ts, prob, args...) end +function (gpti::GetParameterTimeseriesIndex)(buffer::AbstractArray, ts::Timeseries, prob, args...) + gpti.param_timeseries_idx(buffer, ts, prob, args...) +end function (gpti::GetParameterTimeseriesIndex)(ts::NotTimeseries, prob) gpti.param_idx(ts, prob) end diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index a21736d..077853b 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -166,55 +166,109 @@ dval = fs.p[4] bidx = timeseries_parameter_index(sys, :b) cidx = timeseries_parameter_index(sys, :c) -for (sym, indexer_trait, timeseries_index, val, check_inference) in [ - (:a, IndexerNotTimeseries, 0, aval, true), - (1, IndexerNotTimeseries, 0, aval, true), - ([:a, :d], IndexerNotTimeseries, 0, [aval, dval], true), - ((:a, :d), IndexerNotTimeseries, 0, (aval, dval), true), - ([1, 4], IndexerNotTimeseries, 0, [aval, dval], true), - ((1, 4), IndexerNotTimeseries, 0, (aval, dval), true), - ([:a, 4], IndexerNotTimeseries, 0, [aval, dval], true), - ((:a, 4), IndexerNotTimeseries, 0, (aval, dval), true), - (:b, IndexerBoth, 1, bval, true), - (bidx, IndexerTimeseries, 1, bval, true), - ([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], true), - ((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), true), - ([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], true), - ((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), true), - ([:b, :b], IndexerBoth, 1, vcat.(bval, bval), true), - ((:b, :b), IndexerBoth, 1, tuple.(bval, bval), true), - ([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), true), - ((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval), true), - ([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), true), - ((bidx, bidx), IndexerTimeseries, 1, tuple.(bval, bval), true), +for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ + (:a, IndexerNotTimeseries, 0, aval, nothing, true), + (1, IndexerNotTimeseries, 0, aval, nothing, true), + ([:a, :d], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((:a, :d), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + ([1, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((1, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + ([:a, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((:a, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + (:b, IndexerBoth, 1, bval, zeros(length(bval)), true), + (bidx, IndexerTimeseries, 1, bval, zeros(length(bval)), true), + ([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), + ((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), + ([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), + ((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), + ([:b, :b], IndexerBoth, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((:b, :b), IndexerBoth, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval),map(_ -> zeros(2), bval), true), + ([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, bidx), IndexerTimeseries, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), ] getter = getp(sys, sym) @test is_indexer_timeseries(getter) isa indexer_trait if indexer_trait <: Union{IndexerTimeseries, IndexerBoth} @test indexer_timeseries_index(getter) == timeseries_index end + test_inplace = buffer !== nothing + test_non_timeseries = indexer_trait !== IndexerTimeseries + if test_inplace && test_non_timeseries + non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] + non_timeseries_buffer = indexer_trait == IndexerNotTimeseries ? deepcopy(buffer) : deepcopy(buffer[end]) + test_non_timeseries_inplace = non_timeseries_buffer isa AbstractArray + end if check_inference @inferred getter(fs) - if indexer_trait != IndexerTimeseries + if test_inplace + @inferred getter(deepcopy(buffer), fs) + end + if test_non_timeseries @inferred getter(parameter_values(fs)) + if test_inplace && test_non_timeseries_inplace && test_non_timeseries_inplace + @inferred getter(deepcopy(non_timeseries_buffer), parameter_values(fs)) + end end end @test getter(fs) == val - - if indexer_trait == IndexerTimeseries - @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) - else + if test_inplace + tmp = deepcopy(buffer) + getter(tmp, fs) + if val isa Tuple + target = collect(val) + elseif eltype(val) <: Tuple + target = collect.(val) + else + target = val + end + @test tmp == target + end + if test_non_timeseries non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] @test getter(parameter_values(fs)) == non_timeseries_val + if test_inplace && test_non_timeseries && test_non_timeseries_inplace + getter(non_timeseries_buffer, parameter_values(fs)) + if non_timeseries_val isa Tuple + target = collect(non_timeseries_val) + else + target = non_timeseries_val + end + @test non_timeseries_buffer == target + end + else + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) + if test_inplace + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter([], parameter_values(fs)) + end end for subidx in [1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] if indexer_trait <: IndexerNotTimeseries @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter(fs, subidx) + if test_inplace + @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter([], fs, subidx) + end else if check_inference @inferred getter(fs, subidx) + if test_inplace && buffer[subidx] isa AbstractArray + @inferred getter(deepcopy(buffer[subidx]), fs, subidx) + end end @test getter(fs, subidx) == val[subidx] + if test_inplace && buffer[subidx] isa AbstractArray + tmp = deepcopy(buffer[subidx]) + getter(tmp, fs, subidx) + if val[subidx] isa Tuple + target = collect(val[subidx]) + elseif eltype(val) <: Tuple + target = collect.(val[subidx]) + else + target = val[subidx] + end + @test tmp == target + end end end end