Skip to content

Commit

Permalink
Merge pull request #316 from AayushSabharwal/as/size
Browse files Browse the repository at this point in the history
fix: add support for adjoint of AbstractVectorOfArray
  • Loading branch information
ChrisRackauckas authored Dec 26, 2023
2 parents 2b69321 + cc14f2c commit e4a2044
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
end
end

Base.@propagate_inbounds function Base.getindex(A::Adjoint{T,<:AbstractVectorOfArray}, idxs...) where {T}
return getindex(A.parent, reverse(to_indices(A, idxs))...)
end

function _observed(A::AbstractDiffEqArray{T, N}, sym, i::Int) where {T, N}
observed(A, sym)(A.u[i], A.p, A.t[i])
end
Expand Down Expand Up @@ -395,6 +399,9 @@ end

# Interface for the two-dimensional indexing, a more standard AbstractArray interface
@inline Base.size(VA::AbstractVectorOfArray) = (size(VA.u[1])..., length(VA.u))
@inline Base.size(VA::AbstractVectorOfArray, i) = size(VA)[i]
@inline Base.size(A::Adjoint{T,<:AbstractVectorOfArray}) where {T} = reverse(size(A.parent))
@inline Base.size(A::Adjoint{T,<:AbstractVectorOfArray}, i) where {T} = size(A)[i]
Base.axes(VA::AbstractVectorOfArray) = Base.OneTo.(size(VA))
Base.axes(VA::AbstractVectorOfArray, d::Int) = Base.OneTo(size(VA)[d])

Expand Down Expand Up @@ -592,6 +599,7 @@ end
@inline Statistics.var(VA::AbstractVectorOfArray; kwargs...) = var(Array(VA); kwargs...)
@inline Statistics.cov(VA::AbstractVectorOfArray; kwargs...) = cov(Array(VA); kwargs...)
@inline Statistics.cor(VA::AbstractVectorOfArray; kwargs...) = cor(Array(VA); kwargs...)
@inline Base.adjoint(VA::AbstractVectorOfArray) = Adjoint(VA)

# make it show just like its data
function Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray)
Expand Down
7 changes: 7 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ for T in (Array{Float64}, Array{ComplexF64})
@test d.x[i] == b.x[i] * c.x[i]
end
end

va = VectorOfArray([i * ones(3) for i in 1:4])
mat = Array(va)

@test size(va') == (size(va', 1), size(va', 2)) == (size(va, 2), size(va, 1))
@test all(va'[i] == mat'[i] for i in eachindex(mat'))
@test Array(va') == mat'

0 comments on commit e4a2044

Please sign in to comment.