Skip to content

Commit

Permalink
fix warning
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 8, 2021
1 parent d7b088a commit 6d03787
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,8 @@ end

@functor Dense

function (a::Dense)(x::AbstractArray)
if eltype(a.W) != eltype(x)
@warn "Element types of input and weights differ." W=eltype(a.W) x=eltype(x) maxlog=1
end
function (a::Dense)(x::AbstractVecOrMat)
eltype(a.W) == eltype(x) || _dense_typewarn(a, x)
W, b, σ = a.W, a.b, a.σ
# reshape to handle dims > 1 as batch dimensions
sz = size(x)
Expand All @@ -131,6 +129,9 @@ function (a::Dense)(x::AbstractArray)
return reshape(x, :, sz[2:end]...)
end

_dense_typewarn(d, x) = @warn "Element types don't match for layer $d, this will be slow." typeof(d.W) typeof(x) maxlog=1
Zygote.@nograd _dense_typewarn

function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
l.σ == identity || print(io, ", ", l.σ)
Expand Down

0 comments on commit 6d03787

Please sign in to comment.