diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 04130aacb26c..358e9ecddcc9 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -58,6 +58,19 @@ struct HostDeviceVectorImpl { perm_d_ = vec_->perm_h_.Complementary(); } + void Init(HostDeviceVectorImpl* vec, const DeviceShard& other) { + if (vec_ == nullptr) { vec_ = vec; } + CHECK_EQ(vec, vec_); + device_ = other.device_; + index_ = other.index_; + cached_size_ = other.cached_size_; + start_ = other.start_; + proper_size_ = other.proper_size_; + SetDevice(); + data_.resize(other.data_.size()); + perm_d_ = other.perm_d_; + } + void ScatterFrom(const T* begin) { // TODO(canonizer): avoid full copy of host data LazySyncDevice(GPUAccess::kWrite); @@ -166,7 +179,12 @@ struct HostDeviceVectorImpl { // required, as a new std::mutex has to be created HostDeviceVectorImpl(const HostDeviceVectorImpl& other) : data_h_(other.data_h_), perm_h_(other.perm_h_), size_d_(other.size_d_), - distribution_(other.distribution_), mutex_(), shards_(other.shards_) {} + distribution_(other.distribution_), mutex_() { + shards_.resize(other.shards_.size()); + dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { + shard.Init(this, other.shards_[i]); + }); + } // Init can be std::vector or std::initializer_list template diff --git a/tests/cpp/common/test_host_device_vector.cu b/tests/cpp/common/test_host_device_vector.cu index e471e785425c..e76744bb3fa9 100644 --- a/tests/cpp/common/test_host_device_vector.cu +++ b/tests/cpp/common/test_host_device_vector.cu @@ -155,6 +155,29 @@ TEST(HostDeviceVector, TestExplicit) { TestHostDeviceVector(n, distribution, starts, sizes); } +TEST(HostDeviceVector, TestCopy) { + size_t n = 1001; + int n_devices = 2; + auto distribution = GPUDistribution::Block(GPUSet::Range(0, n_devices)); + std::vector starts{0, 501}; + std::vector sizes{501, 500}; + SetCudaSetDeviceHandler(SetDevice); + + HostDeviceVector v; + { + // a separate scope to ensure that v1 is gone before further checks + HostDeviceVector v1; + InitHostDeviceVector(n, distribution, &v1); + v = v1; + } + CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead); + PlusOne(&v); + CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite); + CheckHost(&v, GPUAccess::kRead); + CheckHost(&v, GPUAccess::kWrite); + SetCudaSetDeviceHandler(nullptr); +} + TEST(HostDeviceVector, Span) { HostDeviceVector vec {1.0f, 2.0f, 3.0f, 4.0f}; vec.Reshard(GPUSet{0, 1});