From 680e756812252394585c241cad778b7a11211d01 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 22 Jan 2024 11:00:16 +0530 Subject: [PATCH] fix: view adjoints --- ext/RecursiveArrayToolsZygoteExt.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 2c15b6e8..0b75593f 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -118,15 +118,21 @@ end Array(VA), adj end +@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...) + function adjoint(y) + (recursivecopy(parent(y)), map(_ -> nothing, I)...) + end + return view(A, I...), adjoint +end + @adjoint function Base.view(A::AbstractVectorOfArray, I...) - adj = let A = A, I = I - function view_adjoint(y) - A = zero(A) - view(A, I...) .= y - return (A, map(_ -> nothing, I)...) - end + function view_adjoint(y) + A = recursivecopy(parent(y)) + recursivefill!(A, zero(eltype(A))) + A[I...] .= y + return (A, map(_ -> nothing, I)...) end - view(A, I...), adj + view(A, I...), view_adjoint end ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))