diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 4acddf5243..0d620a132f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -119,11 +119,10 @@ end @functor Dense -function (a::Dense)(x::AbstractVecOrMat) +function (a::Dense)(x::AbstractArray) W, b, σ = a.W, a.b, a.σ - # reshape to handle dims > 1 as batch dimensions sz = size(x) - x = reshape(x, sz[1], :) + x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions x = σ.(W*x .+ b) return reshape(x, :, sz[2:end]...) end