Skip to content

Commit

Permalink
fixup! fix: fix view adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 30, 2024
1 parent 97edb58 commit e822993
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>"]
version = "3.19.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down Expand Up @@ -36,6 +37,7 @@ RecursiveArrayToolsTrackerExt = "Tracker"
RecursiveArrayToolsZygoteExt = "Zygote"

[compat]
Accessors = "0.1"
Adapt = "3.4, 4"
Aqua = "0.8"
ArrayInterface = "7.6"
Expand Down
16 changes: 8 additions & 8 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module RecursiveArrayToolsZygoteExt

using RecursiveArrayTools
using RecursiveArrayTools.Accessors: @set, @reset

if isdefined(Base, :get_extension)
using Zygote
Expand Down Expand Up @@ -125,9 +126,10 @@ end
@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
A .= y
(A, map(_ -> nothing, I)...)
u = collect.(eachslice(y, dims=ndims(y)))
B = @set A.u = u

Check warning on line 130 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L129-L130

Added lines #L129 - L130 were not covered by tests

(B, map(_ -> nothing, I)...)

Check warning on line 132 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L132

Added line #L132 was not covered by tests
end
end
return view(A, I...), view_adjoint
Expand All @@ -136,11 +138,9 @@ end
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
recursivefill!(A, zero(eltype(A)))
v = view(A, I...)
v .= y
return (A, map(_ -> nothing, I)...)
B = @set A .= zero(eltype(A))
@reset B[I...] = y
return (B, map(_ -> nothing, I)...)

Check warning on line 143 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L141-L143

Added lines #L141 - L143 were not covered by tests
end
end
view(A, I...), view_adjoint
Expand Down
1 change: 1 addition & 0 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using DocStringExtensions
using RecipesBase, StaticArraysCore, Statistics,
ArrayInterface, LinearAlgebra
using SymbolicIndexingInterface
import Accessors
using SparseArrays

import Adapt
Expand Down

0 comments on commit e822993

Please sign in to comment.