Skip to content

Commit

Permalink
docs: keep the ConvMixer default backend as cuda.jl for now
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 6, 2024
1 parent c03be3a commit db9486a
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Comonicon = "1.0.8"
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3"
Enzyme = "0.13.14"
Enzyme = "0.13.16"
ImageCore = "0.10.2"
ImageShow = "0.3.8"
Interpolations = "0.15.1"
Expand All @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.5"
Reactant = "0.2.8"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
1 change: 1 addition & 0 deletions examples/ConvMixer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Options
--seed <42::Int>
--epochs <25::Int>
--lr-max <0.01::Float64>
--backend <reactant::String>

Flags
--clip-norm
Expand Down
13 changes: 9 additions & 4 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end
Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="reactant")
backend::String="gpu_if_available")
rng = StableRNG(seed)

if backend == "gpu_if_available"
Expand Down Expand Up @@ -111,13 +111,16 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::

if backend == "reactant"
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device
model_compiled = @compile model(x_ra, ps, st)
@printf "[Info] Compiling model with Reactant.jl\n"
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
@printf "[Info] Model compiled!\n"
else
model_compiled = model
end

loss = CrossEntropyLoss(; logits=Val(true))

@printf "[Info] Training model\n"
for epoch in 1:epochs
stime = time()
lr = 0
Expand All @@ -127,6 +130,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
(_, _, _, train_state) = Training.single_train_step!(
adtype, loss, (x, y), train_state
)
@show i, time() - stime
end
ttime = time() - stime

Expand All @@ -137,7 +141,8 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
model_compiled, train_state.parameters, train_state.states, testloader
) * 100

@printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \
Time: %.2f\n" epoch lr train_acc test_acc ttime
@printf "[Train] Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: \
%.2f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime
end
@printf "[Info] Finished training\n"
end
3 changes: 3 additions & 0 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x)
function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber)
return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y))
end

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g

0 comments on commit db9486a

Please sign in to comment.