diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 0cd4d4341487..b4474f6746bc 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -85,5 +85,6 @@ List of Contributors * [Andrew Thia](https://github.com/BlueTea88) - Andrew Thia implemented feature interaction constraints * [Wei Tian](https://github.com/weitian) -* [Chen Qin] (https://github.com/chenqin) +* [Chen Qin](https://github.com/chenqin) * [Sam Wilkinson](https://samwilkinson.io) +* [Matthew Jones](https://github.com/mt-jones) diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index 94fadc469173..071d66c23e48 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -208,6 +208,7 @@ Many thanks to the following contributors (alphabetical order): * Andrey Adinets * Jiaming Yuan * Jonathan C. McKinney +* Matthew Jones * Philip Cho * Rory Mitchell * Shankara Rao Thejaswi Nanditale diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 1d6355c54f60..4e28daccb214 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -23,6 +23,7 @@ #ifdef XGBOOST_USE_NCCL #include "nccl.h" +#include "../common/io.h" #endif // Uncomment to enable @@ -853,6 +854,8 @@ class AllReducer { std::vector comms; std::vector streams; std::vector device_ordinals; // device id from CUDA + std::vector device_counts; // device count from CUDA + ncclUniqueId id; #endif public: @@ -872,14 +875,41 @@ class AllReducer { #ifdef XGBOOST_USE_NCCL /** \brief this >monitor . init. */ this->device_ordinals = device_ordinals; - comms.resize(device_ordinals.size()); - dh::safe_nccl(ncclCommInitAll(comms.data(), - static_cast(device_ordinals.size()), - device_ordinals.data())); - streams.resize(device_ordinals.size()); + this->device_counts.resize(rabit::GetWorldSize()); + this->comms.resize(device_ordinals.size()); + this->streams.resize(device_ordinals.size()); + this->id = GetUniqueId(); + + device_counts.at(rabit::GetRank()) = device_ordinals.size(); + for (size_t i = 0; i < device_counts.size(); i++) { + int dev_count = device_counts.at(i); + rabit::Allreduce(&dev_count, 1); + device_counts.at(i) = dev_count; + } + + int nccl_rank = 0; + int nccl_rank_offset = std::accumulate(device_counts.begin(), + device_counts.begin() + rabit::GetRank(), 0); + int nccl_nranks = std::accumulate(device_counts.begin(), + device_counts.end(), 0); + nccl_rank += nccl_rank_offset; + + GroupStart(); for (size_t i = 0; i < device_ordinals.size(); i++) { - safe_cuda(cudaSetDevice(device_ordinals[i])); - safe_cuda(cudaStreamCreate(&streams[i])); + int dev = device_ordinals.at(i); + dh::safe_cuda(cudaSetDevice(dev)); + dh::safe_nccl(ncclCommInitRank( + &comms.at(i), + nccl_nranks, id, + nccl_rank)); + + nccl_rank++; + } + GroupEnd(); + + for (size_t i = 0; i < device_ordinals.size(); i++) { + safe_cuda(cudaSetDevice(device_ordinals.at(i))); + safe_cuda(cudaStreamCreate(&streams.at(i))); } initialised_ = true; #else @@ -1010,7 +1040,30 @@ class AllReducer { dh::safe_cuda(cudaStreamSynchronize(streams[i])); } #endif + }; + +#ifdef XGBOOST_USE_NCCL + /** + * \fn ncclUniqueId GetUniqueId() + * + * \brief Gets the Unique ID from NCCL to be used in setting up interprocess + * communication + * + * \return the Unique ID + */ + ncclUniqueId GetUniqueId() { + static const int RootRank = 0; + ncclUniqueId id; + if (rabit::GetRank() == RootRank) { + dh::safe_nccl(ncclGetUniqueId(&id)); + } + rabit::Broadcast( + (void*)&id, + (size_t)sizeof(ncclUniqueId), + (int)RootRank); + return id; } +#endif }; class SaveCudaContext { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 13a24cdfa270..408e89e6189b 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -628,10 +628,12 @@ struct DeviceShard { dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(), split_candidates.size() * sizeof(DeviceSplitCandidate), cudaMemcpyDeviceToHost)); + DeviceSplitCandidate best_split; for (auto candidate : split_candidates) { best_split.Update(candidate, param); } + return best_split; } @@ -1049,7 +1051,8 @@ class GPUHistMakerSpecialised{ } void AllReduceHist(int nidx) { - if (shards_.size() == 1) return; + if (shards_.size() == 1 && !rabit::IsDistributed()) + return; monitor_.Start("AllReduce"); reducer_.GroupStart(); @@ -1080,6 +1083,9 @@ class GPUHistMakerSpecialised{ right_node_max_elements, shard->ridx_segments[nidx_right].Size()); } + rabit::Allreduce(&left_node_max_elements, 1); + rabit::Allreduce(&right_node_max_elements, 1); + auto build_hist_nidx = nidx_left; auto subtraction_trick_nidx = nidx_right; @@ -1142,9 +1148,12 @@ class GPUHistMakerSpecialised{ tmp_sums[i] = dh::SumReduction( shard->temp_memory, shard->gpair.Data(), shard->gpair.Size()); }); + GradientPair sum_gradient = std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair()); + rabit::Allreduce((GradientPair::ValueT*)&sum_gradient, 2); + // Generate root histogram dh::ExecuteIndexShards( &shards_, diff --git a/tests/ci_build/Dockerfile.gpu b/tests/ci_build/Dockerfile.gpu index 1376b5fe8d3b..15a72c09a598 100644 --- a/tests/ci_build/Dockerfile.gpu +++ b/tests/ci_build/Dockerfile.gpu @@ -35,7 +35,7 @@ ENV CPP=/opt/rh/devtoolset-2/root/usr/bin/cpp # Install Python packages RUN \ - pip install numpy pytest scipy scikit-learn wheel + pip install numpy pytest scipy scikit-learn wheel kubernetes urllib3==1.22 ENV GOSU_VERSION 1.10 diff --git a/tests/ci_build/test_mgpu.sh b/tests/ci_build/test_mgpu.sh index 5eef3e7081e1..2dfafcc2ee54 100755 --- a/tests/ci_build/test_mgpu.sh +++ b/tests/ci_build/test_mgpu.sh @@ -6,3 +6,6 @@ python setup.py install --user cd .. pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu ./testxgboost --gtest_filter=*.MGPU_* + +cd tests/distributed-gpu +./runtests-gpu.sh \ No newline at end of file diff --git a/tests/distributed-gpu/runtests-gpu.sh b/tests/distributed-gpu/runtests-gpu.sh new file mode 100755 index 000000000000..e3fa8a0d3ec7 --- /dev/null +++ b/tests/distributed-gpu/runtests-gpu.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +rm -f *.model* + +echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" +PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=4 \ + python test_gpu_basic_1x4.py + +echo -e "\n ====== 2. Basic distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n" +PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \ + python test_gpu_basic_2x2.py + +echo -e "\n ====== 3. Basic distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n" +PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \ + python test_gpu_basic_asym.py + +echo -e "\n ====== 4. Basic distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n" +PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=1 \ + python test_gpu_basic_4x1.py \ No newline at end of file diff --git a/tests/distributed-gpu/test_gpu_basic_1x4.py b/tests/distributed-gpu/test_gpu_basic_1x4.py new file mode 100644 index 000000000000..d325a167a4c8 --- /dev/null +++ b/tests/distributed-gpu/test_gpu_basic_1x4.py @@ -0,0 +1,51 @@ +#!/usr/bin/python +import xgboost as xgb +import time +from collections import OrderedDict + +# Always call this before using distributed module +xgb.rabit.init() +rank = xgb.rabit.get_rank() +world = xgb.rabit.get_world_size() + +# Load file, file will be automatically sharded in distributed mode. +dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') +dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') + +# Specify parameters via map, definition are same as c++ version +param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } + +# Specify validations set to watch performance +watchlist = [(dtest,'eval'), (dtrain,'train')] +num_round = 20 + +# Run training, all the features in training API is available. +# Currently, this script only support calling train once for fault recovery purpose. +bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) + +# Have each worker save its model +model_name = "test.model.1x4." + str(rank) +bst.dump_model(model_name, with_stats=True); time.sleep(2) +xgb.rabit.tracker_print("Finished training\n") + +fail = False +if (rank == 0): + for i in range(0, world): + model_name_root = "test.model.1x4." + str(i) + for j in range(0, world): + if i != j: + with open(model_name_root, 'r') as model_root: + model_name_rank = "test.model.1x4." + str(j) + with open(model_name_rank, 'r') as model_rank: + diff = set(model_root).difference(model_rank) + if len(diff) != 0: + fail = True + xgb.rabit.finalize() + raise Exception('Worker models diverged: test.model.1x4.{} differs from test.model.1x4.{}'.format(i, j)) + +if (rank != 0) and (fail): + xgb.rabit.finalize() + +# Notify the tracker all training has been successful +# This is only needed in distributed training. +xgb.rabit.finalize() diff --git a/tests/distributed-gpu/test_gpu_basic_2x2.py b/tests/distributed-gpu/test_gpu_basic_2x2.py new file mode 100644 index 000000000000..b1669560ce31 --- /dev/null +++ b/tests/distributed-gpu/test_gpu_basic_2x2.py @@ -0,0 +1,51 @@ +#!/usr/bin/python +import xgboost as xgb +import time +from collections import OrderedDict + +# Always call this before using distributed module +xgb.rabit.init() +rank = xgb.rabit.get_rank() +world = xgb.rabit.get_world_size() + +# Load file, file will be automatically sharded in distributed mode. +dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') +dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') + +# Specify parameters via map, definition are same as c++ version +param = {'n_gpus': 2, 'gpu_id': 2*rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } + +# Specify validations set to watch performance +watchlist = [(dtest,'eval'), (dtrain,'train')] +num_round = 20 + +# Run training, all the features in training API is available. +# Currently, this script only support calling train once for fault recovery purpose. +bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) + +# Have each worker save its model +model_name = "test.model.2x2." + str(rank) +bst.dump_model(model_name, with_stats=True); time.sleep(2) +xgb.rabit.tracker_print("Finished training\n") + +fail = False +if (rank == 0): + for i in range(0, world): + model_name_root = "test.model.2x2." + str(i) + for j in range(0, world): + if i != j: + with open(model_name_root, 'r') as model_root: + model_name_rank = "test.model.2x2." + str(j) + with open(model_name_rank, 'r') as model_rank: + diff = set(model_root).difference(model_rank) + if len(diff) != 0: + fail = True + xgb.rabit.finalize() + raise Exception('Worker models diverged: test.model.2x2.{} differs from test.model.2x2.{}'.format(i, j)) + +if (rank != 0) and (fail): + xgb.rabit.finalize() + +# Notify the tracker all training has been successful +# This is only needed in distributed training. +xgb.rabit.finalize() diff --git a/tests/distributed-gpu/test_gpu_basic_4x1.py b/tests/distributed-gpu/test_gpu_basic_4x1.py new file mode 100644 index 000000000000..6662a3ac6345 --- /dev/null +++ b/tests/distributed-gpu/test_gpu_basic_4x1.py @@ -0,0 +1,34 @@ +#!/usr/bin/python +import xgboost as xgb +import time +from collections import OrderedDict + +# Always call this before using distributed module +xgb.rabit.init() +rank = xgb.rabit.get_rank() +world = xgb.rabit.get_world_size() + +# Load file, file will be automatically sharded in distributed mode. +dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') +dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') + +# Specify parameters via map, definition are same as c++ version +param = {'n_gpus': 4, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } + +# Specify validations set to watch performance +watchlist = [(dtest,'eval'), (dtrain,'train')] +num_round = 20 + +# Run training, all the features in training API is available. +# Currently, this script only support calling train once for fault recovery purpose. +bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) + +# Have root save its model +if(rank == 0): + model_name = "test.model.4x1." + str(rank) + bst.dump_model(model_name, with_stats=True) +xgb.rabit.tracker_print("Finished training\n") + +# Notify the tracker all training has been successful +# This is only needed in distributed training. +xgb.rabit.finalize() diff --git a/tests/distributed-gpu/test_gpu_basic_asym.py b/tests/distributed-gpu/test_gpu_basic_asym.py new file mode 100644 index 000000000000..e20304ef2ac4 --- /dev/null +++ b/tests/distributed-gpu/test_gpu_basic_asym.py @@ -0,0 +1,54 @@ +#!/usr/bin/python +import xgboost as xgb +import time +from collections import OrderedDict + +# Always call this before using distributed module +xgb.rabit.init() +rank = xgb.rabit.get_rank() +world = xgb.rabit.get_world_size() + +# Load file, file will be automatically sharded in distributed mode. +dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') +dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') + +# Specify parameters via map, definition are same as c++ version +if rank == 0: + param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } +else: + param = {'n_gpus': 3, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } + +# Specify validations set to watch performance +watchlist = [(dtest,'eval'), (dtrain,'train')] +num_round = 20 + +# Run training, all the features in training API is available. +# Currently, this script only support calling train once for fault recovery purpose. +bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) + +# Have each worker save its model +model_name = "test.model.asym." + str(rank) +bst.dump_model(model_name, with_stats=True); time.sleep(2) +xgb.rabit.tracker_print("Finished training\n") + +fail = False +if (rank == 0): + for i in range(0, world): + model_name_root = "test.model.asym." + str(i) + for j in range(0, world): + if i != j: + with open(model_name_root, 'r') as model_root: + model_name_rank = "test.model.asym." + str(j) + with open(model_name_rank, 'r') as model_rank: + diff = set(model_root).difference(model_rank) + if len(diff) != 0: + fail = True + xgb.rabit.finalize() + raise Exception('Worker models diverged: test.model.asym.{} differs from test.model.asym.{}'.format(i, j)) + +if (rank != 0) and (fail): + xgb.rabit.finalize() + +# Notify the tracker all training has been successful +# This is only needed in distributed training. +xgb.rabit.finalize()