Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inbuilt-Distributed Setup #500

Merged
merged 14 commits into from
Apr 7, 2024
Merged

Inbuilt-Distributed Setup #500

merged 14 commits into from
Apr 7, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Feb 25, 2024

Fixes #494

@avik-pal avik-pal force-pushed the ap/nccl branch 2 times, most recently from cad7299 to 5aba87a Compare February 25, 2024 06:34
@avik-pal
Copy link
Member Author

Here's a full working script for distributed training:

using Lux, Optimisers, Random, Zygote
using MPI, NCCL, LuxCUDA

DistributedUtils.initialize(Val(:NCCL))
backend = DistributedUtils.get_distributed_backend(Val(:NCCL))
rank = DistributedUtils.local_rank(backend)

CUDA.allowscalar(false)

gdev = gpu_device()

model = Chain(Dense(1 => 256, tanh), Dense(256 => 512, tanh), Dense(512 => 256, tanh),
    Dense(256 => 1))
rng = Random.default_rng()
Random.seed!(rng, 100 * rank)
ps, st = Lux.setup(rng, model) .|> gdev

ps = DistributedUtils.synchronize!!(backend, ps)
st = DistributedUtils.synchronize!!(backend, st)

x = rand(rng, 1, 256) |> gdev
y = x .^ 2

opt = DistributedUtils.DistributedOptimizer(backend, Adam(0.001f0))
st_opt = Optimisers.setup(opt, ps)

loss(model, p, st, x, y) = sum(abs2, first(model(x, p, st)) .- y)

st_opt = DistributedUtils.synchronize!!(backend, st_opt)

t1 = time()

for epoch in 1:1024
    global ps, st_opt
    l, back = Zygote.pullback(loss, model, ps, st, x, y)
    if rank == 0
        println("Epoch $epoch: Loss $l")
    end
    gs = back(one(l))[2]
    st_opt, ps = Optimisers.update!(st_opt, ps, gs)
end

if rank == 0
    t2 = time()
    println("Time: ", t2 - t1)
end

@avik-pal
Copy link
Member Author

@vpuri3 this might be of interest to you. It will still take me sometime to finish the tests and merge it but rn it seems fully usable.

@avik-pal avik-pal force-pushed the ap/nccl branch 2 times, most recently from 7b50b2d to 62f8d39 Compare March 2, 2024 16:03
ext/LuxMPIExt.jl Outdated Show resolved Hide resolved
ext/LuxMPIExt.jl Outdated Show resolved Hide resolved
src/distributed/public_api.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/nccl branch 9 times, most recently from cc5c883 to 5b7279d Compare March 24, 2024 22:45
Copy link

codecov bot commented Mar 25, 2024

Codecov Report

Attention: Patch coverage is 89.30041% with 26 lines in your changes are missing coverage. Please review.

Project coverage is 87.56%. Comparing base (992c448) to head (f9fe7b6).

Files Patch % Lines
src/distributed/public_api.jl 77.27% 15 Missing ⚠️
ext/LuxMPIExt.jl 93.18% 6 Missing ⚠️
ext/LuxOptimisersExt.jl 77.77% 2 Missing ⚠️
src/distributed/backend.jl 75.00% 2 Missing ⚠️
ext/LuxSimpleChainsExt.jl 88.88% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #500      +/-   ##
==========================================
+ Coverage   87.33%   87.56%   +0.22%     
==========================================
  Files          33       38       +5     
  Lines        1729     1962     +233     
