diff --git a/src/lib/array.jl b/src/lib/array.jl index c23da62d1..c77e91687 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -213,6 +213,11 @@ end end end +function _forward(cx::Context, ::typeof(norm), x::AbstractArray, p::Real = 2) + fallback = (x, p) -> sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 + _forward(cx, fallback, x, p) +end + # LinAlg Matrix Types # ===================