diff --git a/src/deprecated.jl b/src/deprecated.jl index 6fe88b5b1..c57fa0d7e 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -65,3 +65,9 @@ macro nograd(ex) end return blk end + +# Internal function used by some downstream packages. +# Removing this completely would require some tricky registry changes, +# but leaving it as a vestigial function is much easier. +# See https://github.com/FluxML/Zygote.jl/pull/1328 for more context. +function ∇getindex end diff --git a/src/lib/array.jl b/src/lib/array.jl index 37884cded..a182c037d 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -41,24 +41,6 @@ end @adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,) @adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,) -@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds) - -@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds) - -∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin - if inds isa NTuple{N,Int} && T <: Number - dx = OneElement(dy, inds, axes(x)) - elseif inds isa NTuple{<:Any, Integer} - dx = _zero(x, typeof(dy)) - dx[inds...] = dy - else - dx = _zero(x, eltype(dy)) - dxv = view(dx, inds...) - dxv .= accum.(dxv, _droplike(dy, dxv)) - end - return (_project(x, dx), map(_->nothing, inds)...) -end - """ OneElement(val, ind, axes) <: AbstractArray diff --git a/test/gradcheck.jl b/test/gradcheck.jl index b7fd5391f..06533ba78 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -174,11 +174,11 @@ end # Ensure that nothings work with numeric types. _, back = Zygote.pullback(getindex, randn(4), [1]) - @test back([nothing]) == (zeros(4), nothing) + @test back([nothing]) === nothing # Ensure that nothings work with non-numeric types. _, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1]) - @test back([nothing]) == (nothing, nothing) + @test back([nothing]) === nothing end @testset "view" begin diff --git a/test/utils.jl b/test/utils.jl index cb11437cf..4e7f4929b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -2,14 +2,8 @@ using ForwardDiff using Zygote: hessian_dual, hessian_reverse @testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse] - - if hess == hessian_dual - @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] - @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version - else - @test_broken hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex - @test_broken hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] - end + @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] + @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version @test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24]) @test hess(sin, pi/2) ≈ -1