Skip to content

Commit

Permalink
Merge pull request #382 from AayushSabharwal/as/fix-adjoints
Browse files Browse the repository at this point in the history
fix: fix `Base.view` adjoints
  • Loading branch information
ChrisRackauckas authored May 30, 2024
2 parents 5e9e8f8 + 97edb58 commit e0b18dc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
27 changes: 17 additions & 10 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end

@adjoint function Base.copy(u::VectorOfArray)
copy(u),
y -> (copy(y),)
tuple copy
end

@adjoint function DiffEqArray(u, t)
Expand All @@ -115,26 +115,33 @@ end
adj = let VA = VA
function Array_adjoint(y)
VA = recursivecopy(VA)
copyto!(VA, y)
VA .= y
return (VA,)
end
end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
function adjoint(y)
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
A .= y
(A, map(_ -> nothing, I)...)
end
end
return view(A, I...), adjoint
return view(A, I...), view_adjoint
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
function view_adjoint(y)
A = recursivecopy(parent(y))
recursivefill!(A, zero(eltype(A)))
A[I...] .= y
return (A, map(_ -> nothing, I)...)
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
recursivefill!(A, zero(eltype(A)))
v = view(A, I...)
v .= y
return (A, map(_ -> nothing, I)...)
end
end
view(A, I...), view_adjoint
end
Expand Down
12 changes: 12 additions & 0 deletions test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ function loss9(x)
return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5])
end

function loss10(x)
voa = VectorOfArray([i * x for i in 1:5])
return sum(view(voa, 2:4, 3:5))
end

function loss11(x)
voa = VectorOfArray([i * x for i in 1:5])
return sum(view(voa, :, :))
end

x = float.(6:10)
loss(x)
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
Expand All @@ -78,3 +88,5 @@ loss(x)
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
@test ForwardDiff.derivative(loss9, 0.0) ==
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)

0 comments on commit e0b18dc

Please sign in to comment.