diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 519d0ecfbe..0f0f7f1e69 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -17,6 +17,7 @@ steps: cuda: "*" env: GROUP: "CUDA" + JULIA_MPI_TEST_NPROCS: 2 # Needs to be same as number of GPUs for NCCL if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 240 matrix: diff --git a/Project.toml b/Project.toml index b7d3ffc6fd..b6da5a9a69 100644 --- a/Project.toml +++ b/Project.toml @@ -120,6 +120,9 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -135,4 +138,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Adapt", "Aqua", "ChainRules", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "Optimisers", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] +test = ["ADTypes", "Adapt", "Aqua", "ChainRules", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] diff --git a/docs/src/manual/distributed_utils.md b/docs/src/manual/distributed_utils.md index da51ddbec5..58f7bbdee3 100644 --- a/docs/src/manual/distributed_utils.md +++ b/docs/src/manual/distributed_utils.md @@ -3,7 +3,7 @@ !!! tip For a fully functional example, see the - [ImageNet Training Example](https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet) + [ImageNet Training Example](https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet). DDP Training using `Lux.DistributedUtils` is a spiritual successor to [FluxMPI.jl](https://github.com/avik-pal/FluxMPI.jl), but has some key differences. diff --git a/examples/ImageNet/Project.toml b/examples/ImageNet/Project.toml index 20373162e9..175af36e03 100644 --- a/examples/ImageNet/Project.toml +++ b/examples/ImageNet/Project.toml @@ -5,7 +5,7 @@ Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" +Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" @@ -30,17 +30,24 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Augmentor = "0.6" Boltz = "0.1, 0.2, 0.3" Configurations = "0.17" -Formatting = "0.4" +FLoops = "0.2" +FileIO = "1.16" +Format = "1.3" Functors = "0.2, 0.3, 0.4" +Images = "0.26" +JLD2 = "0.4.46" JpegTurbo = "0.1" Lux = "0.4, 0.5" LuxAMDGPU = "0.1, 0.2" LuxCUDA = "0.2, 0.3" MLUtils = "0.2.10, 0.3, 0.4" +MPI = "0.20.19" Metalhead = "0.9" +NCCL = "0.1.1" OneHotArrays = "0.1, 0.2" Optimisers = "0.2, 0.3" -Setfield = "0.8.2, 1" +ParameterSchedulers = "0.4" +Setfield = "1" SimpleConfig = "0.1" Statistics = "1" Zygote = "0.6" diff --git a/examples/ImageNet/main.jl b/examples/ImageNet/main.jl index 25112ebbb5..51f7e7ceb1 100644 --- a/examples/ImageNet/main.jl +++ b/examples/ImageNet/main.jl @@ -6,8 +6,7 @@ import FLoops: ThreadedEx import Metalhead import MPI, NCCL using LuxAMDGPU, LuxCUDA - -using Formatting # TODO: Get rid of this +using Format # Distributed Training: NCCL for NVIDIA GPUs and MPI for anything else if LuxCUDA.functional() @@ -79,10 +78,12 @@ function construct(cfg::OptimizerConfig) end if cfg.scheduler.name == "cosine" - scheduler = CosineAnnealSchedule(cfg.learning_rate, cfg.learning_rate / 100, - cfg.scheduler.cycle_length; dampen=cfg.scheduler.damp_factor) + l0 = cfg.learning_rate + l1 = cfg.learning_rate / 100 + scheduler = ComposedSchedule(CosAnneal(l0, l1, cfg.scheduler.cycle_length), + Step(l0, cfg.scheduler.damp_factor, cfg.scheduler.cycle_length)) elseif cfg.scheduler.name == "constant" - scheduler = ConstantSchedule(cfg.learning_rate) + scheduler = Constant(cfg.learning_rate) elseif cfg.scheduler.name == "step" scheduler = Step( cfg.learning_rate, cfg.scheduler.lr_step_decay, cfg.scheduler.lr_step) diff --git a/examples/ImageNet/utils.jl b/examples/ImageNet/utils.jl index fd30a1c870..a2f7a0b74d 100644 --- a/examples/ImageNet/utils.jl +++ b/examples/ImageNet/utils.jl @@ -77,51 +77,6 @@ function load_checkpoint(fname::String) end end -# Parameter Scheduling -## Copied from ParameterSchedulers.jl due to its heavy dependencies -struct CosineAnnealSchedule{restart, T, S <: Integer} - range::T - offset::T - dampen::T - period::S - - function CosineAnnealSchedule( - lambda_0, lambda_1, period; restart::Bool=true, dampen=1.0f0) - range = abs(lambda_0 - lambda_1) - offset = min(lambda_0, lambda_1) - return new{restart, typeof(range), typeof(period)}(range, offset, dampen, period) - end -end - -function (s::CosineAnnealSchedule{true})(t) - d = s.dampen^div(t - 1, s.period) - return (s.range * (1 + cos(pi * mod(t - 1, s.period) / s.period)) / 2 + s.offset) / d -end - -function (s::CosineAnnealSchedule{false})(t) - return s.range * (1 + cos(pi * (t - 1) / s.period)) / 2 + s.offset -end - -struct Step{T, S} - start::T - decay::T - step_sizes::S - - function Step(start::T, decay::T, step_sizes::S) where {T, S} - _step_sizes = (S <: Integer) ? Iterators.repeated(step_sizes) : step_sizes - - return new{T, typeof(_step_sizes)}(start, decay, _step_sizes) - end -end - -(s::Step)(t) = s.start * s.decay^(searchsortedfirst(s.step_sizes, t - 1) - 1) - -struct ConstantSchedule{T} - val::T -end - -(s::ConstantSchedule)(t) = s.val - # Tracking @kwdef mutable struct AverageMeter fmtstr @@ -132,7 +87,7 @@ end end function AverageMeter(name::String, fmt::String) - fmtstr = Formatting.FormatExpr("$name {1:$fmt} ({2:$fmt})") + fmtstr = FormatExpr("$name {1:$fmt} ({2:$fmt})") return AverageMeter(; fmtstr=fmtstr) end @@ -160,7 +115,7 @@ function reset_meter!(meter::AverageMeter) end function print_meter(meter::AverageMeter) - return Formatting.printfmt(meter.fmtstr, meter.val, meter.average) + return printfmt(meter.fmtstr, meter.val, meter.average) end # ProgressMeter @@ -171,10 +126,9 @@ end function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String="") where {N} fmt = "%" * string(length(string(num_batches))) * "d" + fmt2 = "{1:" * string(length(string(num_batches))) * "d}" prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : "" - batch_fmtstr = Formatting.generate_formatter("$prefix[$fmt/" * - Formatting.sprintf1(fmt, num_batches) * - "]") + batch_fmtstr = FormatExpr("$prefix[$fmt2/" * cfmt(fmt, num_batches) * "]") return ProgressMeter{N}(batch_fmtstr, meters) end @@ -184,8 +138,7 @@ function reset_meter!(meter::ProgressMeter) end function print_meter(meter::ProgressMeter, batch::Int) - base_str = meter.batch_fmtstr(batch) - print(base_str) + printfmt(meter.batch_fmtstr, batch) foreach(x -> (print("\t"); print_meter(x)), meter.meters[1:end]) println() return nothing diff --git a/ext/LuxComponentArraysExt.jl b/ext/LuxComponentArraysExt.jl index 95076dbe5e..4ca2e1c02b 100644 --- a/ext/LuxComponentArraysExt.jl +++ b/ext/LuxComponentArraysExt.jl @@ -14,8 +14,8 @@ Lux.__named_tuple(ca::ComponentArray) = NamedTuple(ca) # Distributed Functionality function DistributedUtils.synchronize!!( backend::Lux.AbstractLuxDistributedBackend, ps::ComponentArray; root::Int=0) - ps_synced = DistributedUtils.synchronize!!(backend, getdata(ps); root) - return ComponentArray(ps_synced, getaxes(ps)) + ps_synced = DistributedUtils.synchronize!!(backend, ComponentArrays.getdata(ps); root) + return ComponentArray(ps_synced, ComponentArrays.getaxes(ps)) end end diff --git a/ext/LuxMPIExt.jl b/ext/LuxMPIExt.jl index 457e73a170..73f09fe52b 100644 --- a/ext/LuxMPIExt.jl +++ b/ext/LuxMPIExt.jl @@ -39,7 +39,6 @@ DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm) DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm) # Broadcast - function DistributedUtils.__bcast!( backend::MPIBackend, sendrecvbuf, dev::AbstractLuxDevice; root=0) MPI.Bcast!(sendrecvbuf, backend.comm; root) @@ -78,11 +77,11 @@ for (aware, dType) in ((MPI_CUDA_AWARE, LuxCUDADevice), (MPI_ROCM_AWARE, LuxAMDG end # Allreduce - function DistributedUtils.__allreduce!( backend::MPIBackend, sendrecvbuf, op::F, dev::AbstractLuxDevice) where {F} - MPI.Allreduce!(sendrecvbuf, op, backend.comm) - if op === typeof(DistributedUtils.avg) + mpiop = ifelse(op === DistributedUtils.avg, +, op) + MPI.Allreduce!(sendrecvbuf, mpiop, backend.comm) + if op === DistributedUtils.avg sendrecvbuf ./= DistributedUtils.total_workers(backend) end return sendrecvbuf @@ -90,8 +89,9 @@ end function DistributedUtils.__allreduce!( backend::MPIBackend, sendbuf, recvbuf, op::F, dev::AbstractLuxDevice) where {F} - MPI.Allreduce!(sendbuf, recvbuf, op, backend.comm) - if op === typeof(DistributedUtils.avg) + mpiop = ifelse(op === DistributedUtils.avg, +, op) + MPI.Allreduce!(sendbuf, recvbuf, mpiop, backend.comm) + if op === DistributedUtils.avg recvbuf ./= DistributedUtils.total_workers(backend) end return recvbuf @@ -123,11 +123,11 @@ for (aware, dType) in ((MPI_CUDA_AWARE, LuxCUDADevice), (MPI_ROCM_AWARE, LuxAMDG end # Reduce - function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F, dev::AbstractLuxDevice; root::Int) where {F} - MPI.Reduce!(sendrecvbuf, op, backend.comm; root) - if op === typeof(DistributedUtils.avg) + mpiop = ifelse(op === DistributedUtils.avg, +, op) + MPI.Reduce!(sendrecvbuf, mpiop, backend.comm; root) + if op === DistributedUtils.avg sendrecvbuf ./= DistributedUtils.total_workers(backend) end return sendrecvbuf @@ -135,8 +135,9 @@ end function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F, dev::AbstractLuxDevice; root::Int) where {F} - MPI.Reduce!(sendbuf, recvbuf, op, backend.comm; root) - if op === typeof(DistributedUtils.avg) + mpiop = ifelse(op === DistributedUtils.avg, +, op) + MPI.Reduce!(sendbuf, recvbuf, mpiop, backend.comm; root) + if op === DistributedUtils.avg recvbuf ./= DistributedUtils.total_workers(backend) end return recvbuf diff --git a/ext/LuxMPINCCLExt.jl b/ext/LuxMPINCCLExt.jl index 2bff97cbe0..737a30a1b9 100644 --- a/ext/LuxMPINCCLExt.jl +++ b/ext/LuxMPINCCLExt.jl @@ -33,7 +33,6 @@ DistributedUtils.total_workers(backend::NCCLBackend) = NCCL.size(backend.comm) # For non-CUDA Arrays, fallback to MPI # Broadcast - function DistributedUtils.__bcast!( backend::NCCLBackend, sendrecvbuf, ::LuxCUDADevice; root=0) NCCL.Broadcast!(sendrecvbuf, backend.comm; root) @@ -57,7 +56,6 @@ function DistributedUtils.__bcast!( end # Allreduce - function DistributedUtils.__allreduce!( backend::NCCLBackend, sendrecvbuf, op::F, ::LuxCUDADevice) where {F} op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) @@ -83,7 +81,6 @@ function DistributedUtils.__allreduce!( end # Reduce - function DistributedUtils.__reduce!( backend::NCCLBackend, sendrecvbuf, op::F, ::LuxCUDADevice; root::Int) where {F} op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) diff --git a/ext/LuxOptimisersExt.jl b/ext/LuxOptimisersExt.jl index 92990c0ac6..7a1275a52d 100644 --- a/ext/LuxOptimisersExt.jl +++ b/ext/LuxOptimisersExt.jl @@ -49,7 +49,7 @@ end end function Optimisers.apply!(opt::DistributedOptimizer, state, x, y) - y_avg = allreduce!(opt.backend, y, DistributedUtils.avg) + y_avg = DistributedUtils.allreduce!(opt.backend, y, DistributedUtils.avg) return Optimisers.apply!(opt.opt, state, x, y_avg) end diff --git a/test/distributed/common_distributedtest.jl b/test/distributed/common_distributedtest.jl new file mode 100644 index 0000000000..01d08a9336 --- /dev/null +++ b/test/distributed/common_distributedtest.jl @@ -0,0 +1,15 @@ +using Lux, MPI, NCCL, Test + +const input_args = length(ARGS) == 3 ? ARGS[1:3] : + (length(ARGS) == 2 ? (ARGS[2], "mpi") : ("CPU", "mpi")) +const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend +const dev = input_args[1] == "CPU" ? LuxCPUDevice() : + (input_args[1] == "CUDA" ? LuxCUDADevice() : LuxAMDGPUDevice()) + +DistributedUtils.initialize(backend_type) +backend = DistributedUtils.get_distributed_backend(backend_type) + +@test DistributedUtils.initialized(backend_type) + +# Should always hold true +@test DistributedUtils.local_rank(backend) < DistributedUtils.total_workers(backend) diff --git a/test/distributed/data_distributedtest.jl b/test/distributed/data_distributedtest.jl new file mode 100644 index 0000000000..22c3e0c6cf --- /dev/null +++ b/test/distributed/data_distributedtest.jl @@ -0,0 +1,28 @@ +using Lux, MLUtils, MPI, NCCL, Random, Test + +const input_args = length(ARGS) == 3 ? ARGS[1:3] : + (length(ARGS) == 2 ? (ARGS[2], "mpi") : ("CPU", "mpi")) +const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend +const dev = input_args[1] == "CPU" ? LuxCPUDevice() : + (input_args[1] == "CUDA" ? LuxCUDADevice() : LuxAMDGPUDevice()) + +rng = Xoshiro(1234) + +DistributedUtils.initialize(backend_type) +backend = DistributedUtils.get_distributed_backend(backend_type) + +data = randn(rng, Float32, 10) +dcontainer = DistributedUtils.DistributedDataContainer(backend, data) + +rank = DistributedUtils.local_rank(backend) +tworkers = DistributedUtils.total_workers(backend) + +if rank != tworkers - 1 + @test length(dcontainer) == ceil(length(data) / tworkers) +else + @test length(dcontainer) == + length(data) - (tworkers - 1) * ceil(length(data) / tworkers) +end + +dsum = sum(Base.Fix1(MLUtils.getobs, dcontainer), 1:MLUtils.numobs(dcontainer)) +@test DistributedUtils.allreduce!(backend, [dsum], +)[1] ≈ sum(data) diff --git a/test/distributed/optimizer_distributedtest.jl b/test/distributed/optimizer_distributedtest.jl new file mode 100644 index 0000000000..e41fb2919d --- /dev/null +++ b/test/distributed/optimizer_distributedtest.jl @@ -0,0 +1,30 @@ +using Lux, MPI, NCCL, Optimisers, Random, Test + +const input_args = length(ARGS) == 3 ? ARGS[1:3] : + (length(ARGS) == 2 ? (ARGS[2], "mpi") : ("CPU", "mpi")) +const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend +const dev = input_args[1] == "CPU" ? LuxCPUDevice() : + (input_args[1] == "CUDA" ? LuxCUDADevice() : LuxAMDGPUDevice()) + +DistributedUtils.initialize(backend_type) +backend = DistributedUtils.get_distributed_backend(backend_type) + +opt = Adam(0.001f0) +ps = (a=zeros(4), b=zeros(4)) |> dev +st_opt = Optimisers.setup(opt, ps) + +dopt = DistributedUtils.DistributedOptimizer(backend, opt) +st_dopt = Optimisers.setup(dopt, ps) + +@test st_dopt.a.state == st_opt.a.state +@test st_dopt.b.state == st_opt.b.state + +@test_nowarn DistributedUtils.synchronize!!(backend, st_dopt) + +gs = (a=ones(4), b=ones(4)) |> dev + +_, ps_dopt = Optimisers.update(st_dopt, ps, gs) +_, ps_opt = Optimisers.update(st_opt, ps, gs) + +@test ps_dopt.a≈ps_opt.a atol=1.0e-5 rtol=1.0e-5 +@test ps_dopt.b≈ps_opt.b atol=1.0e-5 rtol=1.0e-5 diff --git a/test/distributed/synchronize_distributedtest.jl b/test/distributed/synchronize_distributedtest.jl new file mode 100644 index 0000000000..6e2618ff46 --- /dev/null +++ b/test/distributed/synchronize_distributedtest.jl @@ -0,0 +1,94 @@ +using ComponentArrays, Lux, MPI, NCCL, Optimisers, Random, Test + +const input_args = length(ARGS) == 3 ? ARGS[1:3] : + (length(ARGS) == 2 ? (ARGS[2], "mpi") : ("CPU", "mpi")) +const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend +const dev = input_args[1] == "CPU" ? LuxCPUDevice() : + (input_args[1] == "CUDA" ? LuxCUDADevice() : LuxAMDGPUDevice()) + +function __get_array_based_on_rank(backend, dims; root) + DistributedUtils.local_rank(backend) == root && return ones(dims...) + return zeros(dims...) +end + +root = 0 + +DistributedUtils.initialize(backend_type) +backend = DistributedUtils.get_distributed_backend(backend_type) + +# Named Tuple +gs = ( + a=(b=__get_array_based_on_rank(backend, (2, 3); root), + c=__get_array_based_on_rank(backend, (2, 3); root)), + d=__get_array_based_on_rank(backend, (2, 3); root)) |> dev + +gs_ = DistributedUtils.synchronize!!(backend, gs; root) + +@test all(gs_.a.b .== 1) +@test all(gs_.a.c .== 1) +@test all(gs_.d .== 1) + +## optimisers +opt = Adam(0.001f0) +st_opt = Optimisers.setup(opt, gs) + +if DistributedUtils.local_rank(backend) == root + st_opt.a.b.state[1] .= 1 + st_opt.a.b.state[2] .= 1 + st_opt.a.c.state[1] .= 1 + st_opt.a.c.state[2] .= 1 + st_opt.d.state[1] .= 1 + st_opt.d.state[2] .= 1 +end + +st_opt = DistributedUtils.synchronize!!(backend, st_opt; root) + +@test all(st_opt.a.b.state[1] .== 1) +@test all(st_opt.a.b.state[2] .== 1) +@test all(st_opt.a.c.state[1] .== 1) +@test all(st_opt.a.c.state[2] .== 1) +@test all(st_opt.d.state[1] .== 1) +@test all(st_opt.d.state[2] .== 1) + +# Has no state +opt = Descent(0.001f0) +st_opt = Optimisers.setup(opt, gs) + +@test_nowarn DistributedUtils.synchronize!!(backend, st_opt; root) + +## ComponentArrays +gs = ( + a=(b=__get_array_based_on_rank(backend, (2, 3); root), + c=__get_array_based_on_rank(backend, (2, 3); root)), + d=__get_array_based_on_rank(backend, (2, 3); root)) +cgs = ComponentArray(gs) |> dev +cgs_ = DistributedUtils.synchronize!!(backend, cgs; root) + +@test all(cgs_.a.b .== 1) +@test all(cgs_.a.c .== 1) +@test all(cgs_.d .== 1) + +# Tuple +gs = ( + (__get_array_based_on_rank(backend, (2, 3); root), + __get_array_based_on_rank(backend, (2, 3); root)), + __get_array_based_on_rank(backend, (2, 3); root)) |> dev + +gs = DistributedUtils.synchronize!!(backend, gs; root) + +@test all(gs[1][1] .== 1) +@test all(gs[1][2] .== 1) +@test all(gs[2] .== 1) + +# Miscelleneous +x = nothing +x = DistributedUtils.synchronize!!(backend, x; root) +@test x === nothing + +x = ifelse(root == DistributedUtils.local_rank(backend), :x, :y) +x_ = DistributedUtils.synchronize!!(backend, x; root) +# Symbol should not change +@test x_ == x + +x = DistributedUtils.synchronize!!(backend, DistributedUtils.local_rank(backend); root) +@test x == root diff --git a/test/runtests.jl b/test/runtests.jl index 8ba7978a23..79a7e4fda9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,40 @@ using ReTestItems ReTestItems.runtests(@__DIR__) + +# Distributed Tests +using MPI, Pkg, Test + +nprocs_str = get(ENV, "JULIA_MPI_TEST_NPROCS", "") +nprocs = nprocs_str == "" ? clamp(Sys.CPU_THREADS, 2, 4) : parse(Int, nprocs_str) +testdir = @__DIR__ +isdistributedtest(f) = endswith(f, "_distributedtest.jl") +distributedtestfiles = String[] +for (root, dirs, files) in walkdir(testdir) + for file in files + if isdistributedtest(file) + push!(distributedtestfiles, joinpath(root, file)) + end + end +end + +@info "Running Distributed Tests with $nprocs processes" + +cur_proj = dirname(Pkg.project().path) + +include("setup_modes.jl") + +@testset "MODE: $(mode)" for (mode, aType, dev, ongpu) in MODES + backends = mode == "CUDA" ? ("mpi", "nccl") : ("mpi",) + for backend_type in backends + @testset "Backend: $(backend_type)" begin + @testset "$(basename(file))" for file in distributedtestfiles + @info "Running $file with $backend_type backend on $mode device" + run(`$(MPI.mpiexec()) -n $(nprocs) $(Base.julia_cmd()) --color=yes \ + --code-coverage=user --project=$(cur_proj) --startup-file=no $(file) \ + $(mode) $(backend_type)`) + Test.@test true + end + end + end +end diff --git a/test/setup_modes.jl b/test/setup_modes.jl new file mode 100644 index 0000000000..7a55620286 --- /dev/null +++ b/test/setup_modes.jl @@ -0,0 +1,23 @@ +using Lux, LuxCUDA, LuxAMDGPU + +CUDA.allowscalar(false) + +const GROUP = get(ENV, "GROUP", "All") + +cpu_testing() = GROUP == "All" || GROUP == "CPU" +cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && LuxCUDA.functional() +amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") && LuxAMDGPU.functional() + +const MODES = begin + # Mode, Array Type, Device Function, GPU? + cpu_mode = ("CPU", Array, LuxCPUDevice(), false) + cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true) + amdgpu_mode = ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true) + + modes = [] + cpu_testing() && push!(modes, cpu_mode) + cuda_testing() && push!(modes, cuda_mode) + amdgpu_testing() && push!(modes, amdgpu_mode) + + modes +end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 90ccada606..f8cd80b1c4 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,32 +1,12 @@ @testsetup module SharedTestSetup import Reexport: @reexport -using Lux, LuxCUDA, LuxAMDGPU +using Lux @reexport using ComponentArrays, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, Zygote, Statistics -import LuxTestUtils: @jet, @test_gradients, check_approx +using LuxTestUtils: @jet, @test_gradients, check_approx -const GROUP = get(ENV, "GROUP", "All") - -CUDA.allowscalar(false) - -cpu_testing() = GROUP == "All" || GROUP == "CPU" -cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && LuxCUDA.functional() -amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") && LuxAMDGPU.functional() - -const MODES = begin - # Mode, Array Type, Device Function, GPU? - cpu_mode = ("CPU", Array, LuxCPUDevice(), false) - cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true) - amdgpu_mode = ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true) - - modes = [] - cpu_testing() && push!(modes, cpu_mode) - cuda_testing() && push!(modes, cuda_mode) - amdgpu_testing() && push!(modes, amdgpu_mode) - - modes -end +include("setup_modes.jl") # Some Helper Functions function get_default_rng(mode::String)