diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 6e4b2dd962..77ba1ce869 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -25,7 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Comonicon = "1.0.8" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" -Enzyme = "0.13.14" +Enzyme = "0.13.16" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2" Printf = "1.10" ProgressBars = "1.5.1" Random = "1.10" -Reactant = "0.2.5" +Reactant = "0.2.8" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/ConvMixer/README.md b/examples/ConvMixer/README.md index 560b2b1d37..f61bf1c4e3 100644 --- a/examples/ConvMixer/README.md +++ b/examples/ConvMixer/README.md @@ -69,6 +69,7 @@ Options --seed <42::Int> --epochs <25::Int> --lr-max <0.01::Float64> + --backend Flags --clip-norm diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index d4da2ca946..debce9c97c 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -76,7 +76,7 @@ end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, - backend::String="reactant") + backend::String="gpu_if_available") rng = StableRNG(seed) if backend == "gpu_if_available" @@ -111,13 +111,16 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: if backend == "reactant" x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device - model_compiled = @compile model(x_ra, ps, st) + @printf "[Info] Compiling model with Reactant.jl\n" + model_compiled = @compile model(x_ra, ps, Lux.testmode(st)) + @printf "[Info] Model compiled!\n" else model_compiled = model end loss = CrossEntropyLoss(; logits=Val(true)) + @printf "[Info] Training model\n" for epoch in 1:epochs stime = time() lr = 0 @@ -127,6 +130,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: (_, _, _, train_state) = Training.single_train_step!( adtype, loss, (x, y), train_state ) + @show i, time() - stime end ttime = time() - stime @@ -137,7 +141,8 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: model_compiled, train_state.parameters, train_state.states, testloader ) * 100 - @printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \ - Time: %.2f\n" epoch lr train_acc test_acc ttime + @printf "[Train] Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: \ + %.2f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime end + @printf "[Info] Finished training\n" end diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 0af6705e4c..56876d8e02 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -5,3 +5,6 @@ LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x) function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber) return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y)) end + +# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint +Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g