Skip to content

Commit

Permalink
docs: update ConvMixer to support reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 20, 2024
1 parent 1b5dace commit 0dbb5a4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
4 changes: 4 additions & 0 deletions examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -15,6 +16,7 @@ PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -23,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Comonicon = "1.0.8"
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3"
Enzyme = "0.13.14"
ImageCore = "0.10.2"
ImageShow = "0.3.8"
Interpolations = "0.15.1"
Expand All @@ -36,6 +39,7 @@ PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.5"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
3 changes: 3 additions & 0 deletions examples/ConvMixer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ for new experiments on small datasets.
You can get around **90.0%** accuracy in just **25 epochs** by running the script with the
following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2.

> [!NOTE]
> To train the model using Reactant.jl pass in `--backend=reactant` to the script.
```bash
julia --startup-file=no \
--project=. \
Expand Down
69 changes: 51 additions & 18 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA,
MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random,
StableRNGs, Statistics, Zygote
using Reactant, Enzyme

CUDA.allowscalar(false)

Expand All @@ -17,7 +18,7 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac
return stack(parent itemdata Base.Fix1(apply, ds.transform), img), y
end

function get_dataloaders(batchsize)
function get_dataloaders(batchsize; kwargs...)
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

Expand All @@ -29,10 +30,10 @@ function get_dataloaders(batchsize)
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std)

trainset = TensorDataset(CIFAR10(:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true)
trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...)

testset = TensorDataset(CIFAR10(:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true)
testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...)

return trainloader, testloader
end
Expand All @@ -43,10 +44,14 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size),
BatchNorm(dim),
[Chain(
SkipConnection(
Chain(Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim,
pad=SamePad()), BatchNorm(dim)), +),
Conv((1, 1), dim => dim, gelu), BatchNorm(dim))
SkipConnection(
Chain(
Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()),
BatchNorm(dim)
),
+
),
Conv((1, 1), dim => dim, gelu), BatchNorm(dim))
for _ in 1:depth]...,
GlobalMeanPool(),
FlattenLayer(),
Expand All @@ -57,10 +62,11 @@ end

function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
Expand All @@ -69,22 +75,46 @@ end

Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01)
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="reactant")
rng = StableRNG(seed)

gdev = gpu_device()
trainloader, testloader = get_dataloaders(batchsize) |> gdev
if backend == "gpu_if_available"
accelerator_device = gpu_device()
elseif backend == "gpu"
accelerator_device = gpu_device(; force=true)
elseif backend == "reactant"
accelerator_device = reactant_device(; force=true)
elseif backend == "cpu"
accelerator_device = cpu_device()
else
error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \
`reactant`, and `cpu`.")
end

kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : ()
trainloader, testloader = get_dataloaders(batchsize; kwargs...) |> accelerator_device

model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)
ps, st = Lux.setup(rng, model) |> gdev
ps, st = Lux.setup(rng, model) |> accelerator_device

opt = AdamW(; eta=lr_max, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

train_state = Training.TrainState(model, ps, st, opt)

lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0])
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
)

adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote()

if backend == "reactant"
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device
model_compiled = @compile model(x_ra, ps, st)
else
model_compiled = model
end

loss = CrossEntropyLoss(; logits=Val(true))

Expand All @@ -95,14 +125,17 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader))
train_state = Optimisers.adjust!(train_state, lr)
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, (x, y), train_state)
adtype, loss, (x, y), train_state
)
end
ttime = time() - stime

train_acc = accuracy(
model, train_state.parameters, train_state.states, trainloader) * 100
test_acc = accuracy(model, train_state.parameters, train_state.states, testloader) *
100
model_compiled, train_state.parameters, train_state.states, trainloader
) * 100
test_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, testloader
) * 100

@printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \
Time: %.2f\n" epoch lr train_acc test_acc ttime
Expand Down

0 comments on commit 0dbb5a4

Please sign in to comment.