Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding complex broadcasting for gradients on the GPU #1324

Merged
merged 25 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" # not loaded, just a version bound
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -34,7 +34,7 @@ ChainRulesTestUtils = "1"
DiffRules = "1.4"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
ForwardDiff = "0.10"
GPUArrays = "8.4.2" # not loaded, just a version bound
GPUArrays = "8.4.2"
GPUArraysCore = "0.1.1"
IRTools = "0.4.4"
LogExpFunctions = "0.3.1"
Expand Down
101 changes: 89 additions & 12 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ end
@adjoint broadcasted(::typeof(imag), x::Numeric) =
imag.(x), z̄ -> (nothing, im .* real.(z̄))

@adjoint broadcasted(::typeof(abs2), x::Numeric) =
abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x)

@adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool)
y = b === false ? a : a .+ b
y, Δ -> (nothing, Δ, nothing)
Expand Down Expand Up @@ -190,7 +193,7 @@ _dual_safearg(x) = false
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
elseif T <: Union{Real, Complex} && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
return broadcast_forward(f, args...)
end
len = inclen(args)
Expand Down Expand Up @@ -232,23 +235,44 @@ end
import ForwardDiff
using ForwardDiff: Dual

dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
dual(x::Bool, p) = x

# We do this because it ensures type stability so it compiles nicely on the gpu
dual(x, i, N) = x
dual(x::Bool, i, ::Val{N}) where {N} = x
ptiede marked this conversation as resolved.
Show resolved Hide resolved
dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(j-> i==j, Val(N)))
# For complex since ForwardDiff.jl doesn't play nicely with complex numbers we
# construct a Complex dual number and tag the real and imaginary parts separately
function dual(x::Complex, i, ::Val{N}) where {N}
re_dual = Dual(real(x), ntuple(j->i==j, Val(2N)))
im_dual = Dual(imag(x), ntuple(j->(N+i)==j, Val(2N)))
ptiede marked this conversation as resolved.
Show resolved Hide resolved
return Complex(re_dual, im_dual)
end

function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = map(args, ntuple(identity,Val(N))) do x, i
dual(x, ntuple(j -> i==j, Val(N)))
function (args::Vararg{Any,N}) where N
ds = map(args, ntuple(identity,Val(N))) do x, i
tmp = dual(x, i, Val(N))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
return tmp
end
return f(ds...)
end
return f(ds...)
end
end


@inline function broadcast_forward(f, args::Vararg{Any,N}) where N
valN = Val(N)
out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
T = eltype(out)
T <: Union{Dual, Complex} || return (out, _ -> nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be Union{Dual, Dual{<:Complex}}? You'd have to try pretty hard but I think the Complex path expects Dual inside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought is was the other way around? At least that is what I am constructing in the dual_function. ForwardDiff.jl also defines Dual <: Real so I think defining it the other way would break things. However, I probably want to be a little more specific here and do

Suggested change
T <: Union{Dual, Complex} || return (out, _ -> nothing)
T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry, that's what I was thinking but didn't type...

if any(eltype(a) <: Complex for a in args)
_broadcast_forward_complex(T, out, args...)
else
_broadcast_forward(T, out, args...)
end
end

# Real input and real output
function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> x.value, out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
Expand All @@ -259,6 +283,60 @@ end
return y, bc_fwd_back
end

# This handles complex output and real input
function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex.(real(x).value, imag(x).value), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles complex input and real output we use the gradient definition from ChainRules here
# since it agrees with what Zygote did for real(x).
function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> x.value, out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(o1.partials[i], o1.partials[i+N]), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# # # This is for complex input and complex output
# # # I am a little confused what derivative we want to use here but this should match
# what is done for all the tests

# If we assume that
# f(x + iy) = u(x,y) + iv(x,y)
# then we do the following for the adjoint
# Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y )
# this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html
function _adjoint_complex(Δz, df, i)
Δu, Δv = reim(Δz)
du, dv = reim(df)
return Complex(Δu*du.partials[i] + Δv*dv.partials[i], Δu*du.partials[i+N] + Δv*dv.partials[i+N])
end

function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex.(real(x).value, imag(x).value), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(y1, o1, i), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame

# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
Expand Down Expand Up @@ -287,4 +365,3 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
end

pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz]

