Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update ConvMixer to support reactant #1063

Draft
wants to merge 7 commits into
base: ap/reactant_updates
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -15,6 +16,7 @@ PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -23,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Comonicon = "1.0.8"
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3"
Enzyme = "0.13.16"
ImageCore = "0.10.2"
ImageShow = "0.3.8"
Interpolations = "0.15.1"
Expand All @@ -36,6 +39,7 @@ PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.11"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
4 changes: 4 additions & 0 deletions examples/ConvMixer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ for new experiments on small datasets.
You can get around **90.0%** accuracy in just **25 epochs** by running the script with the
following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2.

> [!NOTE]
> To train the model using Reactant.jl pass in `--backend=reactant` to the script.
```bash
julia --startup-file=no \
--project=. \
Expand Down Expand Up @@ -66,6 +69,7 @@ Options
--seed <42::Int>
--epochs <25::Int>
--lr-max <0.01::Float64>
--backend <reactant::String>

Flags
--clip-norm
Expand Down
94 changes: 67 additions & 27 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA,
MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random,
StableRNGs, Statistics, Zygote
using Reactant, Enzyme

CUDA.allowscalar(false)

Expand All @@ -17,7 +18,7 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y
end

function get_dataloaders(batchsize)
function get_dataloaders(batchsize; kwargs...)
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

Expand All @@ -29,10 +30,10 @@ function get_dataloaders(batchsize)
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std)

trainset = TensorDataset(CIFAR10(:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true)
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)

testset = TensorDataset(CIFAR10(:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true)
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)

return trainloader, testloader
end
Expand All @@ -42,12 +43,20 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
return Chain(
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size),
BatchNorm(dim),
[Chain(
SkipConnection(
Chain(Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim,
pad=SamePad()), BatchNorm(dim)), +),
Conv((1, 1), dim => dim, gelu), BatchNorm(dim))
for _ in 1:depth]...,
[
Chain(
SkipConnection(
Chain(
Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()),
BatchNorm(dim)
),
+
),
Conv((1, 1), dim => dim, gelu),
BatchNorm(dim)
)
for _ in 1:depth
]...,
GlobalMeanPool(),
FlattenLayer(),
Dense(dim => 10)
Expand All @@ -57,55 +66,86 @@ end

function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end

Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01)
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-4,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="reactant")
rng = StableRNG(seed)

gdev = gpu_device()
trainloader, testloader = get_dataloaders(batchsize) .|> gdev
if backend == "gpu_if_available"
accelerator_device = gpu_device()
elseif backend == "gpu"
accelerator_device = gpu_device(; force=true)
elseif backend == "reactant"
accelerator_device = reactant_device(; force=true)
elseif backend == "cpu"
accelerator_device = cpu_device()
else
error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \
`reactant`, and `cpu`.")
end

kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : ()
trainloader, testloader = get_dataloaders(batchsize; kwargs...) |> accelerator_device

model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)
ps, st = Lux.setup(rng, model) |> gdev
ps, st = Lux.setup(rng, model) |> accelerator_device

opt = AdamW(; eta=lr_max, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

train_state = Training.TrainState(
model, ps, st, AdamW(; eta=lr_max, lambda=weight_decay))
train_state = Training.TrainState(model, ps, st, opt)

lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0])
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
)

loss = CrossEntropyLoss(; logits=Val(true))
adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote()

if backend == "reactant"
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device
@printf "[Info] Compiling model with Reactant.jl\n"
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
@printf "[Info] Model compiled!\n"
else
model_compiled = model
end

loss_fn = CrossEntropyLoss(; logits=Val(true))

@printf "[Info] Training model\n"
for epoch in 1:epochs
stime = time()
lr = 0
for (i, (x, y)) in enumerate(trainloader)
lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader))
train_state = Optimisers.adjust!(train_state, lr)
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, (x, y), train_state)
adtype, loss_fn, (x, y), train_state
)
end
ttime = time() - stime

train_acc = accuracy(
model, train_state.parameters, train_state.states, trainloader) * 100
test_acc = accuracy(model, train_state.parameters, train_state.states, testloader) *
100

@printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \
Time: %.2f\n" epoch lr train_acc test_acc ttime
model_compiled, train_state.parameters, train_state.states, trainloader
) * 100
test_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, testloader
) * 100

@printf "[Train] Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: \
%.2f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime
end
@printf "[Info] Finished training\n"
end
Loading