Skip to content

Commit

Permalink
Merge pull request #86 from SciML/as/empty-getu
Browse files Browse the repository at this point in the history
fix: fix and test `getu` and `getp` with empty arrays
  • Loading branch information
AayushSabharwal authored Jun 17, 2024
2 parents e7dd822 + 8a1d5f7 commit a328556
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <:
end

function MultipleParametersGetter(getters)
if isempty(getters)
return MultipleParametersGetter{IndexerNotTimeseries, typeof(getters), Nothing}(
getters, nothing)
end
has_timeseries_indexers = any(getters) do g
is_indexer_timeseries(g) == IndexerTimeseries()
end
Expand Down
2 changes: 1 addition & 1 deletion src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ for (t1, t2) in [
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
num_observed = count(x -> is_observed(sys, x), sym)
if num_observed == 0 || num_observed == 1 && sym isa Tuple
if all(Base.Fix1(is_parameter, sys), sym) &&
if !isempty(sym) && all(Base.Fix1(is_parameter, sys), sym) &&
all(!Base.Fix1(is_timeseries_parameter, sys), sym)
GetpAtStateTime(getp(sys, sym))
else
Expand Down
21 changes: 21 additions & 0 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,23 @@ for sys in [
@test buffer == collect(val)
end
end

getter = getp(sys, [])
@test getter(fi) == []
getter = getp(sys, ())
@test getter(fi) == ()
end
end

let
sc = SymbolCache(nothing, nothing, :t)
fi = FakeIntegrator(sc, nothing, 0.0, Ref(0))
getter = getp(sc, [])
@test getter(fi) == []
getter = getp(sc, ())
@test getter(fi) == ()
end

struct MyDiffEqArray
t::Vector{Float64}
u::Vector{Vector{Float64}}
Expand Down Expand Up @@ -387,6 +401,13 @@ end

@test_throws ErrorException getp(sys, :not_a_param)

let fs = fs, sys = sys
getter = getp(sys, [])
@test getter(fs) == []
getter = getp(sys, ())
@test getter(fs) == ()
end

struct FakeNoTimeSolution
sys::SymbolCache
u::Vector{Float64}
Expand Down
26 changes: 26 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,19 @@ for (sym, val, check_inference) in [
@test get(fi) == val
end

let fi = fi, sys = sys
getter = getu(sys, [])
@test getter(fi) == []
getter = getu(sys, ())
@test getter(fi) == ()
sc = SymbolCache(nothing, [:a, :b], :t)
fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0)
getter = getu(sc, [])
@test getter(fi) == []
getter = getu(sc, ())
@test getter(fi) == ()
end

for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
(:b, p[2], 5.0, true)
(:c, p[3], 6.0, true)
Expand Down Expand Up @@ -232,6 +245,19 @@ for (sym, val) in [(:a, p[1])
@test get(sol) == val
end

let sol = sol, sys = sys
getter = getu(sys, [])
@test getter(sol) == [[] for _ in 1:length(sol.t)]
getter = getu(sys, ())
@test getter(sol) == [() for _ in 1:length(sol.t)]
sc = SymbolCache(nothing, [:a, :b], :t)
sol = FakeSolution(sys, [], [1.0, 2.0], [])
getter = getu(sc, [])
@test getter(sol) == []
getter = getu(sc, ())
@test getter(sol) == []
end

sys = SymbolCache([:x, :y, :z], [:a, :b, :c])
u = [1.0, 2.0, 3.0]
p = [10.0, 20.0, 30.0]
Expand Down

0 comments on commit a328556

Please sign in to comment.