diff --git a/Project.toml b/Project.toml index 140fb7cf3..ad3b107ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.2.3" +version = "1.2.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index f29bc8db4..331bef7e5 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -13,7 +13,8 @@ end struct GlobalPoolMode <: AbstractPoolMode end -(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)]) +# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint +(::GlobalPoolMode)() = GlobalPoolMode() @concrete struct AdaptivePoolMode <: AbstractPoolMode out_size <: Tuple{Vararg{IntegerType}} @@ -33,15 +34,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode abstract type AbstractPoolOp end struct MaxPoolOp <: AbstractPoolOp end + (m::MaxPoolOp)(x, pdims) = maxpool(x, pdims) +function (m::MaxPoolOp)(x, ::GlobalPoolMode) + return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf)) +end struct MeanPoolOp <: AbstractPoolOp end + (m::MeanPoolOp)(x, pdims) = meanpool(x, pdims) +(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2)) @concrete struct LpPoolOp <: AbstractPoolOp p end + (m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p) +(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p) symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp() symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()