From 6d03787f922534daae922a1d1837cbc8fc97a5c8 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 15 Nov 2020 13:15:52 +0100 Subject: [PATCH] fix warning --- src/layers/basic.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 52379c449e..8805935d10 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -119,10 +119,8 @@ end @functor Dense -function (a::Dense)(x::AbstractArray) - if eltype(a.W) != eltype(x) - @warn "Element types of input and weights differ." W=eltype(a.W) x=eltype(x) maxlog=1 - end +function (a::Dense)(x::AbstractVecOrMat) + eltype(a.W) == eltype(x) || _dense_typewarn(a, x) W, b, σ = a.W, a.b, a.σ # reshape to handle dims > 1 as batch dimensions sz = size(x) @@ -131,6 +129,9 @@ function (a::Dense)(x::AbstractArray) return reshape(x, :, sz[2:end]...) end +_dense_typewarn(d, x) = @warn "Element types don't match for layer $d, this will be slow." typeof(d.W) typeof(x) maxlog=1 +Zygote.@nograd _dense_typewarn + function Base.show(io::IO, l::Dense) print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1)) l.σ == identity || print(io, ", ", l.σ)