diff --git a/src/lib/array.jl b/src/lib/array.jl index fe0438b5a..6d914d272 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -170,11 +170,11 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) # So we keep axes(x) to restore gradient dx to its full length & correct shape. _tryaxes(x) = axes(x) _tryaxes(x::Tuple) = Val(length(x)) -_tryaxes(::Number) = Val(-1) +_tryaxes(x::Number) = x _restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax) _restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(map(length, ax)) - length(dx))), ax) _restore(dx, ::Val{N}) where {N} = ntuple(i -> get(dx,i,nothing), N) -_restore(dx, ::Val{-1}) = only(dx) +_restore(dx, ::Number) = only(dx) # Sometimes a pullback doesn't return a Tuple, but rather returns only a # single nothing to say "all arguments have zero cotangent". This function is needed to