diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index d214fa31..7dc160c2 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -105,30 +105,35 @@ function getp(sys, p) end function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) - _getter = let p = p + return let p = p function _getter(::NotTimeseries, prob) parameter_values(prob, p) end function _getter(::Timeseries, prob) parameter_values(prob, p) end - function _getter(::Timeseries, prob, i) - parameter_values(parameter_values_at_time(prob, i), p) + function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex}) + parameter_values(parameter_values_at_time(prob, only(to_indices(parameter_timeseries(prob), (i,)))), p) end - function _getter(::Timeseries, prob, ::Colon) - parameter_values.((parameter_values_at_time(prob, i) for i in eachindex(parameter_timeseries(prob))), (p,)) + function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon}) + parameter_values.(parameter_values_at_time.((prob,), (j for j in only(to_indices(parameter_timeseries(prob), (i,))))), p) end - end - return let _getter = _getter - function getter(prob, args...) - return _getter(is_timeseries(prob), prob, args...) + function _getter(::Timeseries, prob, i) + parameter_values.(parameter_values_at_time.((prob,), i), (p,)) + end + getter = let _getter = _getter + function getter(prob, args...) + return _getter(is_timeseries(prob), prob, args...) + end end + getter end end function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) - return getp(sys, idx) + return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, sys, NotSymbolic(), NotSymbolic(), idx) + return _getp(sys, NotSymbolic(), NotSymbolic(), idx) end for (t1, t2) in [ @@ -139,43 +144,54 @@ for (t1, t2) in [ @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) getters = getp.((sys,), p) - _getter = return let getters = getters + return let getters = getters function _getter(::NotTimeseries, prob) map(g -> g(prob), getters) end function _getter(::Timeseries, prob) map(g -> g(prob), getters) end - function _getter(::Timeseries, prob, i) + function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex}) map(g -> g(prob, i), getters) end - function _getter(::Timeseries, prob, ::Colon) - [map(g -> g(prob, i), getters) for i in eachindex(parameter_timeseries(prob))] - end - function _getter(buffer, ::NotTimeseries, prob) - map!(g -> g(prob), buffer, getters) + function _getter(::Timeseries, prob, i) + [map(g -> g(prob, j), getters) for j in only(to_indices(parameter_timeseries(prob), (i,)))] end - function _getter(buffer, ::Timeseries, prob) - map!(g -> g(prob), buffer, getters) + function _getter!(buffer, ::NotTimeseries, prob) + for (g, bufi) in zip(getters, eachindex(buffer)) + buffer[bufi] = g(prob) + end + buffer end - function _getter(buffer, ::Timeseries, prob, i) - map!(g -> g(prob, i), buffer, getters) + function _getter!(buffer, ::Timeseries, prob) + for (g, bufi) in zip(getters, eachindex(buffer)) + buffer[bufi] = g(prob) + end + buffer end - function _getter(buffer, ::Timeseries, prob, ::Colon) - for (bufi, tsi) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob))) - map!(g -> g(prob, tsi), buffer[bufi], getters) + function _getter!(buffer, ::Timeseries, prob, i::Union{Int, CartesianIndex}) + for (g, bufi) in zip(getters, eachindex(buffer)) + buffer[bufi] = g(prob, i) end buffer end - _getter - end - - return let _getter = _getter - function getter(prob, i...) - return _getter(is_timeseries(prob), prob, i...) + function _getter!(buffer, ::Timeseries, prob, i) + for (bufi, tsi) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,)))) + for (g, bufj) in zip(getters, eachindex(buffer[bufi])) + buffer[bufi][bufj] = g(prob, tsi) + end + end + buffer end - function getter(buffer::AbstractArray, prob, i...) - return _getter(buffer, is_timeseries(prob), prob, i...) + _getter, _getter! + getter = let _getter = _getter, _getter! = _getter! + function getter(prob, i...) + return _getter(is_timeseries(prob), prob, i...) + end + function getter(buffer::AbstractArray, prob, i...) + return _getter!(buffer, is_timeseries(prob), prob, i...) + end + getter end getter end diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index 0b92587d..ab365ed0 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -10,8 +10,8 @@ struct ParameterIndexingProxy{T} wrapped::T end -function Base.getindex(p::ParameterIndexingProxy, idx) - return getp(p.wrapped, idx)(p.wrapped) +function Base.getindex(p::ParameterIndexingProxy, idx, args...) + getp(p.wrapped, idx)(p.wrapped, args...) end function Base.setindex!(p::ParameterIndexingProxy, val, idx) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index adac52f4..f785ac0d 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -6,6 +6,9 @@ struct FakeIntegrator{S, P} p::P end +function Base.getproperty(fi::FakeIntegrator, s::Symbol) + s === :ps ? ParameterIndexingProxy(fi) : getfield(fi, s) +end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p @@ -13,6 +16,7 @@ sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) p = [1.0, 2.0, 3.0] fi = FakeIntegrator(sys, copy(p)) new_p = [4.0, 5.0, 6.0] +@test parameter_timeseries(fi) == [0] for (sym, oldval, newval, check_inference) in [ (:a, p[1], new_p[1], true), (1, p[1], new_p[1], true), @@ -33,6 +37,7 @@ for (sym, oldval, newval, check_inference) in [ if check_inference @inferred get(fi) end + @test get(fi) == fi.ps[sym] @test get(fi) == oldval if check_inference @inferred set!(fi, newval) @@ -43,6 +48,11 @@ for (sym, oldval, newval, check_inference) in [ set!(fi, oldval) @test get(fi) == oldval + fi.ps[sym] = newval + @test get(fi) == newval + fi.ps[sym] = oldval + @test get(fi) == oldval + if check_inference @inferred get(p) end @@ -68,3 +78,71 @@ for (sym, val) in [ @inferred get(buffer, fi) @test buffer == val end + +struct FakeSolution + sys::SymbolCache + u::Vector{Vector{Float64}} + t::Vector{Float64} + p::Vector{Vector{Float64}} + pt::Vector{Float64} +end + +function Base.getproperty(fs::FakeSolution, s::Symbol) + s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s) +end +SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys +SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p[end] +SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[end][i] +function SymbolicIndexingInterface.parameter_values_at_time(fs::FakeSolution, t) + fs.p[t] +end +function SymbolicIndexingInterface.parameter_values_at_state_time(fs::FakeSolution, t) + ptind = searchsortedfirst(fs.pt, fs.t[t]) + fs.p[ptind - 1] +end +SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution) = fs.pt +SymbolicIndexingInterface.is_timeseries(::Type{FakeSolution}) = Timeseries() +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +fs = FakeSolution( + sys, + [i * ones(3) for i in 1:5], + [0.2i for i in 1:5], + [2i * ones(3) for i in 1:10], + [0.1i for i in 1:10], +) +ps = fs.p +p = fs.p[end] +avals = getindex.(ps, 1) +bvals = getindex.(ps, 2) +cvals = getindex.(ps, 3) +@test parameter_timeseries(fs) == fs.pt +for (sym, val, arrval, check_inference) in [ + (:a, p[1], avals, true), + (1, p[1], avals, true), + ([:a, :b], p[1:2], vcat.(avals, bvals), true), + (1:2, p[1:2], vcat.(avals, bvals), true), + ((1, 2), Tuple(p[1:2]), tuple.(avals, bvals), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), + ((:a, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true), + ([1, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), + ((1, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true) +] + get = getp(sys, sym) + if check_inference + @inferred get(fs) + end + @test get(fs) == fs.ps[sym] + @test get(fs) == val + + for sub_inds in [:, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))] + if check_inference + @inferred get(fs, sub_inds) + end + @test get(fs, sub_inds) == fs.ps[sym, sub_inds] + @test get(fs, sub_inds) == arrval[sub_inds] + end +end