Skip to content

Commit

Permalink
Merge pull request #359 from jlchan/jc/VectorOfArray_multidim_helper
Browse files Browse the repository at this point in the history
Specialize `Base.similar` for `VectorOfArray` with multidimensional parent
  • Loading branch information
ChrisRackauckas authored Feb 22, 2024
2 parents 239bd04 + 43072c8 commit b10cddb
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 57 deletions.
17 changes: 16 additions & 1 deletion src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,16 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
VectorOfArray{T, N + 1, typeof(vec)}(vec)
end

# allow multi-dimensional arrays as long as they're linearly indexed
# allow multi-dimensional arrays as long as they're linearly indexed.
# currently restricted to arrays whose elements are all the same type
function VectorOfArray(array::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
@assert IndexStyle(typeof(array)) isa IndexLinear

return VectorOfArray{T, N + 1, typeof(array)}(array)
end

Base.parent(vec::VectorOfArray) = vec.u

function DiffEqArray(vec::AbstractVector{T},
ts::AbstractVector,
::NTuple{N, Int},
Expand Down Expand Up @@ -721,6 +724,18 @@ end
VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)])
end

# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
return VectorOfArray(similar(Base.parent(vec)))
end

# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractVector{<:AbstractArray{T}}}
return Base.similar(vec, eltype(vec))
end

# fill!
# For DiffEqArray it ignores ts and fills only u
function Base.fill!(VA::AbstractVectorOfArray, x)
Expand Down
4 changes: 4 additions & 0 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ foo!(u_matrix)
foo!(u_vector)
@test u_matrix u_vector

# test that, for VectorOfArray with multi-dimensional parent arrays,
# `similar` preserves the structure of the parent array
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))

# test efficiency
num_allocs = @allocations foo!(u_matrix)
@test num_allocs == 0
4 changes: 3 additions & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

Expand All @@ -10,4 +11,5 @@ ModelingToolkit = "8.33"
MonteCarloMeasurements = "1.1"
OrdinaryDiffEq = "6.31"
Unitful = "1.17"
Tracker = "0.2"
Tracker = "0.2"
StaticArrays = "1"
13 changes: 12 additions & 1 deletion test/upstream.jl → test/downstream/odesolve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface
using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface, StaticArrays
function lorenz(du, u, p, t)
du[1] = 10.0 * (u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
Expand Down Expand Up @@ -49,3 +49,14 @@ end
ArrayPartition(zeros(1), [0.75])),
(0.0, 1.0)),
Rodas5()).retcode == ReturnCode.Success

function rhs!(duu::VectorOfArray, uu::VectorOfArray, p, t)
du = parent(duu)
u = parent(uu)
du .= u
end

u = fill(SVector{2}(ones(2)), 2, 3)
ode = ODEProblem(rhs!, VectorOfArray(u), (0.0, 1.0))
sol = solve(ode, Tsit5())
@test SciMLBase.successful_retcode(sol)
72 changes: 18 additions & 54 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,67 +19,31 @@ end

@time begin
if GROUP == "Core" || GROUP == "All"
@time @safetestset "Quality Assurance" begin
include("qa.jl")
end
@time @safetestset "Utils Tests" begin
include("utils_test.jl")
end
@time @safetestset "NamedArrayPartition Tests" begin
include("named_array_partition_tests.jl")
end
@time @safetestset "Partitions Tests" begin
include("partitions_test.jl")
end
@time @safetestset "VecOfArr Indexing Tests" begin
include("basic_indexing.jl")
end
@time @safetestset "SymbolicIndexingInterface API test" begin
include("symbolic_indexing_interface_test.jl")
end
@time @safetestset "VecOfArr Interface Tests" begin
include("interface_tests.jl")
end
@time @safetestset "Table traits" begin
include("tabletraits.jl")
end
@time @safetestset "StaticArrays Tests" begin
include("copy_static_array_test.jl")
end
@time @safetestset "Linear Algebra Tests" begin
include("linalg.jl")
end
@time @safetestset "Upstream Tests" begin
include("upstream.jl")
end
@time @safetestset "Adjoint Tests" begin
include("adjoints.jl")
end
@time @safetestset "Measurement Tests" begin
include("measurements.jl")
end
@time @safetestset "Quality Assurance" include("qa.jl")
@time @safetestset "Utils Tests" include("utils_test.jl")
@time @safetestset "NamedArrayPartition Tests" include("named_array_partition_tests.jl")
@time @safetestset "Partitions Tests" include("partitions_test.jl")
@time @safetestset "VecOfArr Indexing Tests" include("basic_indexing.jl")
@time @safetestset "SymbolicIndexingInterface API test" include("symbolic_indexing_interface_test.jl")
@time @safetestset "VecOfArr Interface Tests" include("interface_tests.jl")
@time @safetestset "Table traits" include("tabletraits.jl")
@time @safetestset "StaticArrays Tests" include("copy_static_array_test.jl")
@time @safetestset "Linear Algebra Tests" include("linalg.jl")
@time @safetestset "Adjoint Tests" include("adjoints.jl")
@time @safetestset "Measurement Tests" include("measurements.jl")
end

if GROUP == "Downstream"
activate_downstream_env()
@time @safetestset "DiffEqArray Indexing Tests" begin
include("downstream/symbol_indexing.jl")
end
@time @safetestset "Event Tests with ArrayPartition" begin
include("downstream/downstream_events.jl")
end
@time @safetestset "Measurements and Units" begin
include("downstream/measurements_and_units.jl")
end
@time @safetestset "TrackerExt" begin
include("downstream/TrackerExt.jl")
end
@time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl")
@time @safetestset "ODE Solve Tests" include("downstream/odesolve.jl")
@time @safetestset "Event Tests with ArrayPartition" include("downstream/downstream_events.jl")
@time @safetestset "Measurements and Units" include("downstream/measurements_and_units.jl")
@time @safetestset "TrackerExt" include("downstream/TrackerExt.jl")
end

if GROUP == "GPU"
activate_gpu_env()
@time @safetestset "VectorOfArray GPU" begin
include("gpu/vectorofarray_gpu.jl")
end
@time @safetestset "VectorOfArray GPU" include("gpu/vectorofarray_gpu.jl")
end
end

0 comments on commit b10cddb

Please sign in to comment.