From cc55ee68c85e8c370da93ed5a87d1f7d1cacb385 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 19 Jan 2024 18:19:22 +0530 Subject: [PATCH] fix: fix `Array(::AbstractVectorOfArray)` adjoint --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- src/utils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index c4611137..2c15b6e8 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -110,7 +110,7 @@ end @adjoint function Base.Array(VA::AbstractVectorOfArray) adj = let VA=VA function Array_adjoint(y) - VA = copy(VA) + VA = recursivecopy(VA) copyto!(VA, y) return (VA,) end diff --git a/src/utils.jl b/src/utils.jl index 658e8418..4af362c9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,7 +28,7 @@ end function recursivecopy(a::AbstractVectorOfArray) b = copy(a) - b.u = recursivecopy.(a.u) + b.u .= recursivecopy.(a.u) return b end