diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 2cdfbf4f..44fd15bc 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -422,6 +422,18 @@ Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N} return VA.u[i][jj] = x end +Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, x, idxs::Union{Int,Colon,CartesianIndex,AbstractArray{Int},AbstractArray{Bool}}...) where {T, N} + v = view(VA, idxs...) + # error message copied from Base by running `ones(3, 3, 3)[:, 2, :] = 2` + if length(v) != length(x) + throw(ArgumentError("indexed assignment with a single value to possibly many locations is not supported; perhaps use broadcasting `.=` instead?")) + end + for (i, j) in zip(eachindex(v), eachindex(x)) + v[i] = x[j] + end + return x +end + # Interface for the two-dimensional indexing, a more standard AbstractArray interface @inline Base.size(VA::AbstractVectorOfArray) = (size(VA.u[1])..., length(VA.u)) @inline Base.size(VA::AbstractVectorOfArray, i) = size(VA)[i] @@ -534,14 +546,11 @@ function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:Abstra return checkbounds(Bool, VA.u, idxs...) end function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...) - if checkbounds(Bool, VA.u, last(idx)) - if last(idx) isa Integer - return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)...)) - else - return all(checkbounds.(Bool, VA.u[last(idx)], tuple.(Base.front(idx))...)) - end + checkbounds(Bool, VA.u, last(idx)) || return false + for i in last(idx) + checkbounds(Bool, VA.u[i], Base.front(idx)...) || return false end - return false + return true end function Base.checkbounds(VA::AbstractVectorOfArray, idx...) checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx)) @@ -549,6 +558,12 @@ end function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N} copyto!.(dest.u, src.u) end +# Required for broadcasted setindex! when slicing across subarrays +# E.g. if `va = VectorOfArray([rand(3, 3) for i in 1:5])` +# Need this method for `va[2, :, :] .= 3.0` +Base.@propagate_inbounds function Base.maybeview(A::AbstractVectorOfArray, I...) + return view(A, I...) +end # Operations function Base.isapprox(A::AbstractVectorOfArray, @@ -619,7 +634,7 @@ function Base.fill!(VA::AbstractVectorOfArray, x) return VA end -Base.reshape(A::VectorOfArray, dims...) = Base.reshape(Array(A), dims...) +Base.reshape(A::AbstractVectorOfArray, dims...) = Base.reshape(Array(A), dims...) # Need this for ODE_DEFAULT_UNSTABLE_CHECK from DiffEqBase to work properly @inline Base.any(f, VA::AbstractVectorOfArray) = any(any(f, u) for u in VA.u) @@ -633,7 +648,7 @@ function Base.convert(::Type{Array}, VA::AbstractVectorOfArray) if !allequal(size.(VA.u)) error("Can only convert non-ragged VectorOfArray to Array") end - return stack(VA) + return Array(VA) end # statistics diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index 7a5acbe6..5abaee07 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -191,6 +191,32 @@ w = v .+ 1 @test_broken w isa DiffEqArray # FIXME @test w.u == map(x -> x .+ 1, v.u) +# setindex! +testva = VectorOfArray([i * ones(3, 3) for i in 1:5]) +testva[:, 2] = 7ones(3, 3) +@test testva[:, 2] == 7ones(3, 3) +testva[:, :] = [2i * ones(3, 3) for i in 1:5] +for i in 1:5 + @test testva[:, i] == 2i * ones(3, 3) +end +testva[:, 1:2:5] = [5i * ones(3, 3) for i in 1:2:5] +for i in 1:2:5 + @test testva[:, i] == 5i * ones(3, 3) +end +testva[CartesianIndex(3, 3, 5)] = 64.0 +@test testva[:, 5][3, 3] == 64.0 +@test_throws ArgumentError testva[2, 1:2, :] = 108.0 +testva[2, 1:2, :] .= 108.0 +for i in 1:5 + @test all(testva[:, i][2, 1:2] .== 108.0) +end +testva[:, 3, :] = [3i / 7j for i in 1:3, j in 1:5] +for j in 1:5 + for i in 1:3 + @test testva[i, 3, j] == 3i / 7j + end +end + # edges cases x = [1, 2, 3, 4, 5, 6, 7, 8, 9] testva = DiffEqArray(x, x)