diff --git a/Project.toml b/Project.toml index f1e581ef..bf423d8c 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ StaticArrays = "1.6" StaticArraysCore = "1.4" Statistics = "1.10" StructArrays = "0.6.11" -SymbolicIndexingInterface = "0.3.19" +SymbolicIndexingInterface = "0.3.20" Tables = "1.11" Test = "1" Tracker = "0.2.15" diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 764a42e6..7a402c0f 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -351,38 +351,58 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli DiffEqArray(A.u[I], A.t[I], parameter_values(A), symbolic_container(A)) end +struct ParameterIndexingError <: Exception + sym +end + +function Base.showerror(io::IO, pie::ParameterIndexingError) + print(io, "Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.") +end + # Symbolic Indexing Methods -for symtype in [ScalarSymbolic, ArraySymbolic] - paramcheck = quote - if is_parameter(A, sym) || (sym isa AbstractArray && symbolic_type(eltype(sym)) !== NotSymbolic() || sym isa Tuple) && all(x -> is_parameter(A, x), sym) - error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.") - end - end - @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym) - $paramcheck - getu(A, sym)(A) +for (symtype, elsymtype, valtype, errcheck) in [ + (ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))), + (ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))), + (NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray}, + :(all(x -> is_parameter(A, x), sym))), +] +@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype) + if $errcheck + throw(ParameterIndexingError(sym)) end - @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg) - $paramcheck - getu(A, sym)(A, arg) + getu(A, sym)(A) +end +@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype, arg) + if $errcheck + throw(ParameterIndexingError(sym)) end - @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Union{AbstractArray{Int}, AbstractArray{Bool}}) - $paramcheck - getu(A, sym).((A,), arg) + getu(A, sym)(A, arg) +end +@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}}) + if $errcheck + throw(ParameterIndexingError(sym)) end - @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Colon) - $paramcheck - getu(A, sym)(A) + getu(A, sym).((A,), arg) +end +@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype, ::Colon) + if $errcheck + throw(ParameterIndexingError(sym)) end + getu(A, sym)(A) +end end Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, - ::SymbolicIndexingInterface.SolvedVariables, args...) + ::NotSymbolic, ::SymbolicIndexingInterface.SolvedVariables, args...) return getindex(A, variable_symbols(A), args...) end Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, - ::SymbolicIndexingInterface.AllVariables, args...) + ::NotSymbolic, ::SymbolicIndexingInterface.AllVariables, args...) return getindex(A, all_variable_symbols(A), args...) end @@ -390,10 +410,14 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, symtype = symbolic_type(_arg) elsymtype = symbolic_type(eltype(_arg)) - if symtype != NotSymbolic() - return _getindex(A, symtype, _arg, args...) + if symtype == NotSymbolic() && elsymtype == NotSymbolic() + if _arg isa Union{Tuple, AbstractArray} && any(x -> symbolic_type(x) != NotSymbolic(), _arg) + _getindex(A, symtype, elsymtype, _arg, args...) + else + _getindex(A, symtype, _arg, args...) + end else - return _getindex(A, elsymtype, _arg, args...) + _getindex(A, symtype, elsymtype, _arg, args...) end end diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 5ef7e3ac..d398a886 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -61,3 +61,15 @@ sol_ts = sol(ts) @assert sol_ts isa DiffEqArray test_tables_interface(sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")], hcat(ts, Array(sol_ts)')) + +# Array variables +using LinearAlgebra +sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0 +ps = @parameters p[1:3] = [1, 2, 3] +eqs = [collect(D.(x) .~ x) + D(y) ~ norm(collect(x)) * y - x[1]] +@mtkbuild sys = ODESystem(eqs, t, sts, ps) +prob = ODEProblem(sys, [], (0, 1.0)) +sol = solve(prob, Tsit5()) +@test sol[x .+ [y, 2y, 3y]] ≈ vcat.(getindex.((sol,), [x[1] + y, x[2] + 2y, x[3] + 3y])...) +@test sol[x, :] ≈ sol[x]