==========================================
+ Hits         1510     1718     +208     
- Misses        219      244      +25     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@avik-pal avik-pal force-pushed the main branch 3 times, most recently from 661793f to 30fb4db Compare April 6, 2024 15:48
@avik-pal avik-pal marked this pull request as ready for review April 6, 2024 23:23
@avik-pal avik-pal force-pushed the ap/nccl branch 4 times, most recently from b26ed78 to 6f52b80 Compare April 7, 2024 00:09
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: f9fe7b6 Previous: 992c448 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3670.625 ns 3209.75 ns 1.14
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 8166.833333333333 ns 7926.700000000001 ns 1.03
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 15163 ns 19366 ns 0.78
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9952.6 ns 9810.2 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8996.666666666666 ns 8866.5 ns 1.01
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4176.666666666667 ns 4243.444444444444 ns 0.98
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 2052.8 ns 1990.7 ns 1.03
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1675.4142857142858 ns 1652.1608391608393 ns 1.01
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1812.0566037735848 ns 1795.509090909091 ns 1.01
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.7260083449235 ns 180.27871148459383 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17763 ns 17533 ns 1.01
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 18554 ns 18454 ns 1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 36097 ns 35486 ns 1.02
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28853 ns 28794 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19757 ns 19636 ns 1.01
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 16280 ns 16290 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 4933.428571428572 ns 4879.142857142857 ns 1.01
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 4852.642857142857 ns 4920.571428571428 ns 0.99
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4937.857142857143 ns 4869 ns 1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1657.1 ns 1667.1 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 49279538 ns 40481995 ns 1.22
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 79129345 ns 74749324 ns 1.06
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 111856614.5 ns 81624105.5 ns 1.37
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 105769295.5 ns 92663753.5 ns 1.14
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 78788834 ns 78428882 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11708664 ns 11787949 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 18417012.5 ns 11402177 ns 1.62
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 12582160 ns 11738847.5 ns 1.07
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 18720703 ns 11450503 ns 1.63
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6391111 ns 6435044 ns 0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1) 106606642 ns 103518139 ns 1.03
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 787082804 ns 745269744 ns 1.06
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2872643380 ns 3135411455 ns 0.92
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1) 172875852 ns 154961766 ns 1.12
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 1116210555.5 ns 1170393896.5 ns 0.95
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3827495527 ns 3517146717 ns 1.09
vgg16/cpu/reverse/Flux/(32, 32, 3, 1) 85830018 ns 85073249 ns 1.01
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 767039358.5 ns 674594478.5 ns 1.14
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 3200463157 ns 2973836453 ns 1.08
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1) 26367187 ns 25306347 ns 1.04
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 238396112 ns 213635745 ns 1.12
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 849048292 ns 948216426 ns 0.90
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1) 27268225.5 ns 27397376 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 227771788 ns 236254961 ns 0.96
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 938007837.5 ns 865698401 ns 1.08
vgg16/cpu/forward/Flux/(32, 32, 3, 1) 25071384 ns 23453859.5 ns 1.07
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 188774128 ns 199823887 ns 0.94
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 714525175.5 ns 710083361 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1040326036.5 ns 1035420079 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1875024925.5 ns 1867349244 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2327458517 ns 2228923342.5 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2366544195 ns 2547118079 ns 0.93
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1924084397.5 ns 1864452179 ns 1.03
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 385275893 ns 371021402.5 ns 1.04
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 459059587 ns 383674219 ns 1.20
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 365241028.5 ns 457325522 ns 0.80
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 12335916 ns 12002780 ns 1.03
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18686861 ns 18115906 ns 1.03
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 20034556.5 ns 19350790.5 ns 1.04
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 24943545 ns 23961969 ns 1.04
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18608073 ns 18121673 ns 1.03
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1167399 ns 1162830 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2176372 ns 2086677 ns 1.04
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2188766 ns 2117214 ns 1.03
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2181562 ns 2080675.5 ns 1.05
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 215273 ns 219999 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 311211 ns 299366 ns 1.04
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 283079 ns 274841 ns 1.03
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 379910 ns 361647 ns 1.05
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 422485.5 ns 411114 ns 1.03
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 281556 ns 273298 ns 1.03
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 402713 ns 396156 ns 1.02
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 93534 ns 89937 ns 1.04
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 95167.5 ns 89126 ns 1.07
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 89488 ns 87172 ns 1.03
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104776 ns 104404 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 206529053 ns 191817225 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 375240316 ns 373766933.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 402169557 ns 398777811 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 457008787.5 ns 461407979 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 374904103 ns 371779027 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 363619515 ns 357643985 ns 1.02
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 53523625 ns 49838086.5 ns 1.07
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 49915096 ns 52974011 ns 0.94
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 49692758 ns 58869319 ns 0.84
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28914148 ns 28370869 ns 1.02
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 20356164.5 ns 19711465.5 ns 1.03
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19869739.5 ns 19683784 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23945565 ns 23236285 ns 1.03
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24592166 ns 24064172.5 ns 1.02
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19983233 ns 19579572 ns 1.02
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6669023.5 ns 6567220 ns 1.02
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6667977 ns 6568408 ns 1.02
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6535538 ns 6518469 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal force-pushed the ap/nccl branch 11 times, most recently from 2053641 to 6cbdc76 Compare April 7, 2024 17:37
@avik-pal avik-pal force-pushed the ap/nccl branch 4 times, most recently from 7206255 to 93cedad Compare April 7, 2024 19:48
@avik-pal avik-pal merged commit 60947c6 into main Apr 7, 2024
22 checks passed
@avik-pal avik-pal deleted the ap/nccl branch April 7, 2024 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Distributed Training
2 participants