From 79bac77bc650cbc2917347faaf24be1e9140a285 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 20 Dec 2023 12:56:34 +0530 Subject: [PATCH] refactor: remove potential for StackOverflowError in AbstractVectorOfArray indexing --- src/vector_of_array.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index f54139eb..e277c583 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -230,7 +230,7 @@ end @deprecate Base.getindex(A::AbstractDiffEqArray, i::Int) Base.getindex(A, :, i) false __parameterless_type(T) = Base.typename(T).wrapper -Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N}, +Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N}, ::NotSymbolic, I::Colon...) where {T, N} @assert length(I) == ndims(A.u[1]) + 1 vecs = vec.(A.u) @@ -238,7 +238,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N}, reshape(reduce(hcat, vecs), size(A.u[1])..., length(A.u))) end -Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N}, +Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N}, ::NotSymbolic, I::AbstractArray{Bool}, J::Colon...) where {T, N} @assert length(J) == ndims(A.u[1]) + 1 - ndims(I) @@ -247,34 +247,34 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N}, end # Need two of each methods to avoid ambiguities -Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int) +Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int) A.u[I] end -Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, I::Union{Int,AbstractArray{Int},AbstractArray{Bool},Colon}...) +Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, I::Union{Int,AbstractArray{Int},AbstractArray{Bool},Colon}...) if last(I) isa Int A.u[last(I)][Base.front(I)...] else stack(getindex.(A.u[last(I)], tuple.(Base.front(I))...)) end end -Base.@propagate_inbounds function Base.getindex(VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex) +Base.@propagate_inbounds function _getindex(VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex) ti = Tuple(ii) i = last(ti) jj = CartesianIndex(Base.front(ti)) return VA.u[i][jj] end -Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}}) +Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}}) VectorOfArray(A.u[I]) end -Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}}) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}}) DiffEqArray(A.u[I], A.t[I], parameter_values(A), symbolic_container(A)) end # Symbolic Indexing Methods -Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym) if is_independent_variable(A, sym) return A.t elseif is_variable(A, sym) @@ -296,7 +296,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar end end -Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...) if is_independent_variable(A, sym) return A.t[args...] elseif is_variable(A, sym) @@ -319,11 +319,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar end -Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...) return getindex(A, collect(sym), args...) end -Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}) if all(x -> is_parameter(A, x), sym) Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex) return getp(A, sym)(A) @@ -332,7 +332,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar end end -Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}, args...) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}, args...) return reduce(vcat, map(s -> A[s, args...]', sym)) end @@ -341,9 +341,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, elsymtype = symbolic_type(eltype(_arg)) if symtype != NotSymbolic() - return Base.getindex(A, symtype, _arg, args...) + return _getindex(A, symtype, _arg, args...) else - return Base.getindex(A, elsymtype, _arg, args...) + return _getindex(A, elsymtype, _arg, args...) end end