Skip to content

Commit

Permalink
Merge pull request #1350 from marius311/buffer_widening
Browse files Browse the repository at this point in the history
fix forward-over-reverse with Buffer in some cases
  • Loading branch information
ToucheSir authored Jan 8, 2023
2 parents c2fb256 + 7beb46d commit 82a12c9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
12 changes: 7 additions & 5 deletions src/lib/buffer.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
grad_mut(b::Buffer) = fill!(similar(b.data, Any), nothing)
grad_mut(b::Buffer{T}) where T<:Number = fill!(similar(b.data, float(T)), 0)
grad_mut(cx::Context, b::Buffer, ::Type=Union{}) =
_get!(() -> fill!(similar(b.data, Any), nothing), cache(cx), b)
grad_mut(cx::Context, b::Buffer{T}, ::Type{S}=Union{}) where {T<:Number, S} =
_get!(() -> fill!(similar(b.data, float(promote_type(T, S))), 0), cache(cx), b)

@non_differentiable Buffer(::Any...)

@adjoint function getindex(b::Buffer, i...)
b[i...], function)
grad = grad_mut(__context__, b)
b[i...], function::S) where {S}
grad = grad_mut(__context__, b, S)
grad[i...] = accum(grad[i...], Δ)
return
end
Expand Down Expand Up @@ -48,7 +50,7 @@ _pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::Abstract
res = copy(b)

function copy_sensitivity(b̄)
grad_mut(__context__, b)[:] .= vec(b̄)
grad_mut(__context__, b, eltype(b̄))[:] .= vec(b̄)
return
end

Expand Down
9 changes: 6 additions & 3 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,18 @@ end

grad_mut(x) = Ref{Any}(nt_nothing(x))

function grad_mut(cx::Context, x)
ch = cache(cx)
grad_mut(cx::Context, x) = _get!(() -> grad_mut(x), cache(cx), x)

# needed for reverse-over-reverse pending rrule for Base.get!
function _get!(default::Base.Callable, ch, x)
if haskey(ch, x)
ch[x]
else
ch[x] = grad_mut(x)
ch[x] = default()
end
end


@adjoint! function setfield!(x, f, val)
y = setfield!(x, f, val)
g = grad_mut(__context__, x)
Expand Down
12 changes: 12 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,18 @@ using Zygote: Buffer
prod(copy(b))
end == (3,)

# backwards pass Buffer widening (#1349)
@test Zygote.hessian(1.) do A
buf = Zygote.Buffer([0, 0])
buf[:] = [1, 2]
sum(A^2 .* copy(buf))
end == 6
@test Zygote.hessian(1.) do A
buf = Zygote.Buffer([0, 0])
buf[1] = 1
A^2 * buf[1]
end == 2

# Buffer storing arrays test
W1 = ones(3, 3)
W2 = ones(3, 3)
Expand Down

0 comments on commit 82a12c9

Please sign in to comment.