diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 47a718a8c..cd5698a3c 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -76,6 +76,11 @@ Numeric{T<:Number} = Union{T,AbstractArray{<:T}} res, Δ -> (nothing, unbroadcast(x, Δ ./ y), unbroadcast(y, -Δ .* res ./ y)) end +@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p + y = Base.literal_pow.(^, x, exp) + y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing) +end + @adjoint broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ) @adjoint function broadcasted(::typeof(σ), x::Numeric)