diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8a63b89bcc..ca275d4a16 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -313,8 +313,9 @@ end function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) # Calculate size of "input", from ∇conv_data()'s perspective... - combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end]) - I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad + calc_dim(xsz, wsz, stride, dilation, pad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], length(c.pad) ÷ 2) + I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad) C_in = size(c.weight)[end-1] * c.groups batch_size = size(x)[end] # Create DenseConvDims() that looks like the corresponding conv()