Skip to content

Commit

Permalink
rename to _project_grad and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Mar 24, 2024
1 parent 6c5b17b commit ccae706
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is no
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")

# Preserves output as tuple when gradients are collapsed
_project_sentinel(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N)
_project_sentinel(x::Tuple, dx::Tuple) = map(_project_sentinel, x, dx)
_project_sentinel(::Any, ::NoTangent) = nothing
_project_sentinel(::Any, ::ZeroTangent) = nothing
_project_sentinel(::Any, ::Nothing) = nothing
_project_sentinel(::Any, dx::Any) = dx
_project_grad(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N)
_project_grad(x::Tuple, dx::Tuple) = map(_project_grad, x, dx)
_project_grad(::Any, ::NoTangent) = nothing
_project_grad(::Any, ::ZeroTangent) = nothing
_project_grad(::Any, ::Nothing) = nothing
_project_grad(::Any, dx::Any) = dx
_project_grad(x::AbstractArray, dx::Tuple) = _project(x, dx)
_project_grad(x::Any, dx::Base.RefValue) = _project(x, dx)

"""
gradient(f, args...)
Expand Down Expand Up @@ -150,7 +152,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
return _project_sentinel(args, grad)
return _project_grad(args, grad)
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand Down Expand Up @@ -216,7 +218,7 @@ function withgradient(f, args...)
else
back(sensitivity(y))
end
results = _project_sentinel(args, grad)
results = _project_grad(args, grad)
(val=y, grad=results)
end

Expand Down Expand Up @@ -477,7 +479,7 @@ function pullback(f, ps::Params)
end

# No conversion required here
_project_sentinel(_, dx::Grads) = dx
_project_grad(_, dx::Grads) = dx

# Code Reflection

Expand Down

0 comments on commit ccae706

Please sign in to comment.