diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 2c15b6e8..2b50a60a 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -118,11 +118,19 @@ end Array(VA), adj end +@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...) + function adjoint(y) + (recursivecopy(A), map(_ -> nothing, I)...) + end + return view(A, I...), adjoint +end + @adjoint function Base.view(A::AbstractVectorOfArray, I...) adj = let A = A, I = I function view_adjoint(y) - A = zero(A) - view(A, I...) .= y + A = recursivecopy(A) + recursivefill!(A, zero(eltype(A))) + A[I...] .= y return (A, map(_ -> nothing, I)...) end end