diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 47a718a8c..cd5698a3c 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -76,6 +76,11 @@ Numeric{T<:Number} = Union{T,AbstractArray{<:T}} res, Δ -> (nothing, unbroadcast(x, Δ ./ y), unbroadcast(y, -Δ .* res ./ y)) end +@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p + y = Base.literal_pow.(^, x, exp) + y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing) +end + @adjoint broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ) @adjoint function broadcasted(::typeof(σ), x::Numeric) diff --git a/src/lib/number.jl b/src/lib/number.jl index 8e293b49a..e28139a3a 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -22,6 +22,7 @@ end for (M, f, arity) in DiffRules.diffrules() arity == 2 || continue + f == :^ && continue da, db = DiffRules.diffrule(M, f, :a, :b) @eval begin @adjoint $M.$f(a::Number, b::Number) = $M.$f(a, b), @@ -29,6 +30,12 @@ for (M, f, arity) in DiffRules.diffrules() end end +@adjoint Base.:^(x::Number, p::Number) = x^p, + Δ -> (Δ * conj(p * x^(p-1)), Δ * conj(x^p * log(complex(x)))) +@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} = + Base.literal_pow(^,x,Val(p)), + Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing) + @adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ) @adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index cb1cbc882..40f44f5d8 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -42,6 +42,22 @@ Random.seed!(0) @test gradient(//, 2, 3) === (1//3, -2//9) +@testset "power" begin + @test gradient(x -> x^2, -2) == (-4,) + @test gradient(x -> x^10, -1.0) == (-10,) # literal_pow + pow = 10 + @test gradient(x -> x^pow, -1.0) == (-pow,) + @test gradient(p -> real(2^p), 2)[1] ≈ 4*log(2) + + @test gradient(xs ->sum(xs .^ 2), [2, -1]) == ([4, -2],) + @test gradient(xs ->sum(xs .^ 10), [3, -1]) == ([10*3^9, -10],) + @test gradient(xs ->sum(xs .^ pow), [4, -1]) == ([pow*4^9, -10],) + + @test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,) + @test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ (-234 + 2im)*log(5 - 7im) + # D[(1+3I)x^p, p] /. {x->5+7I, p->2} // Conjugate +end + @test gradtest((a,b)->sum(reim(acosh(complex(a[1], b[1])))), [-2.0], [1.0]) @test gradtest((x, W, b) -> identity.(W*x .+ b), 5, (2,5), 2) @@ -63,6 +79,7 @@ Random.seed!(0) @test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) @test gradtest(x -> sum(x[i] for i in 1:length(x)), randn(10)) @test_broken gradtest(x -> sum(i->x[i], 1:length(x)), randn(10)) # https://github.com/FluxML/Zygote.jl/issues/231 +@test gradtest(x -> sum((i->x[i]).(1:length(x))), randn(10)) @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5)) @test gradtest(x -> prod(x), (3,4))