Skip to content

Commit

Permalink
Merge pull request #371 from AayushSabharwal/as/broadcast-index
Browse files Browse the repository at this point in the history
fix: fix indexing using array symbolics, `Colon`
  • Loading branch information
ChrisRackauckas authored May 3, 2024
2 parents 39ae861 + 0879183 commit d6eb1ad
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ StaticArrays = "1.6"
StaticArraysCore = "1.4"
Statistics = "1.10"
StructArrays = "0.6.11"
SymbolicIndexingInterface = "0.3.19"
SymbolicIndexingInterface = "0.3.20"
Tables = "1.11"
Test = "1"
Tracker = "0.2.15"
Expand Down
70 changes: 47 additions & 23 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,49 +351,73 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli
DiffEqArray(A.u[I], A.t[I], parameter_values(A), symbolic_container(A))
end

struct ParameterIndexingError <: Exception
sym
end

function Base.showerror(io::IO, pie::ParameterIndexingError)
print(io, "Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.")
end

# Symbolic Indexing Methods
for symtype in [ScalarSymbolic, ArraySymbolic]
paramcheck = quote
if is_parameter(A, sym) || (sym isa AbstractArray && symbolic_type(eltype(sym)) !== NotSymbolic() || sym isa Tuple) && all(x -> is_parameter(A, x), sym)
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
end
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym)
$paramcheck
getu(A, sym)(A)
for (symtype, elsymtype, valtype, errcheck) in [
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
:(all(x -> is_parameter(A, x), sym))),
]
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype)
if $errcheck
throw(ParameterIndexingError(sym))
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg)
$paramcheck
getu(A, sym)(A, arg)
getu(A, sym)(A)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg)
if $errcheck
throw(ParameterIndexingError(sym))
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
$paramcheck
getu(A, sym).((A,), arg)
getu(A, sym)(A, arg)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
if $errcheck
throw(ParameterIndexingError(sym))
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Colon)
$paramcheck
getu(A, sym)(A)
getu(A, sym).((A,), arg)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, ::Colon)
if $errcheck
throw(ParameterIndexingError(sym))
end
getu(A, sym)(A)
end
end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,
::SymbolicIndexingInterface.SolvedVariables, args...)
::NotSymbolic, ::SymbolicIndexingInterface.SolvedVariables, args...)
return getindex(A, variable_symbols(A), args...)
end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,
::SymbolicIndexingInterface.AllVariables, args...)
::NotSymbolic, ::SymbolicIndexingInterface.AllVariables, args...)
return getindex(A, all_variable_symbols(A), args...)
end

Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
symtype = symbolic_type(_arg)
elsymtype = symbolic_type(eltype(_arg))

if symtype != NotSymbolic()
return _getindex(A, symtype, _arg, args...)
if symtype == NotSymbolic() && elsymtype == NotSymbolic()
if _arg isa Union{Tuple, AbstractArray} && any(x -> symbolic_type(x) != NotSymbolic(), _arg)
_getindex(A, symtype, elsymtype, _arg, args...)
else
_getindex(A, symtype, _arg, args...)
end
else
return _getindex(A, elsymtype, _arg, args...)
_getindex(A, symtype, elsymtype, _arg, args...)
end
end

Expand Down
12 changes: 12 additions & 0 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,15 @@ sol_ts = sol(ts)
@assert sol_ts isa DiffEqArray
test_tables_interface(sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")],
hcat(ts, Array(sol_ts)'))

# Array variables
using LinearAlgebra
sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
ps = @parameters p[1:3] = [1, 2, 3]
eqs = [collect(D.(x) .~ x)
D(y) ~ norm(collect(x)) * y - x[1]]
@mtkbuild sys = ODESystem(eqs, t, sts, ps)
prob = ODEProblem(sys, [], (0, 1.0))
sol = solve(prob, Tsit5())
@test sol[x .+ [y, 2y, 3y]] vcat.(getindex.((sol,), [x[1] + y, x[2] + 2y, x[3] + 3y])...)
@test sol[x, :] sol[x]

0 comments on commit d6eb1ad

Please sign in to comment.