Skip to content

Commit

Permalink
Merge pull request #330 from AayushSabharwal/as/setindex
Browse files Browse the repository at this point in the history
fix: add setindex! for higher dimensional VoA, fix checkbounds allocations
  • Loading branch information
ChrisRackauckas authored Jan 9, 2024
2 parents bc59b23 + 477ba0a commit 5e1d5ee
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
33 changes: 24 additions & 9 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,18 @@ Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}
return VA.u[i][jj] = x
end

Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, x, idxs::Union{Int,Colon,CartesianIndex,AbstractArray{Int},AbstractArray{Bool}}...) where {T, N}
v = view(VA, idxs...)
# error message copied from Base by running `ones(3, 3, 3)[:, 2, :] = 2`
if length(v) != length(x)
throw(ArgumentError("indexed assignment with a single value to possibly many locations is not supported; perhaps use broadcasting `.=` instead?"))
end
for (i, j) in zip(eachindex(v), eachindex(x))
v[i] = x[j]
end
return x
end

# Interface for the two-dimensional indexing, a more standard AbstractArray interface
@inline Base.size(VA::AbstractVectorOfArray) = (size(VA.u[1])..., length(VA.u))
@inline Base.size(VA::AbstractVectorOfArray, i) = size(VA)[i]
Expand Down Expand Up @@ -534,21 +546,24 @@ function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:Abstra
return checkbounds(Bool, VA.u, idxs...)
end
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
if checkbounds(Bool, VA.u, last(idx))
if last(idx) isa Integer
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)...))
else
return all(checkbounds.(Bool, VA.u[last(idx)], tuple.(Base.front(idx))...))
end
checkbounds(Bool, VA.u, last(idx)) || return false
for i in last(idx)
checkbounds(Bool, VA.u[i], Base.front(idx)...) || return false
end
return false
return true
end
function Base.checkbounds(VA::AbstractVectorOfArray, idx...)
checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx))
end
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N}
copyto!.(dest.u, src.u)
end
# Required for broadcasted setindex! when slicing across subarrays
# E.g. if `va = VectorOfArray([rand(3, 3) for i in 1:5])`
# Need this method for `va[2, :, :] .= 3.0`
Base.@propagate_inbounds function Base.maybeview(A::AbstractVectorOfArray, I...)
return view(A, I...)
end

# Operations
function Base.isapprox(A::AbstractVectorOfArray,
Expand Down Expand Up @@ -619,7 +634,7 @@ function Base.fill!(VA::AbstractVectorOfArray, x)
return VA
end

Base.reshape(A::VectorOfArray, dims...) = Base.reshape(Array(A), dims...)
Base.reshape(A::AbstractVectorOfArray, dims...) = Base.reshape(Array(A), dims...)

# Need this for ODE_DEFAULT_UNSTABLE_CHECK from DiffEqBase to work properly
@inline Base.any(f, VA::AbstractVectorOfArray) = any(any(f, u) for u in VA.u)
Expand All @@ -633,7 +648,7 @@ function Base.convert(::Type{Array}, VA::AbstractVectorOfArray)
if !allequal(size.(VA.u))
error("Can only convert non-ragged VectorOfArray to Array")
end
return stack(VA)
return Array(VA)
end

# statistics
Expand Down
26 changes: 26 additions & 0 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,32 @@ w = v .+ 1
@test_broken w isa DiffEqArray # FIXME
@test w.u == map(x -> x .+ 1, v.u)

# setindex!
testva = VectorOfArray([i * ones(3, 3) for i in 1:5])
testva[:, 2] = 7ones(3, 3)
@test testva[:, 2] == 7ones(3, 3)
testva[:, :] = [2i * ones(3, 3) for i in 1:5]
for i in 1:5
@test testva[:, i] == 2i * ones(3, 3)
end
testva[:, 1:2:5] = [5i * ones(3, 3) for i in 1:2:5]
for i in 1:2:5
@test testva[:, i] == 5i * ones(3, 3)
end
testva[CartesianIndex(3, 3, 5)] = 64.0
@test testva[:, 5][3, 3] == 64.0
@test_throws ArgumentError testva[2, 1:2, :] = 108.0
testva[2, 1:2, :] .= 108.0
for i in 1:5
@test all(testva[:, i][2, 1:2] .== 108.0)
end
testva[:, 3, :] = [3i / 7j for i in 1:3, j in 1:5]
for j in 1:5
for i in 1:3
@test testva[i, 3, j] == 3i / 7j
end
end

# edges cases
x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
testva = DiffEqArray(x, x)
Expand Down

0 comments on commit 5e1d5ee

Please sign in to comment.