Skip to content

Commit

Permalink
fix: fix Array(::AbstractVectorOfArray) adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 19, 2024
1 parent f226441 commit cc55ee6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ end
@adjoint function Base.Array(VA::AbstractVectorOfArray)
adj = let VA=VA
function Array_adjoint(y)
VA = copy(VA)
VA = recursivecopy(VA)

Check warning on line 113 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L113

Added line #L113 was not covered by tests
copyto!(VA, y)
return (VA,)
end
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end

function recursivecopy(a::AbstractVectorOfArray)
b = copy(a)
b.u = recursivecopy.(a.u)
b.u .= recursivecopy.(a.u)

Check warning on line 31 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L31

Added line #L31 was not covered by tests
return b
end

Expand Down

0 comments on commit cc55ee6

Please sign in to comment.