1 change: 0 additions & 1 deletion test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,3 @@ end
end
@test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
end

60 changes: 39 additions & 21 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end
g_gpu = gradient(x -> v(x, 7), a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g

w(x) = sum(broadcast(log, x))
g = gradient(x -> w(x), a)[1]
g_gpu = gradient(x -> w(x), a_gpu)[1]
Expand All @@ -38,7 +38,7 @@ end
@test gradient(x -> sum(x .> 3), a_gpu) == (nothing,)
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
@test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]

# Projection: eltype preservation:
@test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32}
Expand Down Expand Up @@ -90,40 +90,40 @@ end
@testset "gradient algebra" begin
w, b = rand(2) |> cu, rand(2) |> cu
x1, x2 = rand(2) |> cu, rand(2) |> cu
gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test .- gs1 isa Grads
@test gs1 .- gs2 isa Grads
@test gs1 .- gs2 isa Grads
@test .+ gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test (2 .* gs1)[w] ≈ 2 * gs1[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w]

gs12 = gs1 .+ gs2
gs1 .+= gs2
@test gs12[w] ≈ gs1[w]
@test gs12[w] ≈ gs1[w]

gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))

@test .- gs3 isa Grads
@test gs3 .- gs4 isa Grads
@test gs3 .- gs4 isa Grads
@test .+ gs3 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]
@test (gs3 .+ gs4)[b] ≈ gs4[b]

@test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads
gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3))
@test gs3 isa Grads
@test gs3 isa Grads

@test_throws ArgumentError gs1 .+ gs4
end
Expand All @@ -140,3 +140,21 @@ end
@test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
end


@testset "CUDA complex broadcasting" begin
# Issue 961 and 1121 and 1215
x = rand(Float32, 50)
y = complex(rand(Float32, 50))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why define x here at all?

Also, this y has zero imaginary part. rand(ComplexF64, 50) would be a stronger test.

julia> complex(rand(Float32, 50))
50-element Vector{ComplexF32}:
  0.89825445f0 + 0.0f0im
  0.40070343f0 + 0.0f0im
  0.29411656f0 + 0.0f0im
  0.44503874f0 + 0.0f0im

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! That x was for a test I was doing on my machine. I think overall that the testing could be a bit better though so I've added another test that uses both real and complex arguments. I probably need to add some additional tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. I think x.^2 .*y .+ y uses only functions which have special rules, and ought to work without this PR. I think even broadcasting trivial functions like add(x,y) = x+y will change the path it takes. But messy examples (e.g. with trig, conj/real/imag, in all sorts of ways) are much more likely to expose mistakes like a conj missing somewhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to invent some functions, did not try them on GPU:

r3 = Float32.(inv.(2:4))
c3 = ComplexF32.(inv.(5:7) .+ im ./ (8:10))

@test gradient(r -> sum(abs2, log.(1 .+ im .* r)./2), r3)[1] ≈ [0.2077734, 0.15268978, 0.11885023]
@test gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)[1] ≈ [-0.4124833f0 + 0.49228126f0im, -0.4258298f0 + 0.49446818f0im, -0.43560573f0 + 0.49583605f0im]
@test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2] ≈ [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im]

But locally, with this branch, I expected them to use the new code... but adding printing doesn't seem to work?

(jl_S8DfLf) pkg> st Zygote
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_S8DfLf/Project.toml`
  [e88e6eb3] Zygote v0.6.49 `https://github.com/ptiede/Zygote.jl#pt-complexbroadcast`

julia> @eval Zygote function dual(x::Complex, i, N)  # from PR, with printing
            @show x
            re_dual = Dual(real(x), ntuple(==(i), 2N))
            im_dual = Dual(imag(x), ntuple(==(N+i), 2N))
            return Complex(re_dual, im_dual)
        end;

