-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inbuilt-Distributed Setup #500
Conversation
cad7299
to
5aba87a
Compare
Here's a full working script for distributed training: using Lux, Optimisers, Random, Zygote
using MPI, NCCL, LuxCUDA
DistributedUtils.initialize(Val(:NCCL))
backend = DistributedUtils.get_distributed_backend(Val(:NCCL))
rank = DistributedUtils.local_rank(backend)
CUDA.allowscalar(false)
gdev = gpu_device()
model = Chain(Dense(1 => 256, tanh), Dense(256 => 512, tanh), Dense(512 => 256, tanh),
Dense(256 => 1))
rng = Random.default_rng()
Random.seed!(rng, 100 * rank)
ps, st = Lux.setup(rng, model) .|> gdev
ps = DistributedUtils.synchronize!!(backend, ps)
st = DistributedUtils.synchronize!!(backend, st)
x = rand(rng, 1, 256) |> gdev
y = x .^ 2
opt = DistributedUtils.DistributedOptimizer(backend, Adam(0.001f0))
st_opt = Optimisers.setup(opt, ps)
loss(model, p, st, x, y) = sum(abs2, first(model(x, p, st)) .- y)
st_opt = DistributedUtils.synchronize!!(backend, st_opt)
t1 = time()
for epoch in 1:1024
global ps, st_opt
l, back = Zygote.pullback(loss, model, ps, st, x, y)
if rank == 0
println("Epoch $epoch: Loss $l")
end
gs = back(one(l))[2]
st_opt, ps = Optimisers.update!(st_opt, ps, gs)
end
if rank == 0
t2 = time()
println("Time: ", t2 - t1)
end |
@vpuri3 this might be of interest to you. It will still take me sometime to finish the tests and merge it but rn it seems fully usable. |
7b50b2d
to
62f8d39
Compare
cc5c883
to
5b7279d
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #500 +/- ##
==========================================
+ Coverage 87.33% 87.56% +0.22%
==========================================
Files 33 38 +5
Lines 1729 1962 +233
==========================================
+ Hits 1510 1718 +208
- Misses 219 244 +25 ☔ View full report in Codecov by Sentry. |
661793f
to
30fb4db
Compare
b26ed78
to
6f52b80
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: f9fe7b6 | Previous: 992c448 | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3670.625 ns |
3209.75 ns |
1.14 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
8166.833333333333 ns |
7926.700000000001 ns |
1.03 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
15163 ns |
19366 ns |
0.78 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9952.6 ns |
9810.2 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
8996.666666666666 ns |
8866.5 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4176.666666666667 ns |
4243.444444444444 ns |
0.98 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
2052.8 ns |
1990.7 ns |
1.03 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1675.4142857142858 ns |
1652.1608391608393 ns |
1.01 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1812.0566037735848 ns |
1795.509090909091 ns |
1.01 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
180.7260083449235 ns |
180.27871148459383 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17763 ns |
17533 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
18554 ns |
18454 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
36097 ns |
35486 ns |
1.02 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
28853 ns |
28794 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
19757 ns |
19636 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
16280 ns |
16290 ns |
1.00 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
4933.428571428572 ns |
4879.142857142857 ns |
1.01 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
4852.642857142857 ns |
4920.571428571428 ns |
0.99 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4937.857142857143 ns |
4869 ns |
1.01 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1657.1 ns |
1667.1 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
49279538 ns |
40481995 ns |
1.22 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
79129345 ns |
74749324 ns |
1.06 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
111856614.5 ns |
81624105.5 ns |
1.37 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
105769295.5 ns |
92663753.5 ns |
1.14 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
78788834 ns |
78428882 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
11708664 ns |
11787949 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
18417012.5 ns |
11402177 ns |
1.62 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
12582160 ns |
11738847.5 ns |
1.07 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
18720703 ns |
11450503 ns |
1.63 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6391111 ns |
6435044 ns |
0.99 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1) |
106606642 ns |
103518139 ns |
1.03 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
787082804 ns |
745269744 ns |
1.06 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2872643380 ns |
3135411455 ns |
0.92 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1) |
172875852 ns |
154961766 ns |
1.12 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
1116210555.5 ns |
1170393896.5 ns |
0.95 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
3827495527 ns |
3517146717 ns |
1.09 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 1) |
85830018 ns |
85073249 ns |
1.01 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
767039358.5 ns |
674594478.5 ns |
1.14 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
3200463157 ns |
2973836453 ns |
1.08 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1) |
26367187 ns |
25306347 ns |
1.04 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
238396112 ns |
213635745 ns |
1.12 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
849048292 ns |
948216426 ns |
0.90 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1) |
27268225.5 ns |
27397376 ns |
1.00 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
227771788 ns |
236254961 ns |
0.96 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
938007837.5 ns |
865698401 ns |
1.08 |
vgg16/cpu/forward/Flux/(32, 32, 3, 1) |
25071384 ns |
23453859.5 ns |
1.07 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
188774128 ns |
199823887 ns |
0.94 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
714525175.5 ns |
710083361 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1040326036.5 ns |
1035420079 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1875024925.5 ns |
1867349244 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2327458517 ns |
2228923342.5 ns |
1.04 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2366544195 ns |
2547118079 ns |
0.93 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
1924084397.5 ns |
1864452179 ns |
1.03 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
385275893 ns |
371021402.5 ns |
1.04 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
459059587 ns |
383674219 ns |
1.20 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
365241028.5 ns |
457325522 ns |
0.80 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
12335916 ns |
12002780 ns |
1.03 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
18686861 ns |
18115906 ns |
1.03 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
20034556.5 ns |
19350790.5 ns |
1.04 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
24943545 ns |
23961969 ns |
1.04 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
18608073 ns |
18121673 ns |
1.03 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1167399 ns |
1162830 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
2176372 ns |
2086677 ns |
1.04 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
2188766 ns |
2117214 ns |
1.03 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
2181562 ns |
2080675.5 ns |
1.05 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
215273 ns |
219999 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
311211 ns |
299366 ns |
1.04 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
283079 ns |
274841 ns |
1.03 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
379910 ns |
361647 ns |
1.05 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
422485.5 ns |
411114 ns |
1.03 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
281556 ns |
273298 ns |
1.03 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
402713 ns |
396156 ns |
1.02 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
93534 ns |
89937 ns |
1.04 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
95167.5 ns |
89126 ns |
1.07 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
89488 ns |
87172 ns |
1.03 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104776 ns |
104404 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
206529053 ns |
191817225 ns |
1.08 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
375240316 ns |
373766933.5 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
402169557 ns |
398777811 ns |
1.01 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
457008787.5 ns |
461407979 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
374904103 ns |
371779027 ns |
1.01 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
363619515 ns |
357643985 ns |
1.02 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
53523625 ns |
49838086.5 ns |
1.07 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
49915096 ns |
52974011 ns |
0.94 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
49692758 ns |
58869319 ns |
0.84 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
28914148 ns |
28370869 ns |
1.02 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
20356164.5 ns |
19711465.5 ns |
1.03 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19869739.5 ns |
19683784 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
23945565 ns |
23236285 ns |
1.03 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
24592166 ns |
24064172.5 ns |
1.02 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19983233 ns |
19579572 ns |
1.02 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6669023.5 ns |
6567220 ns |
1.02 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6667977 ns |
6568408 ns |
1.02 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6535538 ns |
6518469 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
2053641
to
6cbdc76
Compare
7206255
to
93cedad
Compare
Fixes #494
NCCL.avg
correctly JuliaGPU/NCCL.jl#54)