Skip to content

Commit

Permalink
only project nothings in gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Mar 24, 2024
1 parent c0daccd commit bcbbfab
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is no
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")

# Preserves output as tuple when gradients are collapsed
_project_all(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N)
_project_all(x::Tuple, dx::Tuple) = map(_project, x, dx)
_project_nothings(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N)
_project_nothings(x::Tuple, dx::Tuple) = map(x, dx) do _x, _dx
return _dx === nothing ? _project(_x, _dx) : _dx
end

"""
gradient(f, args...)
Expand Down Expand Up @@ -146,7 +148,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
return _project_all(args, grad)
return _project_nothings(args, grad)
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand Down Expand Up @@ -212,7 +214,7 @@ function withgradient(f, args...)
else
back(sensitivity(y))
end
results = _project_all(args, grad)
results = _project_nothings(args, grad)
(val=y, grad=results)
end

Expand Down Expand Up @@ -473,7 +475,7 @@ function pullback(f, ps::Params)
end

# No conversion required here
_project_all(_, dx::Grads) = dx
_project_nothings(_, dx::Grads) = dx

# Code Reflection

Expand Down
6 changes: 6 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2125,3 +2125,9 @@ end
@test gradient(x -> @.(x * x * x), 2.0) == gradient(x -> x * (x * x), 2.0)
@test gradient(x -> @.(3.0*x*2.0*x), 2.0) == gradient(x -> 6(x^2), 2.0)
end

@testset "Sparse input" begin
g1 = Zygote.gradient(sum, zeros(1,1))[1]
g2 = Zygote.gradient(sum, spzeros(1,1))[1]
@test g1 == g2
end

0 comments on commit bcbbfab

Please sign in to comment.