diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 5e5239743..dac1221ce 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -114,10 +114,11 @@ sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is no sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))") # Preserves output as tuple when gradients are collapsed -_project_nothings(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) -_project_nothings(x::Tuple, dx::Tuple) = map(x, dx) do _x, _dx - return _dx === nothing ? _project(_x, _dx) : _dx -end +_project_sentinel(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) +_project_sentinel(x::Tuple, dx::Tuple) = map(_project_sentinel, x, dx) +_project_sentinel(::Any, ::NoTangent) = nothing +_project_sentinel(::Any, ::ZeroTangent) = nothing +_project_sentinel(::Any, ::Nothing) = nothing """ gradient(f, args...) @@ -148,7 +149,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - return _project_nothings(args, grad) + return _project_sentinel(args, grad) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -214,7 +215,7 @@ function withgradient(f, args...) else back(sensitivity(y)) end - results = _project_nothings(args, grad) + results = _project_sentinel(args, grad) (val=y, grad=results) end @@ -475,7 +476,7 @@ function pullback(f, ps::Params) end # No conversion required here -_project_nothings(_, dx::Grads) = dx +_project_sentinel(_, dx::Grads) = dx # Code Reflection