From f1e9526721db87ea1d88375dc8ac1cb4676d7169 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 17 Jan 2024 19:32:00 +0530 Subject: [PATCH] fix: fix several adjoints, copy and zero methods for VoA --- ext/RecursiveArrayToolsZygoteExt.jl | 26 ++++++++++++++++---- src/vector_of_array.jl | 38 ++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index c4ddc1a7..c4b4b520 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -50,7 +50,11 @@ end Colon, BitArray, AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))]) - Δ′[i, j...] = Δ + if isempty(j) + Δ′.u[i] = Δ + else + Δ′[i, j...] = Δ + end (Δ′, nothing, map(_ -> nothing, j)...) end VA[i, j...], AbstractVectorOfArray_getindex_adjoint @@ -104,13 +108,25 @@ end end @adjoint function Base.Array(VA::AbstractVectorOfArray) - Array(VA), - y -> (Array(y),) + adj = let VA=VA + function Array_adjoint(y) + VA = copy(VA) + VA .= y + return (VA,) + end + end + Array(VA), adj end @adjoint function Base.view(A::AbstractVectorOfArray, I...) - view(A, I...), - y -> (view(y, I...), ntuple(_ -> nothing, length(I))...) + adj = let A = A, I = I + function view_adjoint(y) + A = zero(A) + view(A, I...) .= y + return (A, map(_ -> nothing, I)...) + end + end + view(A, I...), adj end ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 16b58279..5f3d47ef 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -160,6 +160,16 @@ function DiffEqArray(vec::AbstractVector{T}, p, sys) end +function DiffEqArray(vec::AbstractVector{VT}, + ts::AbstractVector, + ::NTuple{N, Int}, + p = nothing, + sys = nothing) where {T, N, VT <: AbstractArray{T, N}} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec, + ts, + p, + sys) +end # Assume that the first element is representative of all other elements function DiffEqArray(vec::AbstractVector, @@ -466,19 +476,25 @@ end tuples(VA::DiffEqArray) = tuple.(VA.t, VA.u) # Growing the array simply adds to the container vector -function Base.copy(VA::AbstractDiffEqArray) - typeof(VA)(copy(VA.u), - copy(VA.t), - (VA.p === nothing) ? nothing : copy(VA.p), - (VA.sys === nothing) ? nothing : copy(VA.sys)) +function _copyfield(VA, fname) + if fname == :u + copy(VA.u) + elseif fname == :t + copy(VA.t) + else + getfield(VA, fname) + end +end +function Base.copy(VA::AbstractVectorOfArray) + typeof(VA)((_copyfield(VA, fname) for fname in fieldnames(typeof(VA)))...) end -Base.copy(VA::AbstractVectorOfArray) = typeof(VA)(copy(VA.u)) - -Base.zero(VA::AbstractVectorOfArray) = VectorOfArray(Base.zero.(VA.u)) -function Base.zero(VA::AbstractDiffEqArray) - u = Base.zero.(VA.u) - DiffEqArray(u, VA.t, parameter_values(VA), symbolic_container(VA)) +function Base.zero(VA::AbstractVectorOfArray) + val = copy(VA) + for i in eachindex(VA.u) + val.u[i] = zero(VA[i]) + end + return val end Base.sizehint!(VA::AbstractVectorOfArray{T, N}, i) where {T, N} = sizehint!(VA.u, i)