Skip to content

Commit

Permalink
[REVIEW] Enable Multi-Node Multi-GPU functionality (#4095)
Browse files Browse the repository at this point in the history
* Initial commit to support multi-node multi-gpu xgboost using dask

* Fixed NCCL initialization by not ignoring the opg parameter.

- it now crashes on NCCL initialization, but at least we're attempting it properly

* At the root node, perform a rabit::Allreduce to get initial sum_gradient across workers

* Synchronizing in a couple of more places.

- now the workers don't go down, but just hang
- no more "wild" values of gradients
- probably needs syncing in more places

* Added another missing max-allreduce operation inside BuildHistLeftRight

* Removed unnecessary collective operations.

* Simplified rabit::Allreduce() sync of gradient sums.

* Removed unnecessary rabit syncs around ncclAllReduce.

- this improves performance _significantly_ (7x faster for overall training,
  20x faster for xgboost proper)

* pulling in latest xgboost

* removing changes to updater_quantile_hist.cc

* changing use_nccl_opg initialization, removing unnecessary if statements

* added definition for opaque ncclUniqueId struct to properly encapsulate GetUniqueId

* placing struct defintion in guard to avoid duplicate code errors

* addressing linting errors

* removing

* removing additional arguments to AllReduer initialization

* removing distributed flag

* making comm init symmetric

* removing distributed flag

* changing ncclCommInit to support multiple modalities

* fix indenting

* updating ncclCommInitRank block with necessary group calls

* fix indenting

* adding print statement, and updating accessor in vector

* improving print statement to end-line

* generalizing nccl_rank construction using rabit

* assume device_ordinals is the same for every node

* test, assume device_ordinals is identical for all nodes

* test, assume device_ordinals is unique for all nodes

* changing names of offset variable to be more descriptive, editing indenting

* wrapping ncclUniqueId GetUniqueId() and aesthetic changes

* adding synchronization, and tests for distributed

* adding  to tests

* fixing broken #endif

* fixing initialization of gpu histograms, correcting errors in tests

* adding to contributors list

* adding distributed tests to jenkins

* fixing bad path in distributed test

* debugging

* adding kubernetes for distributed tests

* adding proper import for OrderedDict

* adding urllib3==1.22 to address ordered_dict import error

* added sleep to allow workers to save their models for comparison

* adding name to GPU contributors under docs
  • Loading branch information
Matthew Jones authored and RAMitchell committed Mar 1, 2019
1 parent 9fefa21 commit 92b7577
Show file tree
Hide file tree
Showing 11 changed files with 286 additions and 10 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,6 @@ List of Contributors
* [Andrew Thia](https://github.com/BlueTea88)
- Andrew Thia implemented feature interaction constraints
* [Wei Tian](https://github.com/weitian)
* [Chen Qin] (https://github.com/chenqin)
* [Chen Qin](https://github.com/chenqin)
* [Sam Wilkinson](https://samwilkinson.io)
* [Matthew Jones](https://github.com/mt-jones)
1 change: 1 addition & 0 deletions doc/gpu/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ Many thanks to the following contributors (alphabetical order):
* Andrey Adinets
* Jiaming Yuan
* Jonathan C. McKinney
* Matthew Jones
* Philip Cho
* Rory Mitchell
* Shankara Rao Thejaswi Nanditale
Expand Down
67 changes: 60 additions & 7 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
#include "../common/io.h"
#endif

// Uncomment to enable
Expand Down Expand Up @@ -853,6 +854,8 @@ class AllReducer {
std::vector<ncclComm_t> comms;
std::vector<cudaStream_t> streams;
std::vector<int> device_ordinals; // device id from CUDA
std::vector<int> device_counts; // device count from CUDA
ncclUniqueId id;
#endif

public:
Expand All @@ -872,14 +875,41 @@ class AllReducer {
#ifdef XGBOOST_USE_NCCL
/** \brief this >monitor . init. */
this->device_ordinals = device_ordinals;
comms.resize(device_ordinals.size());
dh::safe_nccl(ncclCommInitAll(comms.data(),
static_cast<int>(device_ordinals.size()),
device_ordinals.data()));
streams.resize(device_ordinals.size());
this->device_counts.resize(rabit::GetWorldSize());
this->comms.resize(device_ordinals.size());
this->streams.resize(device_ordinals.size());
this->id = GetUniqueId();

device_counts.at(rabit::GetRank()) = device_ordinals.size();
for (size_t i = 0; i < device_counts.size(); i++) {
int dev_count = device_counts.at(i);
rabit::Allreduce<rabit::op::Sum, int>(&dev_count, 1);
device_counts.at(i) = dev_count;
}

int nccl_rank = 0;
int nccl_rank_offset = std::accumulate(device_counts.begin(),
device_counts.begin() + rabit::GetRank(), 0);
int nccl_nranks = std::accumulate(device_counts.begin(),
device_counts.end(), 0);
nccl_rank += nccl_rank_offset;

GroupStart();
for (size_t i = 0; i < device_ordinals.size(); i++) {
safe_cuda(cudaSetDevice(device_ordinals[i]));
safe_cuda(cudaStreamCreate(&streams[i]));
int dev = device_ordinals.at(i);
dh::safe_cuda(cudaSetDevice(dev));
dh::safe_nccl(ncclCommInitRank(
&comms.at(i),
nccl_nranks, id,
nccl_rank));

nccl_rank++;
}
GroupEnd();

for (size_t i = 0; i < device_ordinals.size(); i++) {
safe_cuda(cudaSetDevice(device_ordinals.at(i)));
safe_cuda(cudaStreamCreate(&streams.at(i)));
}
initialised_ = true;
#else
Expand Down Expand Up @@ -1010,7 +1040,30 @@ class AllReducer {
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
}
#endif
};

#ifdef XGBOOST_USE_NCCL
/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
static const int RootRank = 0;
ncclUniqueId id;
if (rabit::GetRank() == RootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
rabit::Broadcast(
(void*)&id,
(size_t)sizeof(ncclUniqueId),
(int)RootRank);
return id;
}
#endif
};

class SaveCudaContext {
Expand Down
11 changes: 10 additions & 1 deletion src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -628,10 +628,12 @@ struct DeviceShard {
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
split_candidates.size() * sizeof(DeviceSplitCandidate),
cudaMemcpyDeviceToHost));

DeviceSplitCandidate best_split;
for (auto candidate : split_candidates) {
best_split.Update(candidate, param);
}

return best_split;
}

Expand Down Expand Up @@ -1049,7 +1051,8 @@ class GPUHistMakerSpecialised{
}

void AllReduceHist(int nidx) {
if (shards_.size() == 1) return;
if (shards_.size() == 1 && !rabit::IsDistributed())
return;
monitor_.Start("AllReduce");

reducer_.GroupStart();
Expand Down Expand Up @@ -1080,6 +1083,9 @@ class GPUHistMakerSpecialised{
right_node_max_elements, shard->ridx_segments[nidx_right].Size());
}

rabit::Allreduce<rabit::op::Max, size_t>(&left_node_max_elements, 1);
rabit::Allreduce<rabit::op::Max, size_t>(&right_node_max_elements, 1);

auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;

Expand Down Expand Up @@ -1142,9 +1148,12 @@ class GPUHistMakerSpecialised{
tmp_sums[i] = dh::SumReduction(
shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
});

GradientPair sum_gradient =
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());

rabit::Allreduce<rabit::op::Sum>((GradientPair::ValueT*)&sum_gradient, 2);

// Generate root histogram
dh::ExecuteIndexShards(
&shards_,
Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ENV CPP=/opt/rh/devtoolset-2/root/usr/bin/cpp

# Install Python packages
RUN \
pip install numpy pytest scipy scikit-learn wheel
pip install numpy pytest scipy scikit-learn wheel kubernetes urllib3==1.22

ENV GOSU_VERSION 1.10

Expand Down
3 changes: 3 additions & 0 deletions tests/ci_build/test_mgpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ python setup.py install --user
cd ..
pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu
./testxgboost --gtest_filter=*.MGPU_*

cd tests/distributed-gpu
./runtests-gpu.sh
19 changes: 19 additions & 0 deletions tests/distributed-gpu/runtests-gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

rm -f *.model*

echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=4 \
python test_gpu_basic_1x4.py

echo -e "\n ====== 2. Basic distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \
python test_gpu_basic_2x2.py

echo -e "\n ====== 3. Basic distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \
python test_gpu_basic_asym.py

echo -e "\n ====== 4. Basic distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n"
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=1 \
python test_gpu_basic_4x1.py
51 changes: 51 additions & 0 deletions tests/distributed-gpu/test_gpu_basic_1x4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/python
import xgboost as xgb
import time
from collections import OrderedDict

# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }

# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.1x4." + str(rank)
bst.dump_model(model_name, with_stats=True); time.sleep(2)
xgb.rabit.tracker_print("Finished training\n")

fail = False
if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.1x4." + str(i)
for j in range(0, world):
if i != j:
with open(model_name_root, 'r') as model_root:
model_name_rank = "test.model.1x4." + str(j)
with open(model_name_rank, 'r') as model_rank:
diff = set(model_root).difference(model_rank)
if len(diff) != 0:
fail = True
xgb.rabit.finalize()
raise Exception('Worker models diverged: test.model.1x4.{} differs from test.model.1x4.{}'.format(i, j))

if (rank != 0) and (fail):
xgb.rabit.finalize()

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
51 changes: 51 additions & 0 deletions tests/distributed-gpu/test_gpu_basic_2x2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/python
import xgboost as xgb
import time
from collections import OrderedDict

# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
param = {'n_gpus': 2, 'gpu_id': 2*rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }

# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.2x2." + str(rank)
bst.dump_model(model_name, with_stats=True); time.sleep(2)
xgb.rabit.tracker_print("Finished training\n")

fail = False
if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.2x2." + str(i)
for j in range(0, world):
if i != j:
with open(model_name_root, 'r') as model_root:
model_name_rank = "test.model.2x2." + str(j)
with open(model_name_rank, 'r') as model_rank:
diff = set(model_root).difference(model_rank)
if len(diff) != 0:
fail = True
xgb.rabit.finalize()
raise Exception('Worker models diverged: test.model.2x2.{} differs from test.model.2x2.{}'.format(i, j))

if (rank != 0) and (fail):
xgb.rabit.finalize()

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
34 changes: 34 additions & 0 deletions tests/distributed-gpu/test_gpu_basic_4x1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/python
import xgboost as xgb
import time
from collections import OrderedDict

# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
param = {'n_gpus': 4, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }

# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have root save its model
if(rank == 0):
model_name = "test.model.4x1." + str(rank)
bst.dump_model(model_name, with_stats=True)
xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
54 changes: 54 additions & 0 deletions tests/distributed-gpu/test_gpu_basic_asym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/python
import xgboost as xgb
import time
from collections import OrderedDict

# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
if rank == 0:
param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
else:
param = {'n_gpus': 3, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }

# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.asym." + str(rank)
bst.dump_model(model_name, with_stats=True); time.sleep(2)
xgb.rabit.tracker_print("Finished training\n")

fail = False
if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.asym." + str(i)
for j in range(0, world):
if i != j:
with open(model_name_root, 'r') as model_root:
model_name_rank = "test.model.asym." + str(j)
with open(model_name_rank, 'r') as model_rank:
diff = set(model_root).difference(model_rank)
if len(diff) != 0:
fail = True
xgb.rabit.finalize()
raise Exception('Worker models diverged: test.model.asym.{} differs from test.model.asym.{}'.format(i, j))

if (rank != 0) and (fail):
xgb.rabit.finalize()

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()

0 comments on commit 92b7577

Please sign in to comment.