Skip to content

Commit

Permalink
fix: view adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 22, 2024
1 parent cc55ee6 commit 2af86b7
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,19 @@ end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
function adjoint(y)
(recursivecopy(A), map(_ -> nothing, I)...)

Check warning on line 123 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L121-L123

Added lines #L121 - L123 were not covered by tests
end
return view(A, I...), adjoint

Check warning on line 125 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L125

Added line #L125 was not covered by tests
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
adj = let A = A, I = I
function view_adjoint(y)
A = zero(A)
view(A, I...) .= y
A = recursivecopy(A)
recursivefill!(A, zero(eltype(A)))
A[I...] .= y

Check warning on line 133 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L131-L133

Added lines #L131 - L133 were not covered by tests
return (A, map(_ -> nothing, I)...)
end
end
Expand Down

0 comments on commit 2af86b7

Please sign in to comment.