Skip to content

Commit

Permalink
fix: fix and test getu and getp with empty arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 17, 2024
1 parent e7dd822 commit dd67abb
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <:
end

function MultipleParametersGetter(getters)
if isempty(getters)
return MultipleParametersGetter{IndexerNotTimeseries, typeof(getters), Nothing}(getters, nothing)
end
has_timeseries_indexers = any(getters) do g
is_indexer_timeseries(g) == IndexerTimeseries()
end
Expand Down
2 changes: 1 addition & 1 deletion src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ for (t1, t2) in [
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
num_observed = count(x -> is_observed(sys, x), sym)
if num_observed == 0 || num_observed == 1 && sym isa Tuple
if all(Base.Fix1(is_parameter, sys), sym) &&
if !isempty(sym) && all(Base.Fix1(is_parameter, sys), sym) &&
all(!Base.Fix1(is_timeseries_parameter, sys), sym)
GetpAtStateTime(getp(sys, sym))
else
Expand Down
15 changes: 15 additions & 0 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,19 @@ for sys in [
@test buffer == collect(val)
end
end

getter = getp(sys, [])
@test getter(fi) == []
end
end

let
sc = SymbolCache(nothing, nothing, :t)
fi = FakeIntegrator(sc, nothing, 0.0, Ref(0))
getter = getp(sc, [])
@test getter(fi) == []
end

struct MyDiffEqArray
t::Vector{Float64}
u::Vector{Vector{Float64}}
Expand Down Expand Up @@ -387,6 +397,11 @@ end

@test_throws ErrorException getp(sys, :not_a_param)

let fs = fs, sys = sys
getter = getp(sys, [])
@test getter(fs) == []
end

struct FakeNoTimeSolution
sys::SymbolCache
u::Vector{Float64}
Expand Down
18 changes: 18 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ for (sym, val, check_inference) in [
@test get(fi) == val
end

let fi = fi, sys = sys
getter = getu(sys, [])
@test getter(fi) == []
sc = SymbolCache(nothing, [:a, :b], :t)
fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0)
getter = getu(sc, [])
@test getter(fi) == []
end

for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
(:b, p[2], 5.0, true)
(:c, p[3], 6.0, true)
Expand Down Expand Up @@ -232,6 +241,15 @@ for (sym, val) in [(:a, p[1])
@test get(sol) == val
end

let sol = sol, sys = sys
getter = getu(sys, [])
@test getter(sol) == [[] for _ in 1:length(sol.t)]
sc = SymbolCache(nothing, [:a, :b], :t)
sol = FakeSolution(sys, nothing, [1.0, 2.0], [0.0])
getter = getu(sc, [])
@test getter(sol) == [[]]
end

sys = SymbolCache([:x, :y, :z], [:a, :b, :c])
u = [1.0, 2.0, 3.0]
p = [10.0, 20.0, 30.0]
Expand Down

0 comments on commit dd67abb

Please sign in to comment.