julia> Zygote.refresh()

julia> @test gradient(r -> sum(abs2, log.(1 .+ im .* r)./2), r3)[1] ≈ [0.2077734, 0.15268978, 0.11885023]
Test Passed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I looked into this and this occurred because I hadn't added a Complex method for _dual_safearg. When I added this some issues started to appear. One of them was because the partials for the real and complex parts had different lengths.

However, that is not the big issue. The big issue is that certain functions seem to be causing some type instabilities during the evaluation of the dual numbers. For instance,

x = rand(Complex{Float32}, 100)
f(x) = sum(abs2, log.(y))
@code_warntype Zygote.dual_function(f).(x)

MethodInstance for (::var"##dotfunction#314#7")(::Vector{ComplexF32})
  from (::var"##dotfunction#314#7")(x1) in Main
Arguments
  #self#::Core.Const(var"##dotfunction#314#7"())
  x1::Vector{ComplexF32}
Body::Union{Vector{ForwardDiff.Dual{Float32, Float32, 2}}, Vector{ForwardDiff.Dual{Float32, V, 2} where V}, Vector{ForwardDiff.Dual{Float32, Float64, 2}}}
1%1 = Zygote.dual_function::Core.Const(Zygote.dual_function)
│   %2 = (%1)(Main.f)::Core.Const(Zygote.var"#944#946"{typeof(f)}(f))
│   %3 = Base.broadcasted(%2, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#944#946"{typeof(f)}, Tuple{Vector{ComplexF32}}}%4 = Base.materialize(%3)::Union{Vector{ForwardDiff.Dual{Float32, Float32, 2}}, Vector{ForwardDiff.Dual{Float32, V, 2} where V}, Vector{ForwardDiff.Dual{Float32, Float64, 2}}}
└──      return %4```

Has a problem where the broadcast can't seem to figure out that eltype of the partial field in Dual should be a Float32. What is really annoying is that this problem does not occur for Float64 where I get

x64 = Complex{Float64}.(x)
@code_warntype Zygote.dual_function(f)(x64)

MethodInstance for (::var"##dotfunction#313#6")(::Vector{ComplexF64})
  from (::var"##dotfunction#313#6")(x1) in Main
Arguments
  #self#::Core.Const(var"##dotfunction#313#6"())
  x1::Vector{ComplexF64}
Body::Vector{ForwardDiff.Dual{Float64, Float64, 2}}
1%1 = Zygote.dual_function::Core.Const(Zygote.dual_function)
│   %2 = (%1)(Main.f)::Core.Const(Zygote.var"#944#946"{typeof(f)}(f))
│   %3 = Base.broadcasted(%2, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#944#946"{typeof(f)}, Tuple{Vector{ComplexF64}}}%4 = Base.materialize(%3)::Vector{ForwardDiff.Dual{Float64, Float64, 2}}
└──      return %4


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok looking into this more. It appears the log with Complex{Dual{Float32}} arguments is type unstable.
My guess is that this occurs because there isn't using the specific forward rule for a complex number for log, or likely any common functions.

Copy link
Member

@mcabbott mcabbott Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is weird, @code_warntype log(Dual(1f0, 1f0) + im) is bad. Inside Base.ssqs, it looks like ldexp(Dual(1f0, 2f0), 3) makes a Float64 dual, by a method from ForwardDiff.

Anyway not this PR's problem! Maybe make an issue on ForwardDiff (or DiffRules) and test inference etc. with other functions here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sounds good! I'll skip log for now and make tests for other functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright I was able to add the last test,

@test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2]  [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im]

and everything passes! The other two tests suggested both run into the ldexp problem with Float32. I have opened up an issue JuliaDiff/ForwardDiff.jl#604 detailing the problem. The good news is that when I fix the problem locally all the tests pass!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are a couple of updates on my end. First, I just realized I was running the previous test on the CPU. When I run it on the GPU, I get a scalar indexing error. The stack trace is

julia>     @test gradcheck_gpu((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)
Error During Test at /home/ptiede/.julia/dev/Zygote/test/cuda.jl:186
  Test threw exception
  Expression: gradcheck_gpu(((r, c)->begin
            sum(abs2, #= /home/ptiede/.julia/dev/Zygote/test/cuda.jl:186 =# @__dot__(sin(conj(c) / r' - im) - imag(c + tanh(r / c'))))
        end), r3, c3)
  Scalar indexing is disallowed.
  Invocation of getindex resulted in scalar indexing of a GPU array.
  This is typically caused by calling an iterating implementation of a method.
  Such implementations *do not* execute on the GPU, but very slowly on the CPU,
  and therefore are only permitted from the REPL for prototyping purposes.
  If you did intend to index this array, annotate the caller with @allowscalar.
  Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:35
    [2] assertscalar(op::String)
      @ GPUArraysCore ~/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:87
    [3] getindex(::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
      @ GPUArrays ~/.julia/packages/GPUArrays/fqD8z/src/host/indexing.jl:9
    [4] getindex
      @ ~/.julia/juliaup/julia-1.8.2+0.x64/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:180 [inlined]
    [5] _unsafe_getindex_rs
      @ ./reshapedarray.jl:250 [inlined]
    [6] _unsafe_getindex
      @ ./reshapedarray.jl:247 [inlined]
    [7] getindex
      @ ./reshapedarray.jl:235 [inlined]
    [8] iterate
      @ ./abstractarray.jl:1167 [inlined]
    [9] iterate
      @ ./abstractarray.jl:1165 [inlined]
   [10] iterate
      @ ./generator.jl:44 [inlined]
   [11] _collect(c::Base.ReshapedArray{ComplexF32, 1, LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, itr::Base.Generator{Base.ReshapedArray{ComplexF32, 1, LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
      @ Base ./array.jl:807
   [12] collect_similar
      @ ./array.jl:716 [inlined]
   [13] map
      @ ./abstractarray.jl:2933 [inlined]
   [14] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}})
      @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/projection.jl:236
   [15] ProjectTo
      @ ~/.julia/packages/ChainRulesCore/C73ay/src/projection.jl:414 [inlined]
   [16] _project
      @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:184 [inlined]
   [17] unbroadcast(x::LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, x̄::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
      @ Zygote ~/.julia/dev/Zygote/src/lib/broadcast.jl:58
   [18] (::Zygote.var"#857#858"{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
      @ Zygote ~/.julia/dev/Zygote/src/lib/broadcast.jl:97
   [19] (::Zygote.var"#3669#back#859"{Zygote.var"#857#858"{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}})(Δ::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
      @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
   [20] Pullback
      @ ./none:0 [inlined]
   [21] (::typeof((#13)))(Δ::Float32)
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
   [22] (::Zygote.var"#60#61"{typeof((#13))})(Δ::Float32)
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
   [23] gradient(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
   [24] gradcheck_gpu(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
      @ Main ~/.julia/dev/Zygote/test/cuda.jl:9
   [25] top-level scope

From the look of the stack trace, this isn't due to this pull request. In fact, if I change the function definition to

sin(conj(c)/$(transpose(r)) - im) - imag(c + tanh(r/c')))

then everything is fine, so my guess is that this is some funkiness related to the pullback of an adjoint of a real vector. I'll take a look into this, but I am not sure if that's part of this pull request.

Second, I have added some additional tests to ensure we hit every one of the _broadcast_forward branches.


xgpu = cu(x)
ygpu = cu(y)


g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1]
g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1]
g3 = Zygote.gradient(x->sum(abs2, x), y)[1]
@test g1 isa CUDA.CuArray{ComplexF32}
@test g2 isa CUDA.CuArray{ComplexF32}
@test collect(g1) ≈ collect(g2)
@test collect(g1) ≈ g3
end