Skip to content

Commit

Permalink
Fix similar for VectorOfArray
Browse files Browse the repository at this point in the history
Signed-off-by: ErikQQY <[email protected]>
  • Loading branch information
ErikQQY committed Sep 11, 2024
1 parent 8e727d1 commit 1af01c9
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 34 deletions.
3 changes: 2 additions & 1 deletion ext/RecursiveArrayToolsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ end
return Array(VA), Array_adjoint
end

@adjoint function Base.view(A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N}
@adjoint function Base.view(
A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N}
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
Expand Down
5 changes: 3 additions & 2 deletions ext/RecursiveArrayToolsSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module RecursiveArrayToolsSparseArraysExt
import SparseArrays
import RecursiveArrayTools

function Base.copyto!(dest::SparseArrays.AbstractCompressedVector, A::RecursiveArrayTools.ArrayPartition)
function Base.copyto!(
dest::SparseArrays.AbstractCompressedVector, A::RecursiveArrayTools.ArrayPartition)
@assert length(dest) == length(A)
cur = 1
@inbounds for i in 1:length(A.x)
Expand All @@ -17,4 +18,4 @@ function Base.copyto!(dest::SparseArrays.AbstractCompressedVector, A::RecursiveA
dest
end

end
end
1 change: 0 additions & 1 deletion ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ end
view(A, I...), view_adjoint
end


@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray,
y::Union{Zygote.Numeric, AbstractVectorOfArray})
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
Expand Down
60 changes: 36 additions & 24 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ returns a vector of the series for each component, that is, `A[i,:]` for each `i
A plot recipe is provided, which plots the `A[i,:]` series.
There is also support for `VectorOfArray` constructed from multi-dimensional arrays
```julia
VectorOfArray(u::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
```
Expand Down Expand Up @@ -60,8 +61,9 @@ A[1, :] # all time periods for f(t)
A.t
```
"""
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
AbstractDiffEqArray{T, N, A}
mutable struct DiffEqArray{
T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
AbstractDiffEqArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
t::B
p::F
Expand Down Expand Up @@ -177,7 +179,9 @@ function DiffEqArray(vec::AbstractVector{T},
::NTuple{N, Int},
p = nothing,
sys = nothing; discretes = nothing) where {T, N}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
DiffEqArray{
eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(
vec,
ts,
p,
sys,
Expand All @@ -197,7 +201,8 @@ end
function DiffEqArray(vec::AbstractVector{VT},
ts::AbstractVector,
::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
DiffEqArray{
eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
ts,
p,
nothing,
Expand Down Expand Up @@ -253,7 +258,7 @@ function DiffEqArray(vec::AbstractVector{VT},
typeof(ts),
typeof(p),
typeof(sys),
typeof(discretes),
typeof(discretes)
}(vec,
ts,
p,
Expand Down Expand Up @@ -375,19 +380,23 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli
end

struct ParameterIndexingError <: Exception
sym
sym::Any
end

function Base.showerror(io::IO, pie::ParameterIndexingError)
print(io, "Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.")
print(io,
"Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.")
end

# Symbolic Indexing Methods
for (symtype, elsymtype, valtype, errcheck) in [
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))),
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any,
:(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any,
:(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait,
Union{<:Tuple, <:AbstractArray},
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym)))
]
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg...)
Expand All @@ -413,8 +422,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
elsymtype = symbolic_type(eltype(_arg))

if symtype == NotSymbolic() && elsymtype == NotSymbolic()
if _arg isa Union{Tuple, AbstractArray} && any(x -> symbolic_type(x) != NotSymbolic(), _arg)
_getindex(A, symtype, elsymtype, _arg, args...)
if _arg isa Union{Tuple, AbstractArray} &&
any(x -> symbolic_type(x) != NotSymbolic(), _arg)
_getindex(A, symtype, elsymtype, _arg, args...)
else
_getindex(A, symtype, _arg, args...)
end
Expand Down Expand Up @@ -707,30 +717,32 @@ end

# Tools for creating similar objects
Base.eltype(::Type{<:AbstractVectorOfArray{T}}) where {T} = T
# TODO: Is there a better way to do this?

@inline function Base.similar(VA::AbstractVectorOfArray, args...)
if args[end] isa Type
return Base.similar(eltype(VA)[], args..., size(VA))
return Base.similar(VA.u, args..., size(VA))
else
return Base.similar(eltype(VA)[], args...)
return Base.similar(VA.u, args...)
end
end
@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)])
end

# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
return VectorOfArray(similar.(Base.parent(vec)))
end

# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractVector{<:AbstractArray{T}}}
return Base.similar(vec, eltype(vec))
@inline function Base.similar(VA::VectorOfArray, ::Type{T}) where {T}
VectorOfArray(similar(VA.u, T))
end

@inline function Base.similar(VA::VectorOfArray, dims::N) where {N}
VectorOfArray(similar(VA.u, dims))
end

@inline function Base.similar(VA::VectorOfArray{T, N, AT},
dims::Tuple) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
VectorOfArray(similar(VA.u, dims))
end

# fill!
# For DiffEqArray it ignores ts and fills only u
Expand Down
4 changes: 2 additions & 2 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ function foo!(u)
end
foo!(u_matrix)
foo!(u_vector)
@test all(u_matrix .== [3, 10])
@test all(u_matrix .== [3, 10])
@test all(vec(u_matrix) .≈ vec(u_vector))

# test that, for VectorOfArray with multi-dimensional parent arrays,
# broadcast and `similar` preserve the structure of the parent array
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
@test typeof(parent((x->x).(u_matrix))) == typeof(parent(u_matrix))
@test typeof(parent((x -> x).(u_matrix))) == typeof(parent(u_matrix))

# test efficiency
num_allocs = @allocations foo!(u_matrix)
Expand Down
4 changes: 4 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ testva2 = similar(testva)
@test typeof(testva2) == typeof(testva)
@test size(testva2) == size(testva)

testva3 = similar(testva, 10)
@test typeof(testva3) == typeof(testva)
@test length(testva3) == 10

# Fill AbstractVectorOfArray and check all
testval = 3.0
fill!(testva2, testval)
Expand Down
7 changes: 3 additions & 4 deletions test/partitions_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
@test all(isnan, ArrayPartition([NaN], [NaN]))
@test all(isnan, ArrayPartition([NaN], ArrayPartition([NaN])))


# broadcasting
_scalar_op(y) = y + 1
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:
Expand Down Expand Up @@ -303,7 +302,7 @@ end
end

@testset "Scalar copyto!" begin
u = [2.0,1.0]
copyto!(u, ArrayPartition(1.0,-1.2))
@test u == [1.0,-1.2]
u = [2.0, 1.0]
copyto!(u, ArrayPartition(1.0, -1.2))
@test u == [1.0, -1.2]
end

0 comments on commit 1af01c9

Please sign in to comment.