Skip to content

Commit

Permalink
refactor: use direct reduction for global pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 11, 2024
1 parent f95cbbd commit ef770a9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.3"
version = "1.2.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
11 changes: 10 additions & 1 deletion src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ end

struct GlobalPoolMode <: AbstractPoolMode end

(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)])
# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
(::GlobalPoolMode)() = GlobalPoolMode()

@concrete struct AdaptivePoolMode <: AbstractPoolMode
out_size <: Tuple{Vararg{IntegerType}}
Expand All @@ -33,15 +34,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode
abstract type AbstractPoolOp end

struct MaxPoolOp <: AbstractPoolOp end

(m::MaxPoolOp)(x, pdims) = maxpool(x, pdims)
function (m::MaxPoolOp)(x, ::GlobalPoolMode)
return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf))
end

struct MeanPoolOp <: AbstractPoolOp end

(m::MeanPoolOp)(x, pdims) = meanpool(x, pdims)
(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2))

@concrete struct LpPoolOp <: AbstractPoolOp
p
end

(m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p)
(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p)

symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp()
symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()
Expand Down

0 comments on commit ef770a9

Please sign in to comment.