From dd67abbc3dbc4d083ce3887931c3da6079fb419d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 17 Jun 2024 11:06:56 +0530 Subject: [PATCH] fix: fix and test `getu` and `getp` with empty arrays --- src/parameter_indexing.jl | 3 +++ src/state_indexing.jl | 2 +- test/parameter_indexing_test.jl | 15 +++++++++++++++ test/state_indexing_test.jl | 18 ++++++++++++++++++ 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index de2d6caf..ed793762 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -323,6 +323,9 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: end function MultipleParametersGetter(getters) + if isempty(getters) + return MultipleParametersGetter{IndexerNotTimeseries, typeof(getters), Nothing}(getters, nothing) + end has_timeseries_indexers = any(getters) do g is_indexer_timeseries(g) == IndexerTimeseries() end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 25d18666..b17e417e 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -236,7 +236,7 @@ for (t1, t2) in [ @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) num_observed = count(x -> is_observed(sys, x), sym) if num_observed == 0 || num_observed == 1 && sym isa Tuple - if all(Base.Fix1(is_parameter, sys), sym) && + if !isempty(sym) && all(Base.Fix1(is_parameter, sys), sym) && all(!Base.Fix1(is_timeseries_parameter, sys), sym) GetpAtStateTime(getp(sys, sym)) else diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 5b392ab1..c32b4e37 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -123,9 +123,19 @@ for sys in [ @test buffer == collect(val) end end + + getter = getp(sys, []) + @test getter(fi) == [] end end +let + sc = SymbolCache(nothing, nothing, :t) + fi = FakeIntegrator(sc, nothing, 0.0, Ref(0)) + getter = getp(sc, []) + @test getter(fi) == [] +end + struct MyDiffEqArray t::Vector{Float64} u::Vector{Vector{Float64}} @@ -387,6 +397,11 @@ end @test_throws ErrorException getp(sys, :not_a_param) +let fs = fs, sys = sys + getter = getp(sys, []) + @test getter(fs) == [] +end + struct FakeNoTimeSolution sys::SymbolCache u::Vector{Float64} diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index d8392636..d4a3f6ac 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -91,6 +91,15 @@ for (sym, val, check_inference) in [ @test get(fi) == val end +let fi = fi, sys = sys + getter = getu(sys, []) + @test getter(fi) == [] + sc = SymbolCache(nothing, [:a, :b], :t) + fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0) + getter = getu(sc, []) + @test getter(fi) == [] +end + for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) (:b, p[2], 5.0, true) (:c, p[3], 6.0, true) @@ -232,6 +241,15 @@ for (sym, val) in [(:a, p[1]) @test get(sol) == val end +let sol = sol, sys = sys + getter = getu(sys, []) + @test getter(sol) == [[] for _ in 1:length(sol.t)] + sc = SymbolCache(nothing, [:a, :b], :t) + sol = FakeSolution(sys, nothing, [1.0, 2.0], [0.0]) + getter = getu(sc, []) + @test getter(sol) == [[]] +end + sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) u = [1.0, 2.0, 3.0] p = [10.0, 20.0, 30.0]