diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 943eb947c..3be3a5e24 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -40,15 +40,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode abstract type AbstractPoolOp end struct MaxPoolOp <: AbstractPoolOp end + (m::MaxPoolOp)(x, pdims) = maxpool(x, pdims) +function (m::MaxPoolOp)(x, ::GlobalPoolMode) + return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf)) +end struct MeanPoolOp <: AbstractPoolOp end + (m::MeanPoolOp)(x, pdims) = meanpool(x, pdims) +(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2)) @concrete struct LpPoolOp <: AbstractPoolOp p end + (m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p) +(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p) symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp() symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()