Skip to content

Commit

Permalink
refactor: use direct reduction for global pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 15, 2024
1 parent c3f1503 commit 5cfc240
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ end

struct GlobalPoolMode <: AbstractPoolMode end

(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)])
# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
(::GlobalPoolMode)() = GlobalPoolMode()

@concrete struct AdaptivePoolMode <: AbstractPoolMode
out_size <: Tuple{Vararg{IntegerType}}
Expand All @@ -33,15 +34,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode
abstract type AbstractPoolOp end

struct MaxPoolOp <: AbstractPoolOp end

(m::MaxPoolOp)(x, pdims) = maxpool(x, pdims)
function (m::MaxPoolOp)(x, ::GlobalPoolMode)
return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf))
end

struct MeanPoolOp <: AbstractPoolOp end

(m::MeanPoolOp)(x, pdims) = meanpool(x, pdims)
(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2))

@concrete struct LpPoolOp <: AbstractPoolOp
p
end

(m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p)
(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p)

symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp()
symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()
Expand Down

0 comments on commit 5cfc240

Please sign in to comment.