From 0e3b04757afef18d59f7623d6d092a30db748c45 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 May 2024 14:00:39 +0530 Subject: [PATCH] test: test new parameter indexing, SymbolCache, ParameterTimeseriesCollection --- test/example_test.jl | 3 + test/parameter_indexing_test.jl | 436 +++++++++++++------ test/parameter_timeseries_collection_test.jl | 47 ++ test/runtests.jl | 3 + test/state_indexing_test.jl | 64 ++- test/symbol_cache_test.jl | 21 +- 6 files changed, 436 insertions(+), 138 deletions(-) create mode 100644 test/parameter_timeseries_collection_test.jl diff --git a/test/example_test.jl b/test/example_test.jl index 5d3caa0..521de32 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -71,6 +71,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test all(.!is_parameter.((sys,), [:x, :y, :z, :t, :p, :q, :r])) @test all(parameter_index.((sys,), [:c, :a, :b]) .== [3, 1, 2]) @test all(parameter_index.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing) +@test all(.!is_timeseries_parameter.((sys,), [:x, :y, :z, :t, :p, :q, :r])) # fallback even if not implemented +@test all(timeseries_parameter_index.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing) # fallback @test is_independent_variable(sys, :t) @test all(.!is_independent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r])) @test all(is_observed.((sys,), [:x, :y, :z, :a, :b, :c, :t])) @@ -88,6 +90,7 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test independent_variable_symbols(sys) == [:t] @test all_variable_symbols(sys) == [:x, :y, :z] @test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z] +@test default_values(sys) == Dict() # fallback even if not implemented sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 67dff5a..3bebd08 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -1,6 +1,16 @@ using SymbolicIndexingInterface +using SymbolicIndexingInterface: IndexerTimeseries, IndexerNotTimeseries, IndexerBoth, + is_indexer_timeseries, indexer_timeseries_index, + ParameterTimeseriesValueIndexMismatchError, + MixedParameterTimeseriesIndexError using Test +arr = [1.0, 2.0, 3.0] +@test parameter_values(arr) == arr +@test current_time(arr) == arr +tp = (1.0, 2.0, 3.0) +@test parameter_values(tp) == tp + struct FakeIntegrator{S, P} sys::S p::P @@ -16,95 +26,115 @@ function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, fi.counter[] += 1 end -sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -for pType in [Vector, Tuple] - p = [1.0, 2.0, 3.0] - fi = FakeIntegrator(sys, pType(copy(p)), Ref(0)) - new_p = [4.0, 5.0, 6.0] - @test parameter_timeseries(fi) == [0] - for (sym, oldval, newval, check_inference) in [ - (:a, p[1], new_p[1], true), - (1, p[1], new_p[1], true), - ([:a, :b], p[1:2], new_p[1:2], true), - (1:2, p[1:2], new_p[1:2], true), - ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), - ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), - ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) - ] - get = getp(sys, sym) - set! = setp(sys, sym) - if check_inference - @inferred get(fi) +for sys in [ + SymbolCache([:x, :y, :z], [:a, :b, :c, :d], [:t]), + SymbolCache([:x, :y, :z], + [:a, :b, :c, :d], + [:t], + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) +] + has_ts = sys.timeseries_parameters !== nothing + for pType in [Vector, Tuple] + p = [1.0, 2.0, 3.0, 4.0] + fi = FakeIntegrator(sys, pType(copy(p)), Ref(0)) + new_p = [4.0, 5.0, 6.0, 7.0] + for i in [7, CartesianIndex(5)] + @test parameter_values_at_state_time(fi, i) == parameter_values(fi) end - @test get(fi) == fi.ps[sym] - @test get(fi) == oldval + for (sym, oldval, newval, check_inference) in [ + (:a, p[1], new_p[1], true), + (1, p[1], new_p[1], true), + ([:a, :b], p[1:2], new_p[1:2], !has_ts), + (1:2, p[1:2], new_p[1:2], true), + ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) + ] + get = getp(sys, sym) + set! = setp(sys, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == fi.ps[sym] + @test get(fi) == oldval - if pType === Tuple - @test_throws MethodError set!(fi, newval) - continue - end + if pType === Tuple + @test_throws MethodError set!(fi, newval) + continue + end - @test fi.counter[] == 0 - if check_inference - @inferred set!(fi, newval) - else - set!(fi, newval) - end - @test fi.counter[] == 1 + @test fi.counter[] == 0 + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test fi.counter[] == 1 - @test get(fi) == newval - set!(fi, oldval) - @test get(fi) == oldval - @test fi.counter[] == 2 + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval + @test fi.counter[] == 2 - fi.ps[sym] = newval - @test get(fi) == newval - @test fi.counter[] == 3 - fi.ps[sym] = oldval - @test get(fi) == oldval - @test fi.counter[] == 4 + fi.ps[sym] = newval + @test get(fi) == newval + @test fi.counter[] == 3 + fi.ps[sym] = oldval + @test get(fi) == oldval + @test fi.counter[] == 4 - if check_inference - @inferred get(p) + if check_inference + @inferred get(p) + end + @test get(p) == oldval + if check_inference + @inferred set!(p, newval) + else + set!(p, newval) + end + @test get(p) == newval + set!(p, oldval) + @test get(p) == oldval + @test fi.counter[] == 4 + fi.counter[] = 0 end - @test get(p) == oldval - if check_inference - @inferred set!(p, newval) - else - set!(p, newval) + + for (sym, val) in [ + ([:a, :b, :c, :d], p), + ([:c, :a], p[[3, 1]]), + ((:b, :a), p[[2, 1]]), + ((1, :c), p[[1, 3]]) + ] + buffer = zeros(length(sym)) + get = getp(sys, sym) + @inferred get(buffer, fi) + @test buffer == val end - @test get(p) == newval - set!(p, oldval) - @test get(p) == oldval - @test fi.counter[] == 4 - fi.counter[] = 0 end +end - for (sym, val) in [ - ([:a, :b, :c], p), - ([:c, :a], p[[3, 1]]), - ((:b, :a), p[[2, 1]]), - ((1, :c), p[[1, 3]]) - ] - buffer = zeros(length(sym)) - get = getp(sys, sym) - @inferred get(buffer, fi) - @test buffer == val - end +struct MyDiffEqArray + t::Vector{Float64} + u::Vector{Vector{Float64}} end +SymbolicIndexingInterface.current_time(mda::MyDiffEqArray) = mda.t +SymbolicIndexingInterface.state_values(mda::MyDiffEqArray) = mda.u +SymbolicIndexingInterface.is_timeseries(::Type{MyDiffEqArray}) = Timeseries() struct FakeSolution sys::SymbolCache u::Vector{Vector{Float64}} t::Vector{Float64} - p::Vector{Vector{Float64}} - pt::Vector{Float64} + p::Vector{Float64} + p_idxs::Vector{Vector{Int}} + p_ts::ParameterTimeseriesCollection{Vector{MyDiffEqArray}} end function Base.getproperty(fs::FakeSolution, s::Symbol) @@ -113,79 +143,227 @@ end SymbolicIndexingInterface.state_values(fs::FakeSolution) = fs.u SymbolicIndexingInterface.current_time(fs::FakeSolution) = fs.t SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys -SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p[end] -SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[end][i] -function SymbolicIndexingInterface.parameter_values_at_time(fs::FakeSolution, t) - fs.p[t] +SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p +SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[i] +function SymbolicIndexingInterface.parameter_values( + fs::FakeSolution, i::ParameterTimeseriesIndex, j) + parameter_values(fs.p_ts, i, j) end function SymbolicIndexingInterface.parameter_values_at_state_time(fs::FakeSolution, t) - ptind = searchsortedfirst(fs.pt, fs.t[t]; lt = <=) - fs.p[ptind - 1] + state_time = fs.t[t] + p = copy(fs.p) + for (i, p_idxs) in enumerate(fs.p_idxs) + p_times = parameter_timeseries(fs, i) + p_timeseries_idx = searchsortedlast(p_times, state_time) + p[p_idxs] = fs.p_ts[i, p_timeseries_idx] + end + return p +end +function SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution, idx) + parameter_timeseries(fs.p_ts, idx) end -SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution) = fs.pt SymbolicIndexingInterface.is_timeseries(::Type{FakeSolution}) = Timeseries() SymbolicIndexingInterface.is_parameter_timeseries(::Type{FakeSolution}) = Timeseries() -sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +sys = SymbolCache([:x, :y, :z], + [:a, :b, :c, :d], + :t; + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) +b_timeseries = MyDiffEqArray(collect(0:0.1:0.9), [[2.5i] for i in 1:10]) +c_timeseries = MyDiffEqArray(collect(0:0.25:0.9), [[3.5i] for i in 1:4]) fs = FakeSolution( sys, [i * ones(3) for i in 1:5], [0.2i for i in 1:5], - [2i * ones(3) for i in 1:10], - [0.1i for i in 1:10] + [20.0, b_timeseries.u[end][1], c_timeseries.u[end][1], 30.0], + [[2], [3]], + ParameterTimeseriesCollection([b_timeseries, c_timeseries]) ) -ps = fs.p -p = fs.p[end] -avals = getindex.(ps, 1) -bvals = getindex.(ps, 2) -cvals = getindex.(ps, 3) -@test parameter_timeseries(fs) == fs.pt -for (sym, val, arrval, check_inference) in [ - (:a, p[1], avals, true), - (1, p[1], avals, true), - ([:a, :b], p[1:2], vcat.(avals, bvals), true), - (1:2, p[1:2], vcat.(avals, bvals), true), - ((1, 2), Tuple(p[1:2]), tuple.(avals, bvals), true), - ([:a, [:b, :c]], [p[1], p[2:3]], - [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), - ([:a, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), - ((:a, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), - ((:a, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true), - ([1, [:b, :c]], [p[1], p[2:3]], - [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), - ([1, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), - ((1, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true) +aval = fs.p[1] +bval = getindex.(b_timeseries.u) +cval = getindex.(c_timeseries.u) +dval = fs.p[4] +bidx = timeseries_parameter_index(sys, :b) +cidx = timeseries_parameter_index(sys, :c) + +for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ + (:a, IndexerNotTimeseries, 0, aval, nothing, true), + (1, IndexerNotTimeseries, 0, aval, nothing, true), + ([:a, :d], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((:a, :d), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + ([1, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((1, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + ([:a, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((:a, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + (:b, IndexerBoth, 1, bval, zeros(length(bval)), true), + (bidx, IndexerTimeseries, 1, bval, zeros(length(bval)), true), + ([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), + ((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), + ([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), + ((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), + ([:b, :b], IndexerBoth, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((:b, :b), IndexerBoth, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, bidx), IndexerTimeseries, 1, + tuple.(bval, bval), map(_ -> zeros(2), bval), true) ] - get = getp(sys, sym) + getter = getp(sys, sym) + @test is_indexer_timeseries(getter) isa indexer_trait + if indexer_trait <: Union{IndexerTimeseries, IndexerBoth} + @test indexer_timeseries_index(getter) == timeseries_index + end + test_inplace = buffer !== nothing + test_non_timeseries = indexer_trait !== IndexerTimeseries + if test_inplace && test_non_timeseries + non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] + non_timeseries_buffer = indexer_trait == IndexerNotTimeseries ? deepcopy(buffer) : + deepcopy(buffer[end]) + test_non_timeseries_inplace = non_timeseries_buffer isa AbstractArray + end if check_inference - @inferred get(fs) + @inferred getter(fs) + if test_inplace + @inferred getter(deepcopy(buffer), fs) + end + if test_non_timeseries + @inferred getter(parameter_values(fs)) + if test_inplace && test_non_timeseries_inplace && test_non_timeseries_inplace + @inferred getter(deepcopy(non_timeseries_buffer), parameter_values(fs)) + end + end end - @test get(fs) == fs.ps[sym] - @test get(fs) == val - - for sub_inds in [ - :, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))] - if check_inference - @inferred get(fs, sub_inds) + @test getter(fs) == val + if test_inplace + tmp = deepcopy(buffer) + getter(tmp, fs) + if val isa Tuple + target = collect(val) + elseif eltype(val) <: Tuple + target = collect.(val) + else + target = val + end + @test tmp == target + end + if test_non_timeseries + non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] + @test getter(parameter_values(fs)) == non_timeseries_val + if test_inplace && test_non_timeseries && test_non_timeseries_inplace + getter(non_timeseries_buffer, parameter_values(fs)) + if non_timeseries_val isa Tuple + target = collect(non_timeseries_val) + else + target = non_timeseries_val + end + @test non_timeseries_buffer == target + end + else + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) + if test_inplace + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter( + [], parameter_values(fs)) end - @test get(fs, sub_inds) == fs.ps[sym, sub_inds] - @test get(fs, sub_inds) == arrval[sub_inds] end + for subidx in [ + 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] + if indexer_trait <: IndexerNotTimeseries + @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter( + fs, subidx) + if test_inplace + @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter( + [], fs, subidx) + end + else + if check_inference + @inferred getter(fs, subidx) + if test_inplace && buffer[subidx] isa AbstractArray + @inferred getter(deepcopy(buffer[subidx]), fs, subidx) + end + end + @test getter(fs, subidx) == val[subidx] + if test_inplace && buffer[subidx] isa AbstractArray + tmp = deepcopy(buffer[subidx]) + getter(tmp, fs, subidx) + if val[subidx] isa Tuple + target = collect(val[subidx]) + elseif eltype(val) <: Tuple + target = collect.(val[subidx]) + else + target = val[subidx] + end + @test tmp == target + end + end + end +end + +for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx), + [bidx, :c], (bidx, :c), [bidx, cidx], (bidx, cidx)] + @test_throws ArgumentError getp(sys, sym) +end + +for (sym, val) in [ + ([:b, :c], [bval[end], cval[end]]), + ((:b, :c), (bval[end], cval[end])) +] + getter = getp(sys, sym) + @test is_indexer_timeseries(getter) == IndexerNotTimeseries() + @test_throws MixedParameterTimeseriesIndexError getter(fs) + @test getter(parameter_values(fs)) == val end -ps = fs.p[2:2:end] -avals = getindex.(ps, 1) -bvals = getindex.(ps, 2) -cvals = getindex.(ps, 3) -for (sym, val, arrval) in [ - (:a, p[1], avals), - ((:b, :c), p[2:3], tuple.(bvals, cvals)), - ([:c, :a], p[[3, 1]], vcat.(cvals, avals)) +bval_state = [b_timeseries.u[searchsortedlast(b_timeseries.t, t)][] for t in fs.t] +cval_state = [c_timeseries.u[searchsortedlast(c_timeseries.t, t)][] for t in fs.t] +xval = getindex.(fs.u, 1) + +for (sym, val_is_timeseries, val, check_inference) in [ + (:a, false, aval, true), + ([:a, :d], false, [aval, dval], true), + ((:a, :d), false, (aval, dval), true), + (:b, true, bval_state, true), + ([:a, :b], true, vcat.(aval, bval_state), false), + ((:a, :b), true, tuple.(aval, bval_state), true), + ([:b, :c], true, vcat.(bval_state, cval_state), true), + ((:b, :c), true, tuple.(bval_state, cval_state), true), + ([:a, :b, :c], true, vcat.(aval, bval_state, cval_state), false), + ((:a, :b, :c), true, tuple.(aval, bval_state, cval_state), true), + ([:x, :b], true, vcat.(xval, bval_state), false), + ((:x, :b), true, tuple.(xval, bval_state), true), + ([:x, :b, :c], true, vcat.(xval, bval_state, cval_state), false), + ((:x, :b, :c), true, tuple.(xval, bval_state, cval_state), true), + ([:a, :b, :x], true, vcat.(aval, bval_state, xval), false), + ((:a, :b, :x), true, tuple.(aval, bval_state, xval), true), + (:(2b), true, 2 .* bval_state, true), + ([:x, :(2b), :(3c)], true, vcat.(xval, 2 .* bval_state, 3 .* cval_state), true), + ((:x, :(2b), :(3c)), true, tuple.(xval, 2 .* bval_state, 3 .* cval_state), true) ] - get = getu(sys, sym) - @inferred get(fs) - @test get(fs) == arrval - for i in eachindex(ps) - @test get(fs, i) == arrval[i] + getter = getu(sys, sym) + if val isa DataType + @test_throws val getter(fs) + continue + end + if check_inference + @inferred getter(fs) + end + @test getter(fs) == val + + for subidx in [ + 1, CartesianIndex(2), :, rand(Bool, length(fs.t)), rand(eachindex(fs.t), 3), 1:2] + if check_inference + @inferred getter(fs, subidx) + end + target = if val_is_timeseries + val[subidx] + else + if fs.t[subidx] isa AbstractArray + len = length(fs.t[subidx]) + fill(val, len) + else + val + end + end + @test getter(fs, subidx) == target end end diff --git a/test/parameter_timeseries_collection_test.jl b/test/parameter_timeseries_collection_test.jl new file mode 100644 index 0000000..4fcb11e --- /dev/null +++ b/test/parameter_timeseries_collection_test.jl @@ -0,0 +1,47 @@ +using SymbolicIndexingInterface +using Test + +struct MyDiffEqArray + t::Vector{Float64} + u::Vector{Vector{Float64}} +end +SymbolicIndexingInterface.current_time(mda::MyDiffEqArray) = mda.t +SymbolicIndexingInterface.state_values(mda::MyDiffEqArray) = mda.u +SymbolicIndexingInterface.is_timeseries(::Type{MyDiffEqArray}) = Timeseries() + +@test_throws ArgumentError ParameterTimeseriesCollection((ones(3), 2ones(3))) + +a_timeseries = MyDiffEqArray(collect(0:0.1:0.9), [[2.5i, sin(0.2i)] for i in 1:10]) +b_timeseries = MyDiffEqArray(collect(0:0.25:0.9), [[3.5i, log(1.3i)] for i in 1:4]) +c_timeseries = MyDiffEqArray(collect(0:0.17:0.90), [[4.3i] for i in 1:5]) +collection = (a_timeseries, b_timeseries, c_timeseries) +ptc = ParameterTimeseriesCollection(collection) + +@test collect(eachindex(ptc)) == [1, 2, 3] +@test [x for x in ptc] == [a_timeseries, b_timeseries, c_timeseries] +@test length(ptc) == 3 +@test parent(ptc) === collection + +for i in 1:3 + @test ptc[i] === collection[i] + @test parameter_timeseries(ptc, i) == collection[i].t + for j in eachindex(collection[i].u[1]) + pti = ParameterTimeseriesIndex(i, j) + @test ptc[pti] == getindex.(collection[i].u, j) + for k in eachindex(collection[i].u) + rhs = collection[i].u[k][j] + @test ptc[pti, CartesianIndex(k)] == rhs + @test ptc[pti, k] == rhs + @test ptc[i, k] == collection[i].u[k] + @test ptc[i, k, j] == rhs + @test parameter_values(ptc, pti, k) == rhs + end + allidxs = eachindex(collection[i].u) + for subidx in [:, rand(allidxs, 3), rand(Bool, length(allidxs))] + rhs = getindex.(collection[i].u[subidx], j) + @test ptc[pti, subidx] == rhs + @test ptc[i, subidx, j] == rhs + @test parameter_values(ptc, pti, subidx) == rhs + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a295448..eaf2ebe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,9 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Fallback test" begin @time include("fallback_test.jl") end + @safetestset "ParameterTimeseriesCollection test" begin + @time include("parameter_timeseries_collection_test.jl") + end @safetestset "Parameter indexing test" begin @time include("parameter_indexing_test.jl") end diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 16c8a71..d839263 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -13,6 +13,10 @@ SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) + +@test_throws ErrorException getu(sys, :q) +@test_throws ErrorException setu(sys, :q) + u = [1.0, 2.0, 3.0] p = [11.0, 12.0, 13.0] t = 0.5 @@ -130,14 +134,18 @@ struct FakeSolution{S, U, P, T} end SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries() +function SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution{ + S, U, P, Nothing}}) where {S, U, P} + NotTimeseries() +end SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u SymbolicIndexingInterface.parameter_values(fp::FakeSolution) = fp.p SymbolicIndexingInterface.current_time(fp::FakeSolution) = fp.t sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] -t = [1.5, 2.0] +u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]] +t = [1.5, 2.0, 2.3, 4.0] sol = FakeSolution(sys, u, p, t) xvals = getindex.(sol.u, 1) @@ -150,7 +158,7 @@ for (sym, ans, check_inference) in [(:x, xvals, true) (1, xvals, true) ([:x, :y], vcat.(xvals, yvals), true) (1:2, vcat.(xvals, yvals), true) - ([:x, 2], vcat.(xvals, yvals), false) + ([:x, 2], vcat.(xvals, yvals), true) ((:z, :y), tuple.(zvals, yvals), true) ((3, 2), tuple.(zvals, yvals), true) ([:x, [:y, :z]], @@ -186,7 +194,8 @@ for (sym, ans, check_inference) in [(:x, xvals, true) @inferred get(sol) end @test get(sol) == ans - for i in eachindex(u) + for i in [rand(eachindex(u)), CartesianIndex(1), :, + rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] if check_inference @inferred get(sol, i) end @@ -204,6 +213,13 @@ for (sym, val, check_inference) in [ @inferred get(sol) end @test get(sol) == val + for i in [rand(eachindex(u)), CartesianIndex(1), :, + rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] + if check_inference + @inferred get(sol, i) + end + @test get(sol, i) == val[i] + end end for (sym, val) in [(:a, p[1]) @@ -211,7 +227,41 @@ for (sym, val) in [(:a, p[1]) (:c, p[3]) ([:a, :b], p[1:2]) ((:c, :b), (p[3], p[2]))] - get = getu(fi, sym) - @inferred get(fi) - @test get(fi) == val + get = getu(sys, sym) + @inferred get(sol) + @test get(sol) == val +end + +sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) +u = [1.0, 2.0, 3.0] +p = [10.0, 20.0, 30.0] +fs = FakeSolution(sys, u, p, nothing) +@test is_timeseries(fs) == NotTimeseries() + +for (sym, val, check_inference) in [ + (:x, u[1], true), + (1, u[1], true), + ([:x, :y], u[1:2], true), + ((:x, :y), Tuple(u[1:2]), true), + (1:2, u[1:2], true), + ([:x, 2], u[1:2], true), + ((:x, 2), Tuple(u[1:2]), true), + ([1, 2], u[1:2], true), + ((1, 2), Tuple(u[1:2]), true), + (:a, p[1], true), + ([:a, :b], p[1:2], true), + ((:a, :b), Tuple(p[1:2]), true), + ([:x, :a], [u[1], p[1]], false), + ((:x, :a), (u[1], p[1]), true), + ([1, :a], [u[1], p[1]], false), + ((1, :a), (u[1], p[1]), true), + (:(x + y + a + b), u[1] + u[2] + p[1] + p[2], true), + ([:(x + a), :(y + b)], [u[1] + p[1], u[2] + p[2]], true), + ((:(x + a), :(y + b)), (u[1] + p[1], u[2] + p[2]), true) +] + getter = getu(sys, sym) + if check_inference + @inferred getter(fs) + end + @test getter(fs) == val end diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index f8a861b..136bb0b 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -16,10 +16,10 @@ sc = SymbolCache( @test is_time_dependent(sc) @test constant_structure(sc) @test variable_symbols(sc) == [:x, :y, :z] -@test parameter_symbols(sc) == [:a, :b] +@test sort(parameter_symbols(sc)) == [:a, :b] @test independent_variable_symbols(sc) == [:t] @test all_variable_symbols(sc) == [:x, :y, :z] -@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z] +@test sort(sort(all_symbols(sc))) == [:a, :b, :t, :x, :y, :z] @test default_values(sc)[:x] == 1 @test default_values(sc)[:y] == :(2b) @test default_values(sc)[:b] == :(2a + x) @@ -45,6 +45,16 @@ obsfn4 = observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)]) obsfn5 = observed(sc, (:(x + a), :(y + b))) @test all(obsfn5(ones(3), 2ones(2), 3.0) .≈ (3.0, 3.0)) +@test_throws TypeError observed(sc, [:(x + a), 2]) +@test_throws TypeError observed(sc, (:(x + a), 2)) + +@test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t; + timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1))) +@test_throws TypeError SymbolCache( + [:x, :y], [:a, :c], :t; timeseries_parameters = Dict(:c => (1, 1))) +@test_nowarn SymbolCache([:x, :y], [:a, :c], :t; + timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1))) + sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) @test sort(all_symbols(sc)) == [:a, :b, :x, :y] @@ -54,6 +64,9 @@ obsfn = observed(sc, :(x + b)) # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) +@test_throws ArgumentError SymbolCache( + [:x, :y], [:a, :b]; timeseries_parameters = Dict(:b => ParameterTimeseriesIndex(1, 1))) + sc = SymbolCache() @test all(.!is_variable.((sc,), [:x, :y, :a, :b, :t])) @test all(variable_index.((sc,), [:x, :y, :a, :b, :t]) .== nothing) @@ -77,6 +90,10 @@ sc = SymbolCache(nothing, nothing, :t) @test all_symbols(sc) == [:t] @test isempty(default_values(sc)) +sc = SymbolCache(nothing, nothing, [:t1, :t2, :t3]) +@test all(is_independent_variable.((sc,), [:t1, :t2, :t3])) +@test independent_variable_symbols(sc) == [:t1, :t2, :t3] + sc2 = copy(sc) @test sc.variables == sc2.variables @test sc.parameters == sc2.parameters