Skip to content

Commit

Permalink
refactor: use new getp and setp for parameter indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 28, 2023
1 parent 1e2fbfb commit 7bc832a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 16 deletions.
15 changes: 4 additions & 11 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,6 @@ function DiffEqArray(vec::AbstractVector{VT},
sys)
end

function Base.getproperty(A::AbstractDiffEqArray, sym::Symbol)
if sym == SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME
return ParameterIndexingProxy(A)
else
return getfield(A, sym)
end
end

SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p
SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys

Expand Down Expand Up @@ -292,8 +284,8 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar
return getindex.(A.u, variable_index.((A,), (sym,), eachindex(A.t)))
end
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `$(nameof(typeof(A))).$(SymbolicIndexingInterface.PARAMETER_INDEXING_PROXY_PROPERTY_NAME)[$sym]` for parameter indexing.", :parameter_getindex)
A.ps[sym]
Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
elseif is_observed(A, sym)
return observed(A, sym).(A.u, (parameter_values(A),), A.t)
else
Expand Down Expand Up @@ -333,7 +325,8 @@ end

Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
if all(x -> is_parameter(A, x), sym)
return getindex.((A,), sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
else
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
end
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, Test
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Test

include("../testutils.jl")

Expand All @@ -21,7 +21,7 @@ sol_new = DiffEqArray(sol.u[1:10],
@test sol_new[RHS] (1 .- sol_new[x]) ./ 3.0
@test sol_new[t] sol_new.t
@test sol_new[t, 1:5] sol_new.t[1:5]
@test sol.ps[τ] == sol_new.ps[τ] == 3.0
@test getp(sol, τ)(sol) == getp(sol_new, τ)(sol_new) == 3.0
@test_deprecated sol[τ]
@test_deprecated sol_new[τ]

Expand Down
8 changes: 5 additions & 3 deletions test/symbolic_indexing_interface_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
@test dx[(:a, :b)] == [(f(x), f2(x)) for x in t]
@test dx[[:a, :b], 3] [f(t[3]), f2(t[3])]
@test dx[[:a, :b], 4:5] vcat(f.(t[4:5])', f2.(t[4:5])')
@test dx.ps[[:p, :q]] == [1.0, 2.0]
@test dx.ps[:p] == 1.0
@test dx.ps[:q] == 2.0
@test getp(dx, [:p, :q])(dx) == [1.0, 2.0]
@test getp(dx, :p)(dx) == 1.0
@test getp(dx, :q)(dx) == 2.0
@test_deprecated dx[:p]
@test_deprecated dx[[:p, :q]]
@test dx[:t] == t

@test symbolic_container(dx) isa SymbolCache
Expand Down

0 comments on commit 7bc832a

Please sign in to comment.