From 847efd28cb3cce715c6d97f1592561e2cb6bb16a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 2 Dec 2020 09:57:46 +0100 Subject: [PATCH] fix merge --- src/layers/basic.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 4acddf5243..0d620a132f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -119,11 +119,10 @@ end @functor Dense -function (a::Dense)(x::AbstractVecOrMat) +function (a::Dense)(x::AbstractArray) W, b, σ = a.W, a.b, a.σ - # reshape to handle dims > 1 as batch dimensions sz = size(x) - x = reshape(x, sz[1], :) + x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions x = σ.(W*x .+ b) return reshape(x, :, sz[2:end]...) end