diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index d4da2ca94..602ede83c 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -76,7 +76,7 @@ end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, - backend::String="reactant") + backend::String="gpu_if_available") rng = StableRNG(seed) if backend == "gpu_if_available" diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 0af6705e4..56876d8e0 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -5,3 +5,6 @@ LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x) function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber) return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y)) end + +# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint +Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g