From 1b7ece598597f38b4ddb843ebb88cc0e02b466ec Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 19 Mar 2024 14:36:18 -0400 Subject: [PATCH 1/3] better size check for conv layers --- src/layers/basic.jl | 2 +- src/layers/conv.jl | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ac85827a41..ef81c30872 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -193,7 +193,7 @@ function _size_check(layer, x::AbstractArray, (d, n)::Pair) d > 0 || throw(DimensionMismatch(string("layer ", layer, " expects ndims(input) > ", ndims(x)-d, ", but got ", summary(x)))) size(x, d) == n || throw(DimensionMismatch(string("layer ", layer, - " expects size(input, $d) == $n, but got ", summary(x)))) + lazy" expects size(input, $d) == $n, but got ", summary(x)))) end ChainRulesCore.@non_differentiable _size_check(::Any...) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4e6044dcfb..45bb5b3198 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -195,7 +195,7 @@ conv_dims(c::Conv, x::AbstractArray) = ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any) function (c::Conv)(x::AbstractArray) - _size_check(c, x, ndims(x)-1 => _channels_in(c)) + _conv_size_check(c, x) σ = NNlib.fast_act(c.σ, x) cdims = conv_dims(c, x) xT = _match_eltype(c, x) @@ -331,7 +331,7 @@ end ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any) function (c::ConvTranspose)(x::AbstractArray) - _size_check(c, x, ndims(x)-1 => _channels_in(c)) + _conv_size_check(c, x) σ = NNlib.fast_act(c.σ, x) cdims = conv_transpose_dims(c, x) xT = _match_eltype(c, x) @@ -473,7 +473,7 @@ crosscor_dims(c::CrossCor, x::AbstractArray) = ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any) function (c::CrossCor)(x::AbstractArray) - _size_check(c, x, ndims(x)-1 => _channels_in(c)) + _conv_size_check(c, x) σ = NNlib.fast_act(c.σ, x) cdims = crosscor_dims(c, x) xT = _match_eltype(c, x) @@ -487,6 +487,15 @@ function Base.show(io::IO, l::CrossCor) print(io, ")") end +function _conv_size_check(layer, x::AbstractArray) + ndims(x) == ndims(layer.weight) || throw(ArgumentError(LazyString("layer ", layer, + " expects ndims(input) == ", ndims(layer.weight), ", but got ", summary(x)))) + d = ndims(x)-1 + n = _channels_in(layer) + size(x,d) == n || throw(DimensionMismatch(LazyString("layer ", layer, + lazy" expects size(input, $d) == $n, but got ", summary(x)))) +end +ChainRulesCore.@non_differentiable _conv_size_check(::Any, ::Any) """ AdaptiveMaxPool(out::NTuple) From c9fde19188738d9c76157557a58dbc3d99bce767 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 19 Mar 2024 14:42:54 -0400 Subject: [PATCH 2/3] similar for pooling layers --- src/layers/conv.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 45bb5b3198..65268187ee 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -524,6 +524,7 @@ struct AdaptiveMaxPool{S, O} end function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}) where {S, T} + _pool_size_check(a, a.out, x) insize = size(x)[1:end-2] outsize = a.out stride = insize .÷ outsize @@ -565,6 +566,7 @@ struct AdaptiveMeanPool{S, O} end function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}) where {S, T} + _pool_size_check(a, a.out, x) insize = size(x)[1:end-2] outsize = a.out stride = insize .÷ outsize @@ -703,6 +705,7 @@ function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N end function (m::MaxPool)(x) + _pool_size_check(m, m.k, x) pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) return maxpool(x, pdims) end @@ -762,6 +765,7 @@ function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N end function (m::MeanPool)(x) + _pool_size_check(m, m.k, x) pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) return meanpool(x, pdims) end @@ -772,3 +776,11 @@ function Base.show(io::IO, m::MeanPool) m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) print(io, ")") end + +function _pool_size_check(layer, tup::Tuple, x::AbstractArray) + N = length(tup) + 2 + ndims(x) == N || throw(ArgumentError(LazyString("layer ", layer, + " expects ndims(input) == ", N, ", but got ", summary(x)))) +end +ChainRulesCore.@non_differentiable _pool_size_check(::Any, ::Any) + From 55bbb70253aa8ae5591fe8999b6d81055d65a32a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 19 Mar 2024 15:08:02 -0400 Subject: [PATCH 3/3] change to DimensionMismatch --- src/layers/conv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 65268187ee..fdf3c756e9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -488,7 +488,7 @@ function Base.show(io::IO, l::CrossCor) end function _conv_size_check(layer, x::AbstractArray) - ndims(x) == ndims(layer.weight) || throw(ArgumentError(LazyString("layer ", layer, + ndims(x) == ndims(layer.weight) || throw(DimensionMismatch(LazyString("layer ", layer, " expects ndims(input) == ", ndims(layer.weight), ", but got ", summary(x)))) d = ndims(x)-1 n = _channels_in(layer) @@ -779,7 +779,7 @@ end function _pool_size_check(layer, tup::Tuple, x::AbstractArray) N = length(tup) + 2 - ndims(x) == N || throw(ArgumentError(LazyString("layer ", layer, + ndims(x) == N || throw(DimensionMismatch(LazyString("layer ", layer, " expects ndims(input) == ", N, ", but got ", summary(x)))) end ChainRulesCore.@non_differentiable _pool_size_check(::Any, ::Any)