Skip to content

Commit

Permalink
refactor: use direct reduction for global pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 6, 2024
1 parent abe8011 commit 6098786
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode
abstract type AbstractPoolOp end

struct MaxPoolOp <: AbstractPoolOp end

(m::MaxPoolOp)(x, pdims) = maxpool(x, pdims)
function (m::MaxPoolOp)(x, ::GlobalPoolMode)
return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf))
end

struct MeanPoolOp <: AbstractPoolOp end

(m::MeanPoolOp)(x, pdims) = meanpool(x, pdims)
(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2))

@concrete struct LpPoolOp <: AbstractPoolOp
p
end

(m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p)
(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p)

symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp()
symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()
Expand Down

0 comments on commit 6098786

Please sign in to comment.