Skip to content

Commit

Permalink
refactor: remove potential for StackOverflowError in AbstractVectorOf…
Browse files Browse the repository at this point in the history
…Array indexing
  • Loading branch information
AayushSabharwal committed Dec 20, 2023
1 parent 9790424 commit 79bac77
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,15 @@ end
@deprecate Base.getindex(A::AbstractDiffEqArray, i::Int) Base.getindex(A, :, i) false

__parameterless_type(T) = Base.typename(T).wrapper
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},

Check warning on line 233 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L233

Added line #L233 was not covered by tests
::NotSymbolic, I::Colon...) where {T, N}
@assert length(I) == ndims(A.u[1]) + 1
vecs = vec.(A.u)
return Adapt.adapt(__parameterless_type(T),
reshape(reduce(hcat, vecs), size(A.u[1])..., length(A.u)))
end

Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},

Check warning on line 241 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L241

Added line #L241 was not covered by tests
::NotSymbolic, I::AbstractArray{Bool},
J::Colon...) where {T, N}
@assert length(J) == ndims(A.u[1]) + 1 - ndims(I)
Expand All @@ -247,34 +247,34 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
end

# Need two of each methods to avoid ambiguities
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
A.u[I]
end

Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, I::Union{Int,AbstractArray{Int},AbstractArray{Bool},Colon}...)
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, I::Union{Int,AbstractArray{Int},AbstractArray{Bool},Colon}...)
if last(I) isa Int
A.u[last(I)][Base.front(I)...]
else
stack(getindex.(A.u[last(I)], tuple.(Base.front(I))...))
end
end
Base.@propagate_inbounds function Base.getindex(VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex)
Base.@propagate_inbounds function _getindex(VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex)
ti = Tuple(ii)
i = last(ti)
jj = CartesianIndex(Base.front(ti))
return VA.u[i][jj]
end

Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}})
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}})
VectorOfArray(A.u[I])
end

Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}})
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}})
DiffEqArray(A.u[I], A.t[I], parameter_values(A), symbolic_container(A))
end

# Symbolic Indexing Methods
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym)
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym)
if is_independent_variable(A, sym)
return A.t
elseif is_variable(A, sym)
Expand All @@ -296,7 +296,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...)
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...)
if is_independent_variable(A, sym)
return A.t[args...]
elseif is_variable(A, sym)
Expand All @@ -319,11 +319,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar
end


Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...)
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...)

Check warning on line 322 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L322

Added line #L322 was not covered by tests
return getindex(A, collect(sym), args...)
end

Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
if all(x -> is_parameter(A, x), sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
Expand All @@ -332,7 +332,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}, args...)
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}, args...)
return reduce(vcat, map(s -> A[s, args...]', sym))
end

Expand All @@ -341,9 +341,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
elsymtype = symbolic_type(eltype(_arg))

if symtype != NotSymbolic()
return Base.getindex(A, symtype, _arg, args...)
return _getindex(A, symtype, _arg, args...)
else
return Base.getindex(A, elsymtype, _arg, args...)
return _getindex(A, elsymtype, _arg, args...)
end
end

Expand Down

0 comments on commit 79bac77

Please sign in to comment.