From 758ac849f12b7e0d1f8ffceef8fe35d17265c65a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 21 Feb 2024 04:48:43 -0500 Subject: [PATCH] format --- docs/pages.jl | 2 +- ext/RecursiveArrayToolsFastBroadcastExt.jl | 10 +- ext/RecursiveArrayToolsMeasurementsExt.jl | 2 +- ...siveArrayToolsMonteCarloMeasurementsExt.jl | 4 +- ext/RecursiveArrayToolsReverseDiffExt.jl | 12 +- ext/RecursiveArrayToolsTrackerExt.jl | 12 +- ext/RecursiveArrayToolsZygoteExt.jl | 190 ++++++++----- src/RecursiveArrayTools.jl | 8 +- src/array_partition.jl | 100 ++++--- src/named_array_partition.jl | 80 +++--- src/tabletraits.jl | 6 +- src/utils.jl | 42 +-- src/vector_of_array.jl | 268 ++++++++++-------- test/adjoints.jl | 5 +- test/basic_indexing.jl | 12 +- test/interface_tests.jl | 45 +-- test/named_array_partition_tests.jl | 17 +- test/partitions_test.jl | 17 +- test/runtests.jl | 10 +- test/symbolic_indexing_interface_test.jl | 4 +- test/testutils.jl | 2 +- test/upstream.jl | 12 +- test/utils_test.jl | 14 +- 23 files changed, 491 insertions(+), 383 deletions(-) diff --git a/docs/pages.jl b/docs/pages.jl index 9ad69ca9..5d9ec136 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -3,5 +3,5 @@ pages = [ "Home" => "index.md", "array_types.md", - "recursive_array_functions.md", + "recursive_array_functions.md" ] diff --git a/ext/RecursiveArrayToolsFastBroadcastExt.jl b/ext/RecursiveArrayToolsFastBroadcastExt.jl index fd2a6e74..a2f840c8 100644 --- a/ext/RecursiveArrayToolsFastBroadcastExt.jl +++ b/ext/RecursiveArrayToolsFastBroadcastExt.jl @@ -4,13 +4,17 @@ using RecursiveArrayTools using FastBroadcast using StaticArraysCore -const AbstractVectorOfSArray = AbstractVectorOfArray{T,N,<:AbstractVector{<:StaticArraysCore.SArray}} where {T,N} +const AbstractVectorOfSArray = AbstractVectorOfArray{ + T, N, <:AbstractVector{<:StaticArraysCore.SArray}} where {T, N} -@inline function FastBroadcast.fast_materialize!(::FastBroadcast.Static.False, ::DB, dst::AbstractVectorOfSArray, bc::Broadcast.Broadcasted{S}) where {S,DB} +@inline function FastBroadcast.fast_materialize!( + ::FastBroadcast.Static.False, ::DB, dst::AbstractVectorOfSArray, + bc::Broadcast.Broadcasted{S}) where {S, DB} if FastBroadcast.use_fast_broadcast(S) for i in 1:length(dst.u) unpacked = RecursiveArrayTools.unpack_voa(bc, i) - dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(unpacked[j] for j in eachindex(unpacked)) + dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(unpacked[j] + for j in eachindex(unpacked)) end else Broadcast.materialize!(dst, bc) diff --git a/ext/RecursiveArrayToolsMeasurementsExt.jl b/ext/RecursiveArrayToolsMeasurementsExt.jl index d631655e..4b18430e 100644 --- a/ext/RecursiveArrayToolsMeasurementsExt.jl +++ b/ext/RecursiveArrayToolsMeasurementsExt.jl @@ -4,7 +4,7 @@ import RecursiveArrayTools isdefined(Base, :get_extension) ? (import Measurements) : (import ..Measurements) function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{ - <:Measurements.Measurement, + <:Measurements.Measurement, }) typeof(oneunit(a)) end diff --git a/ext/RecursiveArrayToolsMonteCarloMeasurementsExt.jl b/ext/RecursiveArrayToolsMonteCarloMeasurementsExt.jl index 35250da6..7237a2da 100644 --- a/ext/RecursiveArrayToolsMonteCarloMeasurementsExt.jl +++ b/ext/RecursiveArrayToolsMonteCarloMeasurementsExt.jl @@ -5,13 +5,13 @@ isdefined(Base, :get_extension) ? (import MonteCarloMeasurements) : (import ..MonteCarloMeasurements) function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{ - <:MonteCarloMeasurements.Particles, + <:MonteCarloMeasurements.Particles, }) typeof(one(a)) end function RecursiveArrayTools.recursive_unitless_eltype(a::Type{ - <:MonteCarloMeasurements.Particles, + <:MonteCarloMeasurements.Particles, }) typeof(one(a)) end diff --git a/ext/RecursiveArrayToolsReverseDiffExt.jl b/ext/RecursiveArrayToolsReverseDiffExt.jl index 115949a1..79b59418 100644 --- a/ext/RecursiveArrayToolsReverseDiffExt.jl +++ b/ext/RecursiveArrayToolsReverseDiffExt.jl @@ -5,13 +5,13 @@ using ReverseDiff using Zygote: @adjoint function trackedarraycopyto!(dest, src) - for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims=ndims(src))) - if dest.u[i] isa AbstractArray - dest.u[i] = reshape(reduce(vcat, slice), size(dest.u[i])) - else - trackedarraycopyto!(dest.u[i], slice) - end + for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims = ndims(src))) + if dest.u[i] isa AbstractArray + dest.u[i] = reshape(reduce(vcat, slice), size(dest.u[i])) + else + trackedarraycopyto!(dest.u[i], slice) end + end end @adjoint function Array(VA::AbstractVectorOfArray{<:ReverseDiff.TrackedReal}) diff --git a/ext/RecursiveArrayToolsTrackerExt.jl b/ext/RecursiveArrayToolsTrackerExt.jl index 845aff72..75d05ba7 100644 --- a/ext/RecursiveArrayToolsTrackerExt.jl +++ b/ext/RecursiveArrayToolsTrackerExt.jl @@ -4,12 +4,12 @@ import RecursiveArrayTools isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker) function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N}, - a::AbstractArray{T2, N}) where { - T <: - Tracker.TrackedArray, - T2 <: - Tracker.TrackedArray, - N} + a::AbstractArray{T2, N}) where { + T <: + Tracker.TrackedArray, + T2 <: + Tracker.TrackedArray, + N} @inbounds for i in eachindex(a) b[i] = copy(a[i]) end diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 0b75593f..a5c9f4a5 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -13,13 +13,14 @@ end # Define a new species of projection operator for this type: # ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() -function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray}, - xs::AbstractVectorOfArray) +function ChainRulesCore.rrule( + T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray}, + xs::AbstractVectorOfArray) T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ) end @adjoint function getindex(VA::AbstractVectorOfArray, - i::Union{BitArray, AbstractArray{Bool}}) + i::Union{BitArray, AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] @@ -46,8 +47,8 @@ end end @adjoint function getindex(VA::AbstractVectorOfArray, i::Int, - j::Union{Int, AbstractArray{Int}, CartesianIndex, - Colon, BitArray, AbstractArray{Bool}}...) + j::Union{Int, AbstractArray{Int}, CartesianIndex, + Colon, BitArray, AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))]) if isempty(j) @@ -61,10 +62,10 @@ end end @adjoint function ArrayPartition(x::S, - ::Type{Val{copy_x}} = Val{false}) where { - S <: - Tuple, - copy_x, + ::Type{Val{copy_x}} = Val{false}) where { + S <: + Tuple, + copy_x } function ArrayPartition_adjoint(_y) y = Array(_y) @@ -87,16 +88,19 @@ end @adjoint function Base.copy(u::VectorOfArray) copy(u), - y -> (copy(y),) + y -> (copy(y),) end @adjoint function DiffEqArray(u, t) DiffEqArray(u, t), y -> begin y isa Ref && (y = VectorOfArray(y[].u)) - (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] - for i in 1:size(y)[end]], - t), nothing) + ( + DiffEqArray( + [y[ntuple(x -> Colon(), ndims(y) - 1)..., i] + for i in 1:size(y)[end]], + t), + nothing) end end @@ -108,7 +112,7 @@ end end @adjoint function Base.Array(VA::AbstractVectorOfArray) - adj = let VA=VA + adj = let VA = VA function Array_adjoint(y) VA = recursivecopy(VA) copyto!(VA, y) @@ -135,133 +139,162 @@ end view(A, I...), view_adjoint end -ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) +function ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) + ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) +end -function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{AbstractArray,AbstractVectorOfArray}) +function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ + AbstractArray, AbstractVectorOfArray}) arr = reshape(x, p.sz) return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) end -@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray}) +@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, + y::Union{Zygote.Numeric, AbstractVectorOfArray}) broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...) end -@adjoint function Broadcast.broadcasted(::typeof(+), x::Zygote.Numeric, y::AbstractVectorOfArray) +@adjoint function Broadcast.broadcasted( + ::typeof(+), x::Zygote.Numeric, y::AbstractVectorOfArray) broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...) end _minus(Δ) = .-Δ _minus(::Nothing) = nothing -@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric}) +@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray, + y::Union{AbstractVectorOfArray, Zygote.Numeric}) x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ))) end -@adjoint function Broadcast.broadcasted(::typeof(*), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric}) +@adjoint function Broadcast.broadcasted(::typeof(*), x::AbstractVectorOfArray, + y::Union{AbstractVectorOfArray, Zygote.Numeric}) ( - x.*y, - Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x))) + x .* y, + Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), + Zygote.unbroadcast(y, Δ .* conj.(x))) ) end -@adjoint function Broadcast.broadcasted(::typeof(/), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric}) - res = x ./ y - res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y))) +@adjoint function Broadcast.broadcasted(::typeof(/), x::AbstractVectorOfArray, + y::Union{AbstractVectorOfArray, Zygote.Numeric}) + res = x ./ y + res, + Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), + Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y))) end -@adjoint function Broadcast.broadcasted(::typeof(-), x::Zygote.Numeric, y::AbstractVectorOfArray) +@adjoint function Broadcast.broadcasted( + ::typeof(-), x::Zygote.Numeric, y::AbstractVectorOfArray) x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ))) end -@adjoint function Broadcast.broadcasted(::typeof(*), x::Zygote.Numeric, y::AbstractVectorOfArray) +@adjoint function Broadcast.broadcasted( + ::typeof(*), x::Zygote.Numeric, y::AbstractVectorOfArray) ( - x.*y, - Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x))) + x .* y, + Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), + Zygote.unbroadcast(y, Δ .* conj.(x))) ) end -@adjoint function Broadcast.broadcasted(::typeof(/), x::Zygote.Numeric, y::AbstractVectorOfArray) - res = x ./ y - res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y))) +@adjoint function Broadcast.broadcasted( + ::typeof(/), x::Zygote.Numeric, y::AbstractVectorOfArray) + res = x ./ y + res, + Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), + Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y))) end @adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray) .-x, Δ -> (nothing, _minus(Δ)) end -@adjoint function Broadcast.broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::AbstractVectorOfArray, exp::Val{p}) where p - y = Base.literal_pow.(^, x, exp) - y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing) +@adjoint function Broadcast.broadcasted(::typeof(Base.literal_pow), ::typeof(^), + x::AbstractVectorOfArray, exp::Val{p}) where {p} + y = Base.literal_pow.(^, x, exp) + y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing) end -@adjoint Broadcast.broadcasted(::typeof(identity), x::AbstractVectorOfArray) = x, Δ -> (nothing, Δ) +@adjoint Broadcast.broadcasted(::typeof(identity), x::AbstractVectorOfArray) = x, +Δ -> (nothing, Δ) @adjoint function Broadcast.broadcasted(::typeof(tanh), x::AbstractVectorOfArray) - y = tanh.(x) - y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2)) + y = tanh.(x) + y, ȳ -> (nothing, ȳ .* conj.(1 .- y .^ 2)) end -@adjoint Broadcast.broadcasted(::typeof(conj), x::AbstractVectorOfArray) = - conj.(x), z̄ -> (nothing, conj.(z̄)) +@adjoint Broadcast.broadcasted(::typeof(conj), x::AbstractVectorOfArray) = conj.(x), +z̄ -> (nothing, conj.(z̄)) -@adjoint Broadcast.broadcasted(::typeof(real), x::AbstractVectorOfArray) = - real.(x), z̄ -> (nothing, real.(z̄)) +@adjoint Broadcast.broadcasted(::typeof(real), x::AbstractVectorOfArray) = real.(x), +z̄ -> (nothing, real.(z̄)) -@adjoint Broadcast.broadcasted(::typeof(imag), x::AbstractVectorOfArray) = - imag.(x), z̄ -> (nothing, im .* real.(z̄)) +@adjoint Broadcast.broadcasted(::typeof(imag), x::AbstractVectorOfArray) = imag.(x), +z̄ -> (nothing, im .* real.(z̄)) -@adjoint Broadcast.broadcasted(::typeof(abs2), x::AbstractVectorOfArray) = - abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) +@adjoint Broadcast.broadcasted(::typeof(abs2), x::AbstractVectorOfArray) = abs2.(x), +z̄ -> (nothing, 2 .* real.(z̄) .* x) -@adjoint function Broadcast.broadcasted(::typeof(+), a::AbstractVectorOfArray{<:Number}, b::Bool) - y = b === false ? a : a .+ b - y, Δ -> (nothing, Δ, nothing) +@adjoint function Broadcast.broadcasted( + ::typeof(+), a::AbstractVectorOfArray{<:Number}, b::Bool) + y = b === false ? a : a .+ b + y, Δ -> (nothing, Δ, nothing) end -@adjoint function Broadcast.broadcasted(::typeof(+), b::Bool, a::AbstractVectorOfArray{<:Number}) - y = b === false ? a : b .+ a - y, Δ -> (nothing, nothing, Δ) +@adjoint function Broadcast.broadcasted( + ::typeof(+), b::Bool, a::AbstractVectorOfArray{<:Number}) + y = b === false ? a : b .+ a + y, Δ -> (nothing, nothing, Δ) end -@adjoint function Broadcast.broadcasted(::typeof(-), a::AbstractVectorOfArray{<:Number}, b::Bool) - y = b === false ? a : a .- b - y, Δ -> (nothing, Δ, nothing) +@adjoint function Broadcast.broadcasted( + ::typeof(-), a::AbstractVectorOfArray{<:Number}, b::Bool) + y = b === false ? a : a .- b + y, Δ -> (nothing, Δ, nothing) end -@adjoint function Broadcast.broadcasted(::typeof(-), b::Bool, a::AbstractVectorOfArray{<:Number}) - b .- a, Δ -> (nothing, nothing, .-Δ) +@adjoint function Broadcast.broadcasted( + ::typeof(-), b::Bool, a::AbstractVectorOfArray{<:Number}) + b .- a, Δ -> (nothing, nothing, .-Δ) end -@adjoint function Broadcast.broadcasted(::typeof(*), a::AbstractVectorOfArray{<:Number}, b::Bool) - if b === false - zero(a), Δ -> (nothing, zero(Δ), nothing) - else - a, Δ -> (nothing, Δ, nothing) - end +@adjoint function Broadcast.broadcasted( + ::typeof(*), a::AbstractVectorOfArray{<:Number}, b::Bool) + if b === false + zero(a), Δ -> (nothing, zero(Δ), nothing) + else + a, Δ -> (nothing, Δ, nothing) + end end -@adjoint function Broadcast.broadcasted(::typeof(*), b::Bool, a::AbstractVectorOfArray{<:Number}) - if b === false - zero(a), Δ -> (nothing, nothing, zero(Δ)) - else - a, Δ -> (nothing, nothing, Δ) - end +@adjoint function Broadcast.broadcasted( + ::typeof(*), b::Bool, a::AbstractVectorOfArray{<:Number}) + if b === false + zero(a), Δ -> (nothing, nothing, zero(Δ)) + else + a, Δ -> (nothing, nothing, Δ) + end end -@adjoint Broadcast.broadcasted(::Type{T}, x::AbstractVectorOfArray) where {T<:Number} = - T.(x), ȳ -> (nothing, Zygote._project(x, ȳ),) +@adjoint Broadcast.broadcasted(::Type{T}, x::AbstractVectorOfArray) where {T <: Number} = T.(x), +ȳ -> (nothing, Zygote._project(x, ȳ)) function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄) N = ndims(x̄) if length(x) == length(x̄) Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors else - dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄)) + dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄) + 1, ndims(x̄)) Zygote._project(x, Zygote.accum_sum(x̄; dims = dims)) end end -@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b) where {F} = _broadcast_generic(__context__, f, a, b) -@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b) -@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b) +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b) where {F} = _broadcast_generic( + __context__, f, a, b) +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a, b::AbstractVectorOfArray) where {F} = _broadcast_generic( + __context__, f, a, b) +@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b::AbstractVectorOfArray) where {F} = _broadcast_generic( + __context__, f, a, b) @inline function _broadcast_generic(__context__, f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) # Avoid generic broadcasting in two easy cases: if T == Bool return (f.(args...), _ -> nothing) - elseif T <: Union{Real, Complex} && isconcretetype(T) && Zygote._dual_purefun(F) && all(Zygote._dual_safearg, args) && !Zygote.isderiving() - return Zygote.broadcast_forward(f, args...) + elseif T <: Union{Real, Complex} && isconcretetype(T) && Zygote._dual_purefun(F) && + all(Zygote._dual_safearg, args) && !Zygote.isderiving() + return Zygote.broadcast_forward(f, args...) end len = Zygote.inclen(args) y∂b = Zygote._broadcast((x...) -> Zygote._pullback(__context__, f, x...), args...) @@ -272,7 +305,8 @@ end dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ) getters = ntuple(i -> Zygote.StaticGetter{i}(), len) dxs = map(g -> Zygote.collapse_nothings(map(g, dxs_zip)), getters) - (nothing, Zygote.accum_sum(dxs[1]), map(Zygote.unbroadcast, args, Base.tail(dxs))...) + (nothing, Zygote.accum_sum(dxs[1]), + map(Zygote.unbroadcast, args, Base.tail(dxs))...) end return y, ∇broadcasted end diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 5f2e7b5c..71cc01f9 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -6,7 +6,7 @@ module RecursiveArrayTools using DocStringExtensions using RecipesBase, StaticArraysCore, Statistics, - ArrayInterface, LinearAlgebra + ArrayInterface, LinearAlgebra using SymbolicIndexingInterface using SparseArrays @@ -32,11 +32,11 @@ Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = (T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray) = T(Array(VA)) export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray, - AllObserved, vecarr_to_vectors, tuples + AllObserved, vecarr_to_vectors, tuples export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_push!, - vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype, - recursive_unitless_bottom_eltype, recursive_unitless_eltype + vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype, + recursive_unitless_bottom_eltype, recursive_unitless_eltype export ArrayPartition, NamedArrayPartition diff --git a/src/array_partition.jl b/src/array_partition.jl index 1a194d7a..e8fe296e 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -92,13 +92,13 @@ Base.zero(A::ArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) Base.Array(A::ArrayPartition) = reduce(vcat, Array.(A.x)) function Base.Array(VA::AbstractVectorOfArray{ - T, - N, - A, + T, + N, + A }) where {T, N, - A <: AbstractVector{ - <:ArrayPartition, - }} + A <: AbstractVector{ + <:ArrayPartition, + }} reduce(hcat, Array.(VA.u)) end @@ -251,7 +251,7 @@ end # workaround for https://github.com/SciML/RecursiveArrayTools.jl/issues/49 function Base._unsafe_getindex(::IndexStyle, A::ArrayPartition, - I::Vararg{Union{Real, AbstractArray}, N}) where {N} + I::Vararg{Union{Real, AbstractArray}, N}) where {N} # This is specifically not inlined to prevent excessive allocations in type unstable code shape = Base.index_shape(I...) dest = similar(A.x[1], shape) @@ -260,8 +260,8 @@ function Base._unsafe_getindex(::IndexStyle, A::ArrayPartition, end function Base._maybe_reshape(::IndexCartesian, - A::ArrayPartition, - I::Vararg{Union{Real, AbstractArray}, N}) where {N} + A::ArrayPartition, + I::Vararg{Union{Real, AbstractArray}, N}) where {N} Vector(A) end @@ -308,25 +308,27 @@ end # promotion rules @inline function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, - ::ArrayPartitionStyle{BStyle}) where {AStyle, - BStyle} + ::ArrayPartitionStyle{BStyle}) where {AStyle, + BStyle} ArrayPartitionStyle(Broadcast.BroadcastStyle(AStyle(), BStyle())) end function Broadcast.BroadcastStyle(::ArrayPartitionStyle{Style}, - ::Broadcast.DefaultArrayStyle{0}) where { - Style <: - Broadcast.BroadcastStyle, + ::Broadcast.DefaultArrayStyle{0}) where { + Style <: + Broadcast.BroadcastStyle, } ArrayPartitionStyle{Style}() end function Broadcast.BroadcastStyle(::ArrayPartitionStyle, - ::Broadcast.DefaultArrayStyle{N}) where {N} + ::Broadcast.DefaultArrayStyle{N}) where {N} Broadcast.DefaultArrayStyle{N}() end combine_styles(::Type{Tuple{}}) = Broadcast.DefaultArrayStyle{0}() -combine_styles(::Type{T}) where {T} = Broadcast.result_style(Broadcast.BroadcastStyle(T.parameters[1]), combine_styles(Tuple{Base.tail((T.parameters...,))...})) - +function combine_styles(::Type{T}) where {T} + Broadcast.result_style(Broadcast.BroadcastStyle(T.parameters[1]), + combine_styles(Tuple{Base.tail((T.parameters...,))...})) +end function Broadcast.BroadcastStyle(::Type{ArrayPartition{T, S}}) where {T, S} Style = combine_styles(S) @@ -334,9 +336,9 @@ function Broadcast.BroadcastStyle(::Type{ArrayPartition{T, S}}) where {T, S} end @inline function Base.copy(bc::Broadcast.Broadcasted{ - ArrayPartitionStyle{Style}, + ArrayPartitionStyle{Style}, }) where { - Style, + Style, } N = npartitions(bc) @inline function f(i) @@ -346,8 +348,8 @@ end end @inline function Base.copyto!(dest::ArrayPartition, - bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where { - Style, + bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where { + Style, } N = npartitions(dest, bc) @inline function f(i) @@ -381,15 +383,15 @@ _npartitions(args::Tuple{}) = 0 Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args)) end @inline function unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, - i) where {Style} + i) where {Style} Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args)) end @inline function unpack(bc::Broadcast.Broadcasted{Style}, - i) where {Style <: Broadcast.DefaultArrayStyle} + i) where {Style <: Broadcast.DefaultArrayStyle} Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) end @inline function unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, - i) where {Style <: Broadcast.DefaultArrayStyle} + i) where {Style <: Broadcast.DefaultArrayStyle} Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) end unpack(x, ::Any) = x @@ -413,35 +415,40 @@ end ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(Vector(A)) -function __get_subtypes_in_module(mod, supertype; include_supertype = true, all=false, except=[]) +function __get_subtypes_in_module( + mod, supertype; include_supertype = true, all = false, except = []) return filter([getproperty(mod, name) for name in names(mod; all) if !in(name, except)]) do value - return value != Union{} && value isa Type && (value <: supertype) && (include_supertype || value != supertype) && !in(value, except) - end + return value != Union{} && value isa Type && (value <: supertype) && + (include_supertype || value != supertype) && !in(value, except) + end end -for factorization in vcat(__get_subtypes_in_module(LinearAlgebra, Factorization; include_supertype = false, all=true, except=[:LU, :LAPACKFactorizations]), LDLt{T,<:SymTridiagonal{T,V} where {V<:AbstractVector{T}}} where {T}) - @eval function LinearAlgebra.ldiv!(A::T, b::ArrayPartition) where {T<:$factorization} +for factorization in vcat( + __get_subtypes_in_module(LinearAlgebra, Factorization; include_supertype = false, + all = true, except = [:LU, :LAPACKFactorizations]), + LDLt{T, <:SymTridiagonal{T, V} where {V <: AbstractVector{T}}} where {T}) + @eval function LinearAlgebra.ldiv!(A::T, b::ArrayPartition) where {T <: $factorization} (x = ldiv!(A, Array(b)); copyto!(b, x)) end end function LinearAlgebra.ldiv!( - A::LinearAlgebra.QRPivoted{T, <: StridedMatrix{T}, <: AbstractVector{T}}, - b::ArrayPartition{T}) where {T <: Union{Float32, Float64, ComplexF64, ComplexF32}} + A::LinearAlgebra.QRPivoted{T, <:StridedMatrix{T}, <:AbstractVector{T}}, + b::ArrayPartition{T}) where {T <: Union{Float32, Float64, ComplexF64, ComplexF32}} x = ldiv!(A, Array(b)) copyto!(b, x) end function LinearAlgebra.ldiv!(A::LinearAlgebra.QRCompactWY{T, M, C}, - b::ArrayPartition) where { - T <: Union{Float32, Float64, ComplexF64, ComplexF32}, - M <: AbstractMatrix{T}, - C <: AbstractMatrix{T}, + b::ArrayPartition) where { + T <: Union{Float32, Float64, ComplexF64, ComplexF32}, + M <: AbstractMatrix{T}, + C <: AbstractMatrix{T} } (x = ldiv!(A, Array(b)); copyto!(b, x)) end -for type in [LU, LU{T,Tridiagonal{T,V}} where {T,V}] +for type in [LU, LU{T, Tridiagonal{T, V}} where {T, V}] @eval function LinearAlgebra.ldiv!(A::$type, b::ArrayPartition) LinearAlgebra._ipiv_rows!(A, 1:length(A.ipiv), b) ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), b)) @@ -463,11 +470,12 @@ end # [ 0 U22 U23] \ [ b2 ] # [ 0 0 U33] [ b3 ] for basetype in [UnitUpperTriangular, UpperTriangular, UnitLowerTriangular, LowerTriangular] - for type in [basetype, basetype{T, <:Adjoint{T}} where {T}, basetype{T, <:Transpose{T}} where {T}] + for type in [basetype, basetype{T, <:Adjoint{T}} where {T}, + basetype{T, <:Transpose{T}} where {T}] j_iter, i_iter = if basetype <: UnitUpperTriangular || basetype <: UpperTriangular - (:(n:-1:1), :(j-1:-1:1)) + (:(n:-1:1), :((j - 1):-1:1)) else - (:(1:n), :((j+1):n)) + (:(1:n), :((j + 1):n)) end @eval function LinearAlgebra.ldiv!(A::$type, bb::ArrayPartition) A = A.data @@ -529,19 +537,19 @@ function LinearAlgebra.mul!(C::ArrayPartition, A::ArrayPartition, B::ArrayPartit end function Base.convert(::Type{ArrayPartition{T, S}}, - A::ArrayPartition{<:Any, <:NTuple{N, Any}}) where {N, T, - S <: - NTuple{N, Any}} + A::ArrayPartition{<:Any, <:NTuple{N, Any}}) where {N, T, + S <: + NTuple{N, Any}} return ArrayPartition{T, S}(ntuple((@inline i -> convert(S.parameters[i], A.x[i])), Val(N))) end @generated function Base.length(::Type{ - <:ArrayPartition{F, T}, + <:ArrayPartition{F, T}, }) where {F, N, - T <: NTuple{N, - StaticArraysCore.StaticArray, - }} + T <: NTuple{N, + StaticArraysCore.StaticArray + }} sum_expr = Expr(:call, :+) for param in T.parameters push!(sum_expr.args, :(length($param))) diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index 873fc49d..ae16c66a 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -5,17 +5,18 @@ Similar to an `ArrayPartition` but the individual arrays can be accessed via the constructor-specified names. However, unlike `ArrayPartition`, each individual array must have the same element type. -""" -struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T} +""" +struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractVector{T} array_partition::A names_to_indices::NT end NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs)) -function NamedArrayPartition(x::NamedTuple) - names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x))) +function NamedArrayPartition(x::NamedTuple) + names_to_indices = NamedTuple(Pair(symbol, index) + for (index, symbol) in enumerate(keys(x))) # enforce homogeneity of eltypes - @assert all(eltype.(values(x)) .== eltype(first(x))) + @assert all(eltype.(values(x)) .== eltype(first(x))) T = eltype(first(x)) S = typeof(values(x)) return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices) @@ -27,61 +28,73 @@ ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition) Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x)) -Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} = +function Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices)) +end Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors - Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices)) -Base.getproperty(x::NamedArrayPartition, s::Symbol) = +function Base.getproperty(x::NamedArrayPartition, s::Symbol) getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s)) +end # this enables x.s = some_array. -@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v) +@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v) index = getproperty(getfield(x, :names_to_indices), s) ArrayPartition(x).x[index] .= v end # print out NamedArrayPartition as a NamedTuple Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:") -Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) = - show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x))) +function Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) + show( + io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x))) +end Base.size(x::NamedArrayPartition) = size(ArrayPartition(x)) Base.length(x::NamedArrayPartition) = length(ArrayPartition(x)) Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...) Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...) -Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) +function Base.map(f, x::NamedArrayPartition) + NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) +end Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x)) # Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x)) -Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} = - NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) +function Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} + NamedArrayPartition{T, S, NT}( + similar(ArrayPartition(x)), getfield(x, :names_to_indices)) +end # broadcasting -Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}() +function Base.BroadcastStyle(::Type{<:NamedArrayPartition}) + Broadcast.ArrayStyle{NamedArrayPartition}() +end function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, - ::Type{ElType}) where {ElType} + ::Type{ElType}) where {ElType} x = find_NamedArrayPartition(bc) return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) end # when broadcasting with ArrayPartition + another array type, the output is the other array tupe -Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) = +function Base.BroadcastStyle( + ::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) Broadcast.DefaultArrayStyle{1}() +end # hook into ArrayPartition broadcasting routines @inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x)) -@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = - Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args)) +@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = Broadcast.Broadcasted( + bc.f, RecursiveArrayTools.unpack_args(i, bc.args)) @inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i) -Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} = - NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices)) +function Base.copy(A::NamedArrayPartition{T, S, NT}) where {T, S, NT} + NamedArrayPartition{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices)) +end -@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function = - NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices) +@inline NamedArrayPartition(f::F, N, names_to_indices) where {F <: Function} = NamedArrayPartition( + ArrayPartition(ntuple(f, Val(N))), names_to_indices) @inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) N = npartitions(bc) @@ -92,23 +105,22 @@ Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} = NamedArrayPartition(f, N, getfield(x, :names_to_indices)) end -@inline function Base.copyto!(dest::NamedArrayPartition, - bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) - N = npartitions(dest, bc) - @inline function f(i) - copyto!(ArrayPartition(dest).x[i], unpack(bc, i)) - end - ntuple(f, Val(N)) - return dest +@inline function Base.copyto!(dest::NamedArrayPartition, + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) + N = npartitions(dest, bc) + @inline function f(i) + copyto!(ArrayPartition(dest).x[i], unpack(bc, i)) + end + ntuple(f, Val(N)) + return dest end # `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments. find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args) -find_NamedArrayPartition(args::Tuple) = +function find_NamedArrayPartition(args::Tuple) find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args)) +end find_NamedArrayPartition(x) = x find_NamedArrayPartition(::Tuple{}) = nothing find_NamedArrayPartition(x::NamedArrayPartition, rest) = x find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest) - - diff --git a/src/tabletraits.jl b/src/tabletraits.jl index cd6e9dd8..4fded222 100644 --- a/src/tabletraits.jl +++ b/src/tabletraits.jl @@ -9,13 +9,13 @@ function Tables.rows(A::AbstractDiffEqArray) :timestamp, (isempty(variable_symbols(A)) ? (Symbol("value", i) for i in 1:N) : - Symbol.(variable_symbols(A)))..., + Symbol.(variable_symbols(A)))... ] types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...] else names = [ :timestamp, - (isempty(variable_symbols(A)) ? :value : Symbol(variable_symbols(A)[1])), + (isempty(variable_symbols(A)) ? :value : Symbol(variable_symbols(A)[1])) ] types = Type[eltype(A.t), VT] end @@ -44,7 +44,7 @@ function Base.eltype(::Type{AbstractDiffEqArrayRows{T, U}}) where {T, U} AbstractDiffEqArrayRow{eltype(T), eltype(U)} end function Base.iterate(x::AbstractDiffEqArrayRows, - (t_state, u_state) = (iterate(x.t), iterate(x.u))) + (t_state, u_state) = (iterate(x.t), iterate(x.u))) t_state === nothing && return nothing u_state === nothing && return nothing t, _t_state = t_state diff --git a/src/utils.jl b/src/utils.jl index 4af362c9..a1b23208 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,7 +10,7 @@ function recursivecopy(a) deepcopy(a) end function recursivecopy(a::Union{StaticArraysCore.SVector, StaticArraysCore.SMatrix, - StaticArraysCore.SArray, Number}) + StaticArraysCore.SArray, Number}) copy(a) end function recursivecopy(a::AbstractArray{T, N}) where {T <: Number, N} @@ -42,9 +42,10 @@ like `copy!` on arrays of scalars. """ function recursivecopy! end -function recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T2, N}) where {T <: StaticArraysCore.StaticArray, - T2 <: StaticArraysCore.StaticArray, - N} +function recursivecopy!(b::AbstractArray{T, N}, + a::AbstractArray{T2, N}) where {T <: StaticArraysCore.StaticArray, + T2 <: StaticArraysCore.StaticArray, + N} @inbounds for i in eachindex(a) # TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19 b[i] = copy(a[i]) @@ -52,18 +53,18 @@ function recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T2, N}) wher end function recursivecopy!(b::AbstractArray{T, N}, - a::AbstractArray{T2, N}) where {T <: Enum, T2 <: Enum, N} + a::AbstractArray{T2, N}) where {T <: Enum, T2 <: Enum, N} copyto!(b, a) end function recursivecopy!(b::AbstractArray{T, N}, - a::AbstractArray{T2, N}) where {T <: Number, T2 <: Number, N} + a::AbstractArray{T2, N}) where {T <: Number, T2 <: Number, N} copyto!(b, a) end function recursivecopy!(b::AbstractArray{T, N}, - a::AbstractArray{T2, N}) where {T <: Union{AbstractArray, AbstractVectorOfArray}, - T2 <: Union{AbstractArray, AbstractVectorOfArray}, N} + a::AbstractArray{T2, N}) where {T <: Union{AbstractArray, AbstractVectorOfArray}, + T2 <: Union{AbstractArray, AbstractVectorOfArray}, N} if ArrayInterface.ismutable(T) @inbounds for i in eachindex(b, a) recursivecopy!(b[i], a[i]) @@ -95,32 +96,32 @@ A recursive `fill!` function. function recursivefill! end function recursivefill!(b::AbstractArray{T, N}, - a::T2) where {T <: StaticArraysCore.StaticArray, - T2 <: StaticArraysCore.StaticArray, N} + a::T2) where {T <: StaticArraysCore.StaticArray, + T2 <: StaticArraysCore.StaticArray, N} @inbounds for i in eachindex(b) b[i] = copy(a) end end function recursivefill!(bs::AbstractVectorOfArray{T, N}, - a::T2) where {T <: StaticArraysCore.StaticArray, - T2 <: StaticArraysCore.StaticArray, N} + a::T2) where {T <: StaticArraysCore.StaticArray, + T2 <: StaticArraysCore.StaticArray, N} @inbounds for b in bs, i in eachindex(b) b[i] = copy(a) end end function recursivefill!(b::AbstractArray{T, N}, - a::T2) where {T <: StaticArraysCore.SArray, - T2 <: Union{Number, Bool}, N} + a::T2) where {T <: StaticArraysCore.SArray, + T2 <: Union{Number, Bool}, N} @inbounds for i in eachindex(b) b[i] = fill(a, typeof(b[i])) end end function recursivefill!(bs::AbstractVectorOfArray{T, N}, - a::T2) where {T <: StaticArraysCore.SArray, - T2 <: Union{Number, Bool}, N} + a::T2) where {T <: StaticArraysCore.SArray, + T2 <: Union{Number, Bool}, N} @inbounds for b in bs, i in eachindex(b) b[i] = fill(a, typeof(b[i])) end @@ -132,13 +133,14 @@ for type in [AbstractArray, AbstractVectorOfArray] end @eval function recursivefill!(b::$type{T, N}, - a::T2) where {T <: Union{Number, Bool}, T2 <: Union{Number, Bool}, N + a::T2) where {T <: Union{Number, Bool}, T2 <: Union{Number, Bool}, N } fill!(b, a) end for type2 in [Any, StaticArraysCore.StaticArray] - @eval function recursivefill!(b::$type{T, N}, a::$type2) where {T <: StaticArraysCore.MArray, N} + @eval function recursivefill!( + b::$type{T, N}, a::$type2) where {T <: StaticArraysCore.MArray, N} @inbounds for i in eachindex(b) if isassigned(b, i) recursivefill!(b[i], a) @@ -149,7 +151,7 @@ for type in [AbstractArray, AbstractVectorOfArray] end end end - + @eval function recursivefill!(b::$type{T, N}, a) where {T <: AbstractArray, N} @inbounds for i in eachindex(b) recursivefill!(b[i], a) @@ -226,7 +228,7 @@ function copyat_or_push!(a::AbstractVector{T}, i::Int, x, perform_copy = true) w end function copyat_or_push!(a::AbstractVector{T}, i::Int, x, - nc::Type{Val{perform_copy}}) where {T, perform_copy} + nc::Type{Val{perform_copy}}) where {T, perform_copy} copyat_or_push!(a, i, x, perform_copy) end diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index c563ce61..3bb47200 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -64,60 +64,60 @@ struct AllObserved end function Base.Array(VA::AbstractVectorOfArray{ - T, - N, - A, + T, + N, + A }) where {T, N, - A <: AbstractVector{ - <:AbstractVector, - }} + A <: AbstractVector{ + <:AbstractVector, + }} reduce(hcat, VA.u) end function Base.Array(VA::AbstractVectorOfArray{ - T, - N, - A, + T, + N, + A }) where {T, N, - A <: - AbstractVector{<:Number}} + A <: + AbstractVector{<:Number}} VA.u end function Base.Matrix(VA::AbstractVectorOfArray{ - T, - N, - A, + T, + N, + A }) where {T, N, - A <: AbstractVector{ - <:AbstractVector, - }} + A <: AbstractVector{ + <:AbstractVector, + }} reduce(hcat, VA.u) end function Base.Matrix(VA::AbstractVectorOfArray{ - T, - N, - A, + T, + N, + A }) where {T, N, - A <: - AbstractVector{<:Number}} + A <: + AbstractVector{<:Number}} Matrix(VA.u) end function Base.Vector(VA::AbstractVectorOfArray{ - T, - N, - A, + T, + N, + A }) where {T, N, - A <: AbstractVector{ - <:AbstractVector, - }} + A <: AbstractVector{ + <:AbstractVector, + }} vec(reduce(hcat, VA.u)) end function Base.Vector(VA::AbstractVectorOfArray{ - T, - N, - A, + T, + N, + A }) where {T, N, - A <: - AbstractVector{<:Number}} + A <: + AbstractVector{<:Number}} VA.u end function Base.Array(VA::AbstractVectorOfArray) @@ -151,10 +151,10 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray end function DiffEqArray(vec::AbstractVector{T}, - ts::AbstractVector, - ::NTuple{N, Int}, - p = nothing, - sys = nothing) where {T, N} + ts::AbstractVector, + ::NTuple{N, Int}, + p = nothing, + sys = nothing) where {T, N} DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec, ts, p, @@ -163,16 +163,16 @@ end # ambiguity resolution function DiffEqArray(vec::AbstractVector{VT}, - ts::AbstractVector, - ::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}} + ts::AbstractVector, + ::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}} DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec, ts, nothing, nothing) end function DiffEqArray(vec::AbstractVector{VT}, - ts::AbstractVector, - ::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}} + ts::AbstractVector, + ::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}} DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec, ts, p, @@ -181,12 +181,12 @@ end # Assume that the first element is representative of all other elements function DiffEqArray(vec::AbstractVector, - ts::AbstractVector, - p = nothing, - sys = nothing; - variables = nothing, - parameters = nothing, - independent_variables = nothing) + ts::AbstractVector, + p = nothing, + sys = nothing; + variables = nothing, + parameters = nothing, + independent_variables = nothing) sys = something(sys, SymbolCache(something(variables, []), something(parameters, []), @@ -199,7 +199,7 @@ function DiffEqArray(vec::AbstractVector, typeof(vec), typeof(ts), typeof(p), - typeof(sys), + typeof(sys) }(vec, ts, p, @@ -207,22 +207,23 @@ function DiffEqArray(vec::AbstractVector, end function DiffEqArray(vec::AbstractVector{VT}, - ts::AbstractVector, - p = nothing, - sys = nothing; - variables = nothing, - parameters = nothing, - independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} - sys = something(sys, SymbolCache(something(variables, []), - something(parameters, []), - something(independent_variables, []))) + ts::AbstractVector, + p = nothing, + sys = nothing; + variables = nothing, + parameters = nothing, + independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} + sys = something(sys, + SymbolCache(something(variables, []), + something(parameters, []), + something(independent_variables, []))) return DiffEqArray{ eltype(eltype(vec)), N + 1, typeof(vec), typeof(ts), typeof(p), - typeof(sys), + typeof(sys) }(vec, ts, p, @@ -242,33 +243,40 @@ Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian() @inline function Base.eachindex(VA::AbstractVectorOfArray) return eachindex(VA.u) end -@inline function Base.eachindex(::IndexLinear, VA::AbstractVectorOfArray{T,N,<:AbstractVector{T}}) where {T, N} +@inline function Base.eachindex( + ::IndexLinear, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}}) where {T, N} return eachindex(IndexLinear(), VA.u) end @inline Base.IteratorSize(::Type{<:AbstractVectorOfArray}) = Base.HasLength() @inline Base.first(VA::AbstractVectorOfArray) = first(VA.u) @inline Base.last(VA::AbstractVectorOfArray) = last(VA.u) function Base.firstindex(VA::AbstractVectorOfArray) - Base.depwarn("Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ", :firstindex) + Base.depwarn( + "Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ", + :firstindex) return firstindex(VA.u) end function Base.lastindex(VA::AbstractVectorOfArray) - Base.depwarn("Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ", :lastindex) + Base.depwarn( + "Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ", + :lastindex) return lastindex(VA.u) end @deprecate Base.getindex(A::AbstractVectorOfArray, I::Int) Base.getindex(A, :, I) false -@deprecate Base.getindex(A::AbstractVectorOfArray, I::AbstractArray{Int}) Base.getindex(A, :, I) false +@deprecate Base.getindex(A::AbstractVectorOfArray, I::AbstractArray{Int}) Base.getindex( + A, :, I) false -@deprecate Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) Base.getindex(A, :, I) false +@deprecate Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) Base.getindex( + A, :, I) false @deprecate Base.getindex(A::AbstractDiffEqArray, i::Int) Base.getindex(A, :, i) false __parameterless_type(T) = Base.typename(T).wrapper Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N}, - ::NotSymbolic, I::Colon...) where {T, N} + ::NotSymbolic, I::Colon...) where {T, N} @assert length(I) == ndims(A.u[1]) + 1 vecs = if N == 1 A.u @@ -280,37 +288,43 @@ Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N}, end Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N}, - ::NotSymbolic, I::AbstractArray{Bool}, - J::Colon...) where {T, N} + ::NotSymbolic, I::AbstractArray{Bool}, + J::Colon...) where {T, N} @assert length(J) == ndims(A.u[1]) + 1 - ndims(I) @assert size(I) == size(A)[1:(ndims(A) - length(J))] return A[ntuple(x -> Colon(), ndims(A))...][I, J...] end # Need two of each methods to avoid ambiguities -Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int) +Base.@propagate_inbounds function _getindex( + A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int) A.u[I] end -Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, I::Union{Int,AbstractArray{Int},AbstractArray{Bool},Colon}...) +Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, + I::Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon}...) if last(I) isa Int A.u[last(I)][Base.front(I)...] else stack(getindex.(A.u[last(I)], tuple.(Base.front(I))...)) end end -Base.@propagate_inbounds function _getindex(VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex) +Base.@propagate_inbounds function _getindex( + VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex) ti = Tuple(ii) i = last(ti) jj = CartesianIndex(Base.front(ti)) return VA.u[i][jj] end -Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}}) +Base.@propagate_inbounds function _getindex( + A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, + I::Union{AbstractArray{Int}, AbstractArray{Bool}}) VectorOfArray(A.u[I]) end -Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}}) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, + I::Union{AbstractArray{Int}, AbstractArray{Bool}}) DiffEqArray(A.u[I], A.t[I], parameter_values(A), symbolic_container(A)) end @@ -336,7 +350,8 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb end end -Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...) +Base.@propagate_inbounds function _getindex( + A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...) if is_independent_variable(A, sym) return A.t[args...] elseif is_variable(A, sym) @@ -358,12 +373,13 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb end end - -Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...) +Base.@propagate_inbounds function _getindex( + A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...) return getindex(A, collect(sym), args...) end -Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}) +Base.@propagate_inbounds function _getindex( + A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}) if all(x -> is_parameter(A, x), sym) error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.") else @@ -371,15 +387,18 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb end end -Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}, args...) +Base.@propagate_inbounds function _getindex( + A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...) return reduce(vcat, map(s -> A[s, args...]', sym)) end -Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, ::SymbolicIndexingInterface.SolvedVariables, args...) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, + ::SymbolicIndexingInterface.SolvedVariables, args...) return getindex(A, variable_symbols(A), args...) end -Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, ::SymbolicIndexingInterface.AllVariables, args...) +Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, + ::SymbolicIndexingInterface.AllVariables, args...) return getindex(A, all_variable_symbols(A), args...) end @@ -394,7 +413,8 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, end end -Base.@propagate_inbounds function Base.getindex(A::Adjoint{T,<:AbstractVectorOfArray}, idxs...) where {T} +Base.@propagate_inbounds function Base.getindex( + A::Adjoint{T, <:AbstractVectorOfArray}, idxs...) where {T} return getindex(A.parent, reverse(to_indices(A, idxs))...) end @@ -409,42 +429,48 @@ function _observed(A::AbstractDiffEqArray{T, N}, sym, ::Colon) where {T, N} end Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v, - ::Colon, I::Int) where {T, N} + ::Colon, I::Int) where {T, N} VA.u[I] = v end @deprecate Base.setindex!(VA::AbstractVectorOfArray, v, I::Int) Base.setindex!(VA, v, :, I) false Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v, - ::Colon, I::Colon) where {T, N} + ::Colon, I::Colon) where {T, N} VA.u[I] = v end -@deprecate Base.setindex!(VA::AbstractVectorOfArray, v, I::Colon) Base.setindex!(VA, v, :, I) false +@deprecate Base.setindex!(VA::AbstractVectorOfArray, v, I::Colon) Base.setindex!( + VA, v, :, I) false Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v, - ::Colon, I::AbstractArray{Int}) where {T, N} + ::Colon, I::AbstractArray{Int}) where {T, N} VA.u[I] = v end -@deprecate Base.setindex!(VA::AbstractVectorOfArray, v, I::AbstractArray{Int}) Base.setindex!(VA, v, :, I) false +@deprecate Base.setindex!(VA::AbstractVectorOfArray, v, I::AbstractArray{Int}) Base.setindex!( + VA, v, :, I) false -Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v, i::Int, - ::Colon) where {T, N} +Base.@propagate_inbounds function Base.setindex!( + VA::AbstractVectorOfArray{T, N}, v, i::Int, + ::Colon) where {T, N} for j in 1:length(VA.u) VA.u[j][i] = v[j] end return v end Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, x, - ii::CartesianIndex) where {T, N} + ii::CartesianIndex) where {T, N} ti = Tuple(ii) i = last(ti) jj = CartesianIndex(Base.front(ti)) return VA.u[i][jj] = x end -Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, x, idxs::Union{Int,Colon,CartesianIndex,AbstractArray{Int},AbstractArray{Bool}}...) where {T, N} +Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, + x, + idxs::Union{Int, Colon, CartesianIndex, AbstractArray{Int}, AbstractArray{Bool}}...) where { + T, N} v = view(VA, idxs...) # error message copied from Base by running `ones(3, 3, 3)[:, 2, :] = 2` if length(v) != length(x) @@ -459,13 +485,13 @@ end # Interface for the two-dimensional indexing, a more standard AbstractArray interface @inline Base.size(VA::AbstractVectorOfArray) = (size(VA.u[1])..., length(VA.u)) @inline Base.size(VA::AbstractVectorOfArray, i) = size(VA)[i] -@inline Base.size(A::Adjoint{T,<:AbstractVectorOfArray}) where {T} = reverse(size(A.parent)) -@inline Base.size(A::Adjoint{T,<:AbstractVectorOfArray}, i) where {T} = size(A)[i] +@inline Base.size(A::Adjoint{T, <:AbstractVectorOfArray}) where {T} = reverse(size(A.parent)) +@inline Base.size(A::Adjoint{T, <:AbstractVectorOfArray}, i) where {T} = size(A)[i] Base.axes(VA::AbstractVectorOfArray) = Base.OneTo.(size(VA)) Base.axes(VA::AbstractVectorOfArray, d::Int) = Base.OneTo(size(VA)[d]) Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v, - I::Int...) where {T, N} + I::Int...) where {T, N} VA.u[I[end]][Base.front(I)...] = v end @@ -533,7 +559,7 @@ function Base.push!(VA::AbstractVectorOfArray{T, N}, new_item::AbstractArray) wh end function Base.append!(VA::AbstractVectorOfArray{T, N}, - new_item::AbstractVectorOfArray{T, N}) where {T, N} + new_item::AbstractVectorOfArray{T, N}) where {T, N} for item in copy(new_item) push!(VA, item) end @@ -545,31 +571,38 @@ function Base.stack(VA::AbstractVectorOfArray; dims = :) end # AbstractArray methods -function Base.view(A::AbstractVectorOfArray{T,N,<:AbstractVector{T}}, I::Vararg{Any, M}) where {T,N,M} +function Base.view(A::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, + I::Vararg{Any, M}) where {T, N, M} @inline if length(I) == 1 - J = map(i->Base.unalias(A,i), to_indices(A, I)) + J = map(i -> Base.unalias(A, i), to_indices(A, I)) elseif length(I) == 2 && (I[1] == Colon() || I[1] == 1) - J = map(i->Base.unalias(A,i), to_indices(A, Base.tail(I))) + J = map(i -> Base.unalias(A, i), to_indices(A, Base.tail(I))) end @boundscheck checkbounds(A, J...) SubArray(A, J) end -function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M} +function Base.view(A::AbstractVectorOfArray, I::Vararg{Any, M}) where {M} @inline - J = map(i->Base.unalias(A,i), to_indices(A, I)) + J = map(i -> Base.unalias(A, i), to_indices(A, I)) @boundscheck checkbounds(A, J...) SubArray(A, J) end function Base.SubArray(parent::AbstractVectorOfArray, indices::Tuple) @inline - SubArray(IndexStyle(Base.viewindexing(indices), IndexStyle(parent)), parent, Base.ensure_indexable(indices), Base.index_dimsum(indices...)) + SubArray(IndexStyle(Base.viewindexing(indices), IndexStyle(parent)), parent, + Base.ensure_indexable(indices), Base.index_dimsum(indices...)) end Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...) -Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing +function Base.check_parent_index_match( + ::RecursiveArrayTools.AbstractVectorOfArray{T, N}, ::NTuple{N, Bool}) where {T, N} + nothing +end Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N -function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, idxs...) where {T, N} +function Base.checkbounds( + ::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, + idxs...) where {T, N} if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1) return checkbounds(Bool, VA.u, idxs[2]) end @@ -585,7 +618,8 @@ end function Base.checkbounds(VA::AbstractVectorOfArray, idx...) checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx)) end -function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T2,N}) where {T, T2, N} +function Base.copyto!(dest::AbstractVectorOfArray{T, N}, + src::AbstractVectorOfArray{T2, N}) where {T, T2, N} for (i, j) in zip(eachindex(dest.u), eachindex(src.u)) if ArrayInterface.ismutable(dest.u[i]) || dest.u[i] isa AbstractVectorOfArray copyto!(dest.u[i], src.u[j]) @@ -594,7 +628,8 @@ function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArr end end end -function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T2, N}) where {T, T2, N} +function Base.copyto!( + dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T2, N}) where {T, T2, N} for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims = ndims(src))) if ArrayInterface.ismutable(dest.u[i]) || dest.u[i] isa AbstractVectorOfArray copyto!(dest.u[i], slice) @@ -604,7 +639,8 @@ function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T2, end dest end -function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T2}) where {T, T2, N} +function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, + src::AbstractVector{T2}) where {T, T2, N} copyto!(dest.u, src) dest end @@ -617,8 +653,8 @@ end # Operations function Base.isapprox(A::AbstractVectorOfArray, - B::Union{AbstractVectorOfArray, AbstractArray}; - kwargs...) + B::Union{AbstractVectorOfArray, AbstractArray}; + kwargs...) return all(isapprox.(A, B; kwargs...)) end @@ -631,11 +667,12 @@ for op in [:(Base.:-), :(Base.:+)] ($op).(A, B) end @eval Base.@propagate_inbounds function ($op)(A::AbstractVectorOfArray, - B::AbstractArray) + B::AbstractArray) @boundscheck length(A) == length(B) VectorOfArray([($op).(a, b) for (a, b) in zip(A, B)]) end - @eval Base.@propagate_inbounds function ($op)(A::AbstractArray, B::AbstractVectorOfArray) + @eval Base.@propagate_inbounds function ($op)( + A::AbstractArray, B::AbstractVectorOfArray) @boundscheck length(A) == length(B) VectorOfArray([($op).(a, b) for (a, b) in zip(A, B)]) end @@ -753,7 +790,8 @@ Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u) function Base.mapreduce(f, op, A::AbstractVectorOfArray; kwargs...) mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...) end -function Base.mapreduce(f, op, A::AbstractVectorOfArray{T,1,<:AbstractVector{T}}; kwargs...) where {T} +function Base.mapreduce( + f, op, A::AbstractVectorOfArray{T, 1, <:AbstractVector{T}}; kwargs...) where {T} mapreduce(f, op, A.u; kwargs...) end @@ -765,15 +803,15 @@ VectorOfArrayStyle(::Val{N}) where {N} = VectorOfArrayStyle{N}() # The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle. Broadcast.BroadcastStyle(a::VectorOfArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a function Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, - a::Base.Broadcast.DefaultArrayStyle{M}) where {M, N} + a::Base.Broadcast.DefaultArrayStyle{M}) where {M, N} Base.Broadcast.DefaultArrayStyle(Val(max(M, N))) end function Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, - a::Base.Broadcast.AbstractArrayStyle{M}) where {M, N} + a::Base.Broadcast.AbstractArrayStyle{M}) where {M, N} typeof(a)(Val(max(M, N))) end function Broadcast.BroadcastStyle(::VectorOfArrayStyle{M}, - ::VectorOfArrayStyle{N}) where {M, N} + ::VectorOfArrayStyle{N}) where {M, N} VectorOfArrayStyle(Val(max(M, N))) end function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T, N}}) where {T, N} @@ -791,11 +829,11 @@ Broadcast.broadcastable(x::AbstractVectorOfArray) = x end for (type, N_expr) in [ - (Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))), - (Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u))) - ] + (Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))), + (Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u))) +] @eval @inline function Base.copyto!(dest::AbstractVectorOfArray, - bc::$type) + bc::$type) bc = Broadcast.flatten(bc) N = $N_expr @inbounds for i in 1:N diff --git a/test/adjoints.jl b/test/adjoints.jl index becdc9f2..1e5ee3c3 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -63,7 +63,7 @@ function loss8(x) end function loss9(x) - return VectorOfArray([collect(3i:3i+3) .* x for i in 1:5]) + return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5]) end x = float.(6:10) @@ -76,4 +76,5 @@ loss(x) @test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x) @test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x) @test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x) -@test ForwardDiff.derivative(loss9, 0.0) == VectorOfArray([collect(3i:3i+3) for i in 1:5]) +@test ForwardDiff.derivative(loss9, 0.0) == + VectorOfArray([collect((3i):(3i + 3)) for i in 1:5]) diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index b3767c30..a6e0c77c 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -16,8 +16,8 @@ mulX .= sqrt.(abs.(testva .* X)) @test mulX == ref @test Array(testva) == [1 4 7 - 2 5 8 - 3 6 9] + 2 5 8 + 3 6 9] @test testa[1:2, 1:2] == [1 4; 2 5] @test testva[1:2, 1:2] == [1 4; 2 5] @@ -26,8 +26,8 @@ mulX .= sqrt.(abs.(testva .* X)) t = [1, 2, 3] diffeq = DiffEqArray(recs, t) @test Array(diffeq) == [1 4 7 - 2 5 8 - 3 6 9] + 2 5 8 + 3 6 9] @test diffeq[1:2, 1:2] == [1 4; 2 5] # # ndims == 2 @@ -65,7 +65,7 @@ testvb = deepcopy(testva) @test testva[:, end] == last(testva) @test testa[:, 1] == recs[1] @test testva.u == recs -@test testva[: ,2:end] == VectorOfArray([recs[i] for i in 2:length(recs)]) +@test testva[:, 2:end] == VectorOfArray([recs[i] for i in 2:length(recs)]) diffeq = DiffEqArray(recs, t) @test_deprecated diffeq[1] @@ -205,7 +205,7 @@ for i in 1:2:5 end testva[CartesianIndex(3, 3, 5)] = 64.0 @test testva[:, 5][3, 3] == 64.0 -@test_throws ArgumentError testva[2, 1:2, :] = 108.0 +@test_throws ArgumentError testva[2, 1:2, :]=108.0 testva[2, 1:2, :] .= 108.0 for i in 1:5 @test all(testva[:, i][2, 1:2] .== 108.0) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index bb47337f..e8790a09 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -60,7 +60,7 @@ push!(testda, [-1, -2, -3, -4]) # Type inference @inferred sum(testva) -@inferred sum(VectorOfArray([VectorOfArray([zeros(4,4)])])) +@inferred sum(VectorOfArray([VectorOfArray([zeros(4, 4)])])) @inferred mapreduce(string, *, testva) # mapreduce @@ -70,9 +70,9 @@ testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) testvb = VectorOfArray([rand(1:10, 3, 3, 3) for _ in 1:4]) arrvb = Array(testvb) for i in 1:ndims(arrvb) - @test sum(arrvb; dims=i) == sum(testvb; dims=i) - @test prod(arrvb; dims=i) == prod(testvb; dims=i) - @test mapreduce(string, *, arrvb; dims=i) == mapreduce(string, *, testvb; dims=i) + @test sum(arrvb; dims = i) == sum(testvb; dims = i) + @test prod(arrvb; dims = i) == prod(testvb; dims = i) + @test mapreduce(string, *, arrvb; dims = i) == mapreduce(string, *, testvb; dims = i) end # Test when ndims == 1 @@ -85,7 +85,8 @@ arrvb = Array(testvb) # view testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3]) arrvc = Array(testvc) -for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (:, :, :), (1:2, 1:2, Bool[1, 0, 1]), (1:2, Bool[1, 0, 1], 1:2), (Bool[1, 0, 1], 1:2, 1:2)] +for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, :, :), (:, :, :), + (1:2, 1:2, Bool[1, 0, 1]), (1:2, Bool[1, 0, 1], 1:2), (Bool[1, 0, 1], 1:2, 1:2)] arr_view = view(arrvc, idxs...) voa_view = view(testvc, idxs...) @test size(arr_view) == size(voa_view) @@ -110,8 +111,9 @@ end @test stack(testva) == [1 4 7; 2 5 8; 3 6 9] @test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9] -testva = VectorOfArray([VectorOfArray([ones(2,2), 2ones(2, 2)]), 3ones(2, 2, 2)]) -@test stack(testva) == [1.0 1.0; 1.0 1.0;;; 2.0 2.0; 2.0 2.0;;;; 3.0 3.0; 3.0 3.0;;; 3.0 3.0; 3.0 3.0] +testva = VectorOfArray([VectorOfArray([ones(2, 2), 2ones(2, 2)]), 3ones(2, 2, 2)]) +@test stack(testva) == + [1.0 1.0; 1.0 1.0;;; 2.0 2.0; 2.0 2.0;;;; 3.0 3.0; 3.0 3.0;;; 3.0 3.0; 3.0 3.0] # convert array from VectorOfArray/DiffEqArray t = 1:8 @@ -164,12 +166,13 @@ testva = VectorOfArray([ VectorOfArray([3ones(3), 4ones(3)]) ]), DiffEqArray([ - 5ones(3, 2), - VectorOfArray([ - 6ones(3), - 7ones(3), - ]), - ], [0.1, 0.2], [100.0, 200.0], SymbolCache([:x, :y], [:a, :b], :t)) + 5ones(3, 2), + VectorOfArray([ + 6ones(3), + 7ones(3) + ]) + ], [0.1, 0.2], + [100.0, 200.0], SymbolCache([:x, :y], [:a, :b], :t)) ]) arr = rand(3, 2, 2, 3) copyto!(testva, arr) @@ -208,7 +211,7 @@ DA = DiffEqArray(map(i -> rand(2, 4), 1:7), 1:7) u = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})]) @test typeof(zero(u)) <: typeof(u) -resize!(u,3) +resize!(u, 3) @test pointer(u) === pointer(u.u) # Ensure broadcast (including assignment) works with StaticArrays @@ -223,11 +226,11 @@ u1 = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})]) u2 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})]) u3 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})]) -function f(u1,u2,u3) +function f(u1, u2, u3) u3 .= u1 .+ u2 end -f(u1,u2,u3) -@test (@allocated f(u1,u2,u3)) == 0 +f(u1, u2, u3) +@test (@allocated f(u1, u2, u3)) == 0 yy = [2.0 1.0; 2.0 1.0] zz = x .+ yy @@ -237,11 +240,11 @@ z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})]) z .= zz @test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) -function f!(z,zz) +function f!(z, zz) z .= zz end -f!(z,zz) -@test (@allocated f!(z,zz)) == 0 +f!(z, zz) +@test (@allocated f!(z, zz)) == 0 z .= 0.1 @test z == VectorOfArray([fill(0.1, SVector{2, Float64}), fill(0.1, SVector{2, Float64})]) @@ -253,7 +256,7 @@ f2!(z) @test (@allocated f2!(z)) == 0 function f3!(z, zz) - @.. broadcast=false z = zz + @.. broadcast=false z=zz end f3!(z, zz) @test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index e823246a..d8164edf 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -3,20 +3,20 @@ using RecursiveArrayTools, Test @testset "NamedArrayPartition tests" begin x = NamedArrayPartition(a = ones(10), b = rand(20)) @test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition - @test typeof(x.^2) <: NamedArrayPartition + @test typeof(x .^ 2) <: NamedArrayPartition @test x.a ≈ ones(10) @test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence - @test all(x .== x[1:end]) + @test all(x .== x[1:end]) y = copy(x) @test zero(x, (10, 20)) == zero(x) # test that ignoring dims works @test typeof(zero(x)) <: NamedArrayPartition @test (y .*= 2).a[1] ≈ 2 # test in-place bcast - @test length(Array(x))==30 + @test length(Array(x)) == 30 @test typeof(Array(x)) <: Array @test propertynames(x) == (:a, :b) - x = NamedArrayPartition(a = ones(1), b = 2*ones(1)) + x = NamedArrayPartition(a = ones(1), b = 2 * ones(1)) @test Base.summary(x) == string(typeof(x), " with arrays:") io = IOBuffer() Base.show(io, MIME"text/plain"(), x) @@ -24,11 +24,10 @@ using RecursiveArrayTools, Test using StructArrays using StaticArrays: SVector - x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))), - b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2)))) + x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2 * ones(5))), + b = StructArray{SVector{2, Float64}}((3 * ones(2, 2), 4 * ones(2, 2)))) @test typeof(x.a) <: StructVector{<:SVector{2}} @test typeof(x.b) <: StructArray{<:SVector{2}, 2} - @test typeof((x->x[1]).(x)) <: NamedArrayPartition - @test typeof(map(x->x[1], x)) <: NamedArrayPartition + @test typeof((x -> x[1]).(x)) <: NamedArrayPartition + @test typeof(map(x -> x[1], x)) <: NamedArrayPartition end - diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 15c31655..5f46b34b 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -126,9 +126,9 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0])) # mapreduce @inferred Union{Int, Float64} sum(x) -@inferred sum(ArrayPartition(ArrayPartition(zeros(4,4)))) +@inferred sum(ArrayPartition(ArrayPartition(zeros(4, 4)))) @inferred sum(ArrayPartition(ArrayPartition(zeros(4)))) -@inferred sum(ArrayPartition(zeros(4,4))) +@inferred sum(ArrayPartition(zeros(4, 4))) @inferred mapreduce(string, *, x) @test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q" @@ -150,7 +150,7 @@ _broadcast_wrapper(y) = _scalar_op.(y) S = [ ((1,), (2,)) => ((1,), (2,)), ((3, 2), (2,)) => ((3, 2), (2,)), - ((3, 2), (2,)) => ((3,), (3,), (2,)), + ((3, 2), (2,)) => ((3,), (3,), (2,)) ] for sizes in S @@ -194,7 +194,7 @@ Base.IndexStyle(::MyType) = IndexLinear() Base.BroadcastStyle(::Type{<:MyType}) = Broadcast.ArrayStyle{MyType}() function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}}, - ::Type{T}) where {T} + ::Type{T}) where {T} similar(find_mt(bc), T) end @@ -224,19 +224,20 @@ up = ap .+ 1 up = 2 .* ap .+ 1 @test typeof(ap) == typeof(up) -@testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1, +@testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ( + (1, 2, false), ([ - 1, + 1 ], 2, false), ([ - 1, + 1 ], [ - 2, + 2 ], true)) @test ArrayInterface.ismutable(ArrayPartition(a, b)) == r diff --git a/test/runtests.jl b/test/runtests.jl index 967933d1..69fda8b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,12 +25,12 @@ end @time @safetestset "Utils Tests" begin include("utils_test.jl") end - @time @safetestset "NamedArrayPartition Tests" begin - include("named_array_partition_tests.jl") + @time @safetestset "NamedArrayPartition Tests" begin + include("named_array_partition_tests.jl") end @time @safetestset "Partitions Tests" begin include("partitions_test.jl") - end + end @time @safetestset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end @@ -52,7 +52,9 @@ end @time @safetestset "Upstream Tests" begin include("upstream.jl") end - @time @safetestset "Adjoint Tests" begin include("adjoints.jl") end + @time @safetestset "Adjoint Tests" begin + include("adjoints.jl") + end @time @safetestset "Measurement Tests" begin include("measurements.jl") end diff --git a/test/symbolic_indexing_interface_test.jl b/test/symbolic_indexing_interface_test.jl index a2a1ae30..74ab04a6 100644 --- a/test/symbolic_indexing_interface_test.jl +++ b/test/symbolic_indexing_interface_test.jl @@ -42,7 +42,8 @@ get_tuple = getu(dx, (:a, :b)) @test variable_index.((dx,), [:a, :b, :p, :q, :t]) == [1, 2, nothing, nothing, nothing] @test is_parameter.((dx,), [:a, :b, :p, :q, :t]) == [false, false, true, true, false] @test parameter_index.((dx,), [:a, :b, :p, :q, :t]) == [nothing, nothing, 1, 2, nothing] -@test is_independent_variable.((dx,), [:a, :b, :p, :q, :t]) == [false, false, false, false, true] +@test is_independent_variable.((dx,), [:a, :b, :p, :q, :t]) == + [false, false, false, false, true] @test variable_symbols(dx) == all_variable_symbols(dx) == [:a, :b] @test parameter_symbols(dx) == [:p, :q] @test independent_variable_symbols(dx) == [:t] @@ -57,4 +58,3 @@ ABC = @SLVector (:a, :b, :c); A = ABC(1, 2, 3); B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]); @test getindex(B, :a) == [1, 1] - diff --git a/test/testutils.jl b/test/testutils.jl index 6917b9c7..dd77f38f 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -4,7 +4,7 @@ using RecursiveArrayTools: Tables, IteratorInterfaceExtensions # Test Tables interface with row access + IteratorInterfaceExtensions for QueryVerse # (see https://tables.juliadata.org/stable/#Testing-Tables.jl-Implementations) function test_tables_interface(x::AbstractDiffEqArray, names::Vector{Symbol}, - values::Matrix) + values::Matrix) @assert length(names) == size(values, 2) # AbstractDiffEqArray is a table with row access diff --git a/test/upstream.jl b/test/upstream.jl index 37fe0341..31f27a05 100644 --- a/test/upstream.jl +++ b/test/upstream.jl @@ -36,12 +36,16 @@ function dyn(u, p, t) ArrayPartition(zeros(1), [0.0])) end -@test solve(ODEProblem(dyn, +@test solve( + ODEProblem(dyn, ArrayPartition(ArrayPartition(zeros(1), [-1.0]), ArrayPartition(zeros(1), [0.75])), - (0.0, 1.0)), AutoTsit5(Rodas5())).retcode == ReturnCode.Success + (0.0, 1.0)), + AutoTsit5(Rodas5())).retcode == ReturnCode.Success -@test_broken solve(ODEProblem(dyn, +@test_broken solve( + ODEProblem(dyn, ArrayPartition(ArrayPartition(zeros(1), [-1.0]), ArrayPartition(zeros(1), [0.75])), - (0.0, 1.0)), Rodas5()).retcode == ReturnCode.Success + (0.0, 1.0)), + Rodas5()).retcode == ReturnCode.Success diff --git a/test/utils_test.jl b/test/utils_test.jl index 8bda7442..5ccda25a 100644 --- a/test/utils_test.jl +++ b/test/utils_test.jl @@ -9,7 +9,7 @@ data = convert(Array, randomized) ## Test means A = [[1 2; 3 4], [1 3; 4 6], [5 6; 7 8]] @test recursive_mean(A) ≈ [2.33333333 3.666666666 - 4.6666666666 6.0] + 4.6666666666 6.0] A = zeros(5, 5) @test recursive_unitless_eltype(A) == Float64 @@ -124,17 +124,17 @@ end @testset "VectorOfArray recursivecopy!" begin u1 = VectorOfArray([fill(2, MVector{2, Float64}), ones(MVector{2, Float64})]) u2 = VectorOfArray([fill(4, MVector{2, Float64}), 2 .* ones(MVector{2, Float64})]) - recursivecopy!(u1,u2) - @test u1.u[1] == [4.0,4.0] - @test u1.u[2] == [2.0,2.0] + recursivecopy!(u1, u2) + @test u1.u[1] == [4.0, 4.0] + @test u1.u[2] == [2.0, 2.0] @test u1.u[1] isa MVector @test u1.u[2] isa MVector u1 = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})]) u2 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})]) - recursivecopy!(u1,u2) - @test u1.u[1] == [4.0,4.0] - @test u1.u[2] == [2.0,2.0] + recursivecopy!(u1, u2) + @test u1.u[1] == [4.0, 4.0] + @test u1.u[2] == [2.0, 2.0] @test u1.u[1] isa SVector @test u1.u[2] isa SVector end