diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 9b253a4f..e668676f 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -88,7 +88,7 @@ end @adjoint function Base.copy(u::VectorOfArray) copy(u), - y -> (copy(y),) + tuple ∘ copy end @adjoint function DiffEqArray(u, t) @@ -115,7 +115,7 @@ end adj = let VA = VA function Array_adjoint(y) VA = recursivecopy(VA) - copyto!(VA, y) + VA .= y return (VA,) end end @@ -123,18 +123,25 @@ end end @adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...) - function adjoint(y) - (recursivecopy(parent(y)), map(_ -> nothing, I)...) + view_adjoint = let A = A, I = I + function (y) + A = recursivecopy(A) + A .= y + (A, map(_ -> nothing, I)...) + end end - return view(A, I...), adjoint + return view(A, I...), view_adjoint end @adjoint function Base.view(A::AbstractVectorOfArray, I...) - function view_adjoint(y) - A = recursivecopy(parent(y)) - recursivefill!(A, zero(eltype(A))) - A[I...] .= y - return (A, map(_ -> nothing, I)...) + view_adjoint = let A = A, I = I + function (y) + A = recursivecopy(A) + recursivefill!(A, zero(eltype(A))) + v = view(A, I...) + v .= y + return (A, map(_ -> nothing, I)...) + end end view(A, I...), view_adjoint end diff --git a/test/adjoints.jl b/test/adjoints.jl index 1e5ee3c3..e5a1fc50 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -66,6 +66,16 @@ function loss9(x) return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5]) end +function loss10(x) + voa = VectorOfArray([i * x for i in 1:5]) + return sum(view(voa, 2:4, 3:5)) +end + +function loss11(x) + voa = VectorOfArray([i * x for i in 1:5]) + return sum(view(voa, :, :)) +end + x = float.(6:10) loss(x) @test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x) @@ -78,3 +88,5 @@ loss(x) @test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x) @test ForwardDiff.derivative(loss9, 0.0) == VectorOfArray([collect((3i):(3i + 3)) for i in 1:5]) +@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x) +@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)