From e9aeead3741201615a94135c88b095517705816e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Nov 2024 21:59:02 -0500 Subject: [PATCH] docs: keep the ConvMixer default backend as cuda.jl for now --- examples/ConvMixer/Project.toml | 2 +- examples/ConvMixer/main.jl | 2 +- ext/LuxReactantExt/patches.jl | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 6e4b2dd962..6f2fd8702d 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -25,7 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Comonicon = "1.0.8" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" -Enzyme = "0.13.14" +Enzyme = "0.13.16" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index d4da2ca946..602ede83c7 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -76,7 +76,7 @@ end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, - backend::String="reactant") + backend::String="gpu_if_available") rng = StableRNG(seed) if backend == "gpu_if_available" diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 0af6705e4c..56876d8e02 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -5,3 +5,6 @@ LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x) function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber) return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y)) end + +# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint +Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g