From 036cbf79b21505155bd7858e99136c8d25731428 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 30 May 2024 17:34:44 +0530 Subject: [PATCH] fix: fix `Base.view` adjoints --- ext/RecursiveArrayToolsReverseDiffExt.jl | 12 ++++++++++++ ext/RecursiveArrayToolsZygoteExt.jl | 8 ++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/ext/RecursiveArrayToolsReverseDiffExt.jl b/ext/RecursiveArrayToolsReverseDiffExt.jl index a35a842b..d551992d 100644 --- a/ext/RecursiveArrayToolsReverseDiffExt.jl +++ b/ext/RecursiveArrayToolsReverseDiffExt.jl @@ -23,4 +23,16 @@ end end return Array(VA), Array_adjoint end + +@adjoint function Base.view(A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N} + view_adjoint = let A = A, I = I + function (y) + A = recursivecopy(A) + trackedarraycopyto!(A, y) + (A, map(_ -> nothing, I)...) + end + end + return view(A, I...), view_adjoint +end + end # module diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index e668676f..415c7f92 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -115,7 +115,7 @@ end adj = let VA = VA function Array_adjoint(y) VA = recursivecopy(VA) - VA .= y + copyto!(VA, y) return (VA,) end end @@ -126,8 +126,8 @@ end view_adjoint = let A = A, I = I function (y) A = recursivecopy(A) - A .= y - (A, map(_ -> nothing, I)...) + copyto!(A, y) + return (A, map(_ -> nothing, I)...) end end return view(A, I...), view_adjoint @@ -139,7 +139,7 @@ end A = recursivecopy(A) recursivefill!(A, zero(eltype(A))) v = view(A, I...) - v .= y + copyto!(v, y) return (A, map(_ -> nothing, I)...) end end