diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 6dbfdb829..3a322c5fe 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -170,14 +170,13 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...)) collapse_nothings(xs::AbstractArray{Nothing}) = nothing collapse_nothings(xs) = xs -_dual_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F) -_dual_purefun(::Type) = false +_dual_purefun(::Type{F}) where {F} = Base.issingletontype(F) _dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers _dual_safearg(x::Numeric{<:Real}) = true _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true -_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types -_dual_safearg(x) = false +_dual_safearg(x::Union{Val, Symbol, Char, AbstractString}) = true # non-differentiable types +_dual_safearg(x::T) where {T} = Base.issingletontype(T) || Base.issingletontype(eltype(T)) @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) @@ -226,9 +225,9 @@ end import ForwardDiff using ForwardDiff: Dual -dual(x, p) = x dual(x::Real, p) = Dual(x, p) -dual(x::Bool, p) = x +dual(x::Bool, p) = x # must ignore +dual(x, p) = x # safe to ignore: trust _dual_safearg() elsewhere function dual_function(f::F) where F function (args::Vararg{Any,N}) where N @@ -239,7 +238,14 @@ function dual_function(f::F) where F end end -@inline function broadcast_forward(f, args::Vararg{Any,N}) where N +@inline function broadcast_forward(f::F, args::Vararg{Any,N}) where {F,N} + Base.issingletontype(F) || @warn ("""Zygote's dual number broadcasting (as used on GPU arrays) cannot track gradients with respect to `f`, + and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`). + typeof(f) = $(F)""") maxlog=1 _id=hash(F) + for a in args + _dual_safearg(a) || error("""Zygote's dual number broadcasting (as used on GPU arrays) cannot handle this argument. + typeof(a) = $(typeof(a))""") + end valN = Val(N) out = dual_function(f).(args...) eltype(out) <: Dual || return (out, _ -> nothing) diff --git a/test/cuda.jl b/test/cuda.jl index 5cb1c8cdc..ea2cbeb82 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -48,6 +48,19 @@ end @test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal # non-differentiables @test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing + + # Errors -- #1215 + y = complex.([4,1]) |> cu + x = complex.([3,2]) |> cu + function f1215(x, y) + x = 2 .* x + return sum(abs2.(x .- y)) + end + @test_throws ErrorException gradient(()-> f1215(x,y), Zygote.Params([x])) + + # From #1018 + @test gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0]) + @test_skip gradient((x,y) -> sum((z->z^2+y[1]).(x)), cu([1,2,3]), cu([4,5])) # if not right, should ideally be an error end @testset "sum(f, x)" begin