Skip to content

Commit

Permalink
fix: view adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 22, 2024
1 parent cc55ee6 commit 680e756
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,21 @@ end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
function adjoint(y)
(recursivecopy(parent(y)), map(_ -> nothing, I)...)

Check warning on line 123 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L121-L123

Added lines #L121 - L123 were not covered by tests
end
return view(A, I...), adjoint

Check warning on line 125 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L125

Added line #L125 was not covered by tests
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
adj = let A = A, I = I
function view_adjoint(y)
A = zero(A)
view(A, I...) .= y
return (A, map(_ -> nothing, I)...)
end
function view_adjoint(y)
A = recursivecopy(parent(y))
recursivefill!(A, zero(eltype(A)))
A[I...] .= y
return (A, map(_ -> nothing, I)...)

Check warning on line 133 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L129-L133

Added lines #L129 - L133 were not covered by tests
end
view(A, I...), adj
view(A, I...), view_adjoint

Check warning on line 135 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L135

Added line #L135 was not covered by tests
end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
Expand Down

0 comments on commit 680e756

Please sign in to comment.