Skip to content

Commit

Permalink
fix: fix view adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 30, 2024
1 parent b75de6e commit fc037f9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
24 changes: 15 additions & 9 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end

@adjoint function Base.copy(u::VectorOfArray)
copy(u),
y -> (copy(y),)
tuple copy
end

@adjoint function DiffEqArray(u, t)
Expand Down Expand Up @@ -123,18 +123,24 @@ end
end

@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
function adjoint(y)
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
copyto!(A, y)
(A, map(_ -> nothing, I)...)

Check warning on line 130 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L126-L130

Added lines #L126 - L130 were not covered by tests
end
end
return view(A, I...), adjoint
return view(A, I...), view_adjoint

Check warning on line 133 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L133

Added line #L133 was not covered by tests
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
function view_adjoint(y)
A = recursivecopy(parent(y))
recursivefill!(A, zero(eltype(A)))
A[I...] .= y
return (A, map(_ -> nothing, I)...)
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
recursivefill!(A, zero(eltype(A)))
A[I...] .= y
return (A, map(_ -> nothing, I)...)

Check warning on line 142 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L137-L142

Added lines #L137 - L142 were not covered by tests
end
end
view(A, I...), view_adjoint
end
Expand Down
12 changes: 12 additions & 0 deletions test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ function loss9(x)
return VectorOfArray([collect((3i):(3i + 3)) .* x for i in 1:5])
end

function loss10(x)
voa = VectorOfArray([i * x for i in 1:5])
return sum(view(voa, 2:4, 3:5))
end

function loss11(x)
voa = VectorOfArray([i * x for i in 1:5])
return sum(view(voa, :, :))
end

x = float.(6:10)
loss(x)
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
Expand All @@ -78,3 +88,5 @@ loss(x)
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
@test ForwardDiff.derivative(loss9, 0.0) ==
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)

0 comments on commit fc037f9

Please sign in to comment.