From 6098786facc82d2de072b2e72147578c9e390514 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Nov 2024 21:54:01 -0500 Subject: [PATCH] refactor: use direct reduction for global pooling --- src/layers/pooling.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 943eb947c1..3be3a5e24e 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -40,15 +40,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode abstract type AbstractPoolOp end struct MaxPoolOp <: AbstractPoolOp end + (m::MaxPoolOp)(x, pdims) = maxpool(x, pdims) +function (m::MaxPoolOp)(x, ::GlobalPoolMode) + return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf)) +end struct MeanPoolOp <: AbstractPoolOp end + (m::MeanPoolOp)(x, pdims) = meanpool(x, pdims) +(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2)) @concrete struct LpPoolOp <: AbstractPoolOp p end + (m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p) +(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p) symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp() symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()