Skip to content

Commit

Permalink
Merge pull request #317 from AayushSabharwal/as/all-symbols
Browse files Browse the repository at this point in the history
feat: add support for new *_symbols methods in SII
  • Loading branch information
ChrisRackauckas authored Dec 27, 2023
2 parents 63421ad + 501387d commit bb5a2c8
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ StaticArrays = "1.6"
StaticArraysCore = "1.1"
Statistics = "1"
StructArrays = "0.6"
SymbolicIndexingInterface = "0.3"
SymbolicIndexingInterface = "0.3.1"
Tables = "1"
Test = "1"
Tracker = "0.2"
Expand Down
14 changes: 10 additions & 4 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,7 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
return getindex.(A.u, variable_index.((A,), (sym,), eachindex(A.t)))
end
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
elseif is_observed(A, sym)
return observed(A, sym).(A.u, (parameter_values(A),), A.t)
else
Expand Down Expand Up @@ -325,8 +324,7 @@ end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
if all(x -> is_parameter(A, x), sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
else
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
end
Expand All @@ -336,6 +334,14 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
return reduce(vcat, map(s -> A[s, args...]', sym))
end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, ::SymbolicIndexingInterface.SolvedVariables, args...)
return getindex(A, variable_symbols(A), args...)
end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, ::SymbolicIndexingInterface.AllVariables, args...)
return getindex(A, all_variable_symbols(A), args...)
end

Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
symtype = symbolic_type(_arg)
elsymtype = symbolic_type(eltype(_arg))
Expand Down
9 changes: 7 additions & 2 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ sol_new = DiffEqArray(sol.u[1:10],
@test sol_new[t] sol_new.t
@test sol_new[t, 1:5] sol_new.t[1:5]
@test getp(sol, τ)(sol) == getp(sol_new, τ)(sol_new) == 3.0
@test_deprecated sol[τ]
@test_deprecated sol_new[τ]
@test variable_symbols(sol) == variable_symbols(sol_new) == [x]
@test all_variable_symbols(sol) == all_variable_symbols(sol_new) == [x, RHS]
@test all_symbols(sol) == all_symbols(sol_new) == [x, RHS, τ, t]
@test sol[solvedvariables, 1:10] == sol_new[solvedvariables] == sol_new[[x]]
@test sol[allvariables, 1:10] == sol_new[allvariables] == sol_new[[x, RHS]]
@test_throws Exception sol[τ]
@test_throws Exception sol_new[τ]

# Tables interface
test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))
Expand Down
2 changes: 1 addition & 1 deletion test/qa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using RecursiveArrayTools, Aqua
Aqua.find_persistent_tasks_deps(RecursiveArrayTools)
ambs = Aqua.detect_ambiguities(RecursiveArrayTools; recursive = true)
@warn "Number of method ambiguities: $(length(ambs))"
@test length(ambs) <= 1
@test length(ambs) <= 11
Aqua.test_deps_compat(RecursiveArrayTools)
Aqua.test_piracies(RecursiveArrayTools)
Aqua.test_project_extras(RecursiveArrayTools)
Expand Down
9 changes: 6 additions & 3 deletions test/symbolic_indexing_interface_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
@test dx[(:a, :b)] == [(f(x), f2(x)) for x in t]
@test dx[[:a, :b], 3] [f(t[3]), f2(t[3])]
@test dx[[:a, :b], 4:5] vcat(f.(t[4:5])', f2.(t[4:5])')
@test dx[solvedvariables] == dx[allvariables] == dx[[:a, :b]]
@test dx[solvedvariables, 3] == dx[allvariables, 3] == dx[[:a, :b], 3]
@test getp(dx, [:p, :q])(dx) == [1.0, 2.0]
@test getp(dx, :p)(dx) == 1.0
@test getp(dx, :q)(dx) == 2.0
@test_deprecated dx[:p]
@test_deprecated dx[[:p, :q]]
@test_throws Exception dx[:p]
@test_throws Exception dx[[:p, :q]]
@test dx[:t] == t

@test symbolic_container(dx) isa SymbolCache
Expand All @@ -35,11 +37,12 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
@test is_parameter.((dx,), [:a, :b, :p, :q, :t]) == [false, false, true, true, false]
@test parameter_index.((dx,), [:a, :b, :p, :q, :t]) == [nothing, nothing, 1, 2, nothing]
@test is_independent_variable.((dx,), [:a, :b, :p, :q, :t]) == [false, false, false, false, true]
@test variable_symbols(dx) == [:a, :b]
@test variable_symbols(dx) == all_variable_symbols(dx) == [:a, :b]
@test parameter_symbols(dx) == [:p, :q]
@test independent_variable_symbols(dx) == [:t]
@test is_time_dependent(dx)
@test constant_structure(dx)
@test all_symbols(dx) == [:a, :b, :p, :q, :t]

dx = DiffEqArray([[f(x), f2(x)] for x in t], t; variables = [:a, :b])
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym
Expand Down

0 comments on commit bb5a2c8

Please sign in to comment.