diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index d4da2ca94..602ede83c 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -76,7 +76,7 @@ end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, - backend::String="reactant") + backend::String="gpu_if_available") rng = StableRNG(seed) if backend == "gpu_if_available" diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 0af6705e4..f89c5acb0 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -5,3 +5,6 @@ LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x) function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber) return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y)) end + +# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint +(g::Lux.GlobalPoolMode)(::TracedRArray) = g diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 331bef7e5..da729191c 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -13,8 +13,7 @@ end struct GlobalPoolMode <: AbstractPoolMode end -# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint -(::GlobalPoolMode)() = GlobalPoolMode() +(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)]) @concrete struct AdaptivePoolMode <: AbstractPoolMode out_size <: Tuple{Vararg{IntegerType}}