From c207bc0f34b5fb107a030201834ee51e3a9e2a90 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 30 Oct 2024 06:46:48 -0100 Subject: [PATCH] Simplify VectorOfArray indexing Relying on stack removes the need to `adapt`, which should make the GPUs more efficient. With `stack`, those extra dispatches were unnecessary. They were rarely hit and one had a bug too! So they can just be removed. Seems to fix downstream. --- Project.toml | 2 +- src/vector_of_array.jl | 25 ------------------------- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 32e4a821..c0b801fb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "3.27.1" +version = "3.27.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index eef8625f..95a2a930 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -321,32 +321,7 @@ end @deprecate Base.getindex(A::AbstractDiffEqArray, i::Int) Base.getindex(A, :, i) false __parameterless_type(T) = Base.typename(T).wrapper -Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N}, - ::NotSymbolic, I::Colon...) where {T, N} - @assert length(I) == ndims(A.u[1]) + 1 - vecs = if N == 1 - A.u - else - vec.(A.u) - end - return Adapt.adapt(__parameterless_type(T), - reshape(reduce(hcat, vecs), size(A.u[1])..., length(A.u))) -end -Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N}, - ::NotSymbolic, I::Colon...) where {T <: Number, N} - @assert length(I) == ndims(A.u) - return A.u[I...] -end - -Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N}, - ::NotSymbolic, I::AbstractArray{Bool}, - J::Colon...) where {T, N} - @assert length(J) == ndims(A.u[1]) + 1 - ndims(I) - @assert size(I) == size(A)[1:(ndims(A) - length(J))] - return A[ntuple(x -> Colon(), ndims(A))...][I, J...] -end -# Need two of each methods to avoid ambiguities Base.@propagate_inbounds function _getindex( A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int) A.u[I]