Skip to content

Commit

Permalink
add recordStream and fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ustclight-sls committed Sep 19, 2024
1 parent 45206b6 commit d34260e
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 53 deletions.
22 changes: 22 additions & 0 deletions dipu/tests/python/individual_scripts/test_rt_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,28 @@ def test_allgather(rank, world_size, port, seed):

for i in range(len(gathered_tensors)):
assert torch.allclose(gathered_tensors[i], expected_tensors[i])

device = torch.device(f"cuda:{rank}")
t = torch.rand((rank + 1, rank + 1), dtype=torch.float16, device=device)
shapes = []
for i in range(world_size):
shapes.append((i + 2, i + 2))
gathered_t = [
torch.empty(shapes[i], dtype=torch.float16, device=device)
for i in range(world_size)
]

try:
dist.all_gather(gathered_t, t)
except RuntimeError as e:
expected_error_message = "Tensor input and output of broadcast must have the same number of elements "
if str(e) == expected_error_message:
print(
"Correct exception raised with expected error message in test_allgather."
)
else:
print(f"Incorrect error message: {str(e)}")

cleanup()


Expand Down
115 changes: 62 additions & 53 deletions dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ void check_device_single_tensor(
}

// Check that all `tensors'
void checkDeviceTensors(const std::vector<at::Tensor>& tensors) {
void checkDeviceTensors(const std::vector<at::Tensor>& tensors,
bool check_sizes_and_strides) {
if (tensors.empty()) {
TORCH_CHECK(false, "Tensor list must be nonempty");
}
Expand All @@ -348,11 +349,13 @@ void checkDeviceTensors(const std::vector<at::Tensor>& tensors) {
if (tensor.scalar_type() != first.scalar_type()) {
TORCH_CHECK(false, "Tensors must have identical type");
}
if (tensor.sizes() != first.sizes()) {
TORCH_CHECK(false, "Tensors must have identical size");
}
if (tensor.strides() != first.strides()) {
TORCH_CHECK(false, "Tensors must have identical strides");
if (check_sizes_and_strides) {
if (tensor.sizes() != first.sizes()) {
TORCH_CHECK(false, "Tensors must have identical size");
}
if (tensor.strides() != first.strides()) {
TORCH_CHECK(false, "Tensors must have identical strides");
}
}
const auto inserted = usedDevices.insert(tensor.get_device()).second;
if (!inserted) {
Expand Down Expand Up @@ -401,15 +404,6 @@ std::vector<at::Tensor> flatten_for_scatter_gather(
return flattened;
}

bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
for (const auto& input_tensor : input_tensors) {
if (!input_tensors[0].is_same_size(input_tensor)) {
return false;
}
}
return true;
}

template <bool RecordDest, typename Dest, typename Src>
void copyInCommStream(std::shared_ptr<DICLComm>& diclComm, const Dest& dest,
const Src& src, int nums) {
Expand Down Expand Up @@ -545,7 +539,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::pointToPoint(
c10::intrusive_ptr<Work> ProcessGroupDICL::allreduce(
std::vector<at::Tensor>& tensors, const AllreduceOptions& opts) {
// inplace in = out, every rank use both in&out.
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);
std::vector<at::Tensor> tensors_cp{tensors};
return collective(
tensors_cp, tensors_cp,
Expand Down Expand Up @@ -574,7 +568,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::allreduce(

c10::intrusive_ptr<Work> ProcessGroupDICL::broadcast(
std::vector<at::Tensor>& tensors, const BroadcastOptions& opts) {
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);
// inplace in = out, only rootRank use in.
return collective(
tensors, tensors,
Expand All @@ -596,7 +590,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::broadcast(
c10::intrusive_ptr<Work> ProcessGroupDICL::reduce(
std::vector<at::Tensor>& tensors, const ReduceOptions& opts) {
// inplace in = out, only rootRank use out.
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);

auto tensor = tensors.back();
int dev_in_group = 0;
Expand Down Expand Up @@ -637,7 +631,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::gather(
int curRank = getRank();
int rootRank = static_cast<int>(opts.rootRank);
c10d::assertRootRank(raise_invalid_arg_func, rootRank, numRanks);
checkDeviceTensors(inputs);
checkDeviceTensors(inputs, true);
c10d::assertSingleElementInput(raise_invalid_arg_func, inputs);
auto input = inputs.back();
std::vector<at::Tensor> outputTensors;
Expand Down Expand Up @@ -695,9 +689,13 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::allgather(
std::vector<at::Tensor>& inputs, const AllgatherOptions& opts) {
auto inputTensor = inputs.back();
auto outputs_ = outputs.back();
bool same_size = check_same_size(outputs_);
const at::Tensor& first_tensor = outputs_.front();
bool same_size = std::all_of(outputs_.begin(), outputs_.end(),
[&first_tensor](const at::Tensor& t) {
return first_tensor.is_same_size(t);
});
checkDeviceTensors(inputs, same_size);
if (same_size) {
checkDeviceTensors(inputs);
// output = input * ranks, no inplace. every ranks use both in&out.
auto outputFlattened =
flatten_for_scatter_gather(outputs, inputs, this->size_);
Expand Down Expand Up @@ -730,34 +728,45 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::allgather(
},
OpType::ALLGATHER);
return work;
} else {
const auto num_reduces = outputs_.size();
return collective(
inputs, outputs_,
[&](at::Tensor& input, at::Tensor& output, diclComm_t comm,
DIPUStream& stream) {
for (const int i : c10::irange(num_reduces)) {
auto& opt = outputs_[i];
auto& ipt = (i == this->rank_) ? inputTensor : opt;
auto broadcastOpts = BroadcastOptions{
static_cast<int64_t>(i), static_cast<int64_t>(0), opts.timeout};
if (opt.numel() != ipt.numel()) {
throw std::runtime_error(
"Tensor input and output of _broadcast_oop must have the "
"same number of elements ");
}
RECORD_FUNCTION("DiclBroadcast", std::vector<c10::IValue>({ipt}));
profile::RecordBlockCreator _("DiclBroadcast", stream.rawstream(),
static_cast<int>(stream.id()));
const auto root = broadcastOpts.rootRank + broadcastOpts.rootTensor;
devproxy::diclBroadcast(ipt.data_ptr(), opt.data_ptr(),
static_cast<size_t>(ipt.numel()),
ipt.scalar_type(), static_cast<int>(root),
comm, stream.rawstream());
}
},
OpType::BROADCAST);
}
const auto num_broadcasts = outputs_.size();
for (const auto i : c10::irange(num_broadcasts)) {
auto& outCheck = outputs_[i];
auto& inCheck = (i == this->rank_) ? inputTensor : outCheck;
if (outCheck.numel() != inCheck.numel()) {
throw std::runtime_error(
"Tensor input and output of broadcast must have the "
"same number of elements ");
}
}
return collective(
inputs, outputs_,
[&](at::Tensor& /* unused */, at::Tensor& /* unused */, diclComm_t comm,
DIPUStream& stream) {
for (const auto i : c10::irange(num_broadcasts)) {
auto& outTensor = outputs_[i];
auto& inTensor = (i == this->rank_) ? inputTensor : outTensor;
RECORD_FUNCTION("DiclBroadcast",
std::vector<c10::IValue>({inTensor}));
profile::RecordBlockCreator _("DiclBroadcast", stream.rawstream(),
static_cast<int>(stream.id()));
const auto root = static_cast<int64_t>(i);
// Just for the convenience of calling collective, it is necessary to
// record the output elements of different devices, and the work logic
// is correct.
if (outTensor.has_storage() &&
(!inTensor.has_storage() ||
inTensor.storage().data_ptr().get() !=
outTensor.storage().data_ptr().get())) {
dipu::recordStream(outTensor, stream);
}
devproxy::diclBroadcast(
inTensor.data_ptr(), outTensor.data_ptr(),
static_cast<size_t>(inTensor.numel()), inTensor.scalar_type(),
static_cast<int>(root), comm, stream.rawstream());
}
},
OpType::BROADCAST);
}

c10::intrusive_ptr<Work> ProcessGroupDICL::_allgather_base(
Expand Down Expand Up @@ -829,7 +838,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::scatter(
int curRank = getRank();
int rootRank = static_cast<int>(opts.rootRank);
c10d::assertRootRank(raise_invalid_arg_func, rootRank, numRanks);
checkDeviceTensors(outputs);
checkDeviceTensors(outputs, true);
c10d::assertSingleElementOutput(raise_invalid_arg_func, outputs);
auto output = outputs.back();
std::vector<at::Tensor> inputTensors;
Expand Down Expand Up @@ -880,10 +889,10 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::reduce_scatter(
std::vector<std::vector<at::Tensor>>& inputs,
const ReduceScatterOptions& opts) {
// input = output * ranks, no inplace, output = reduced(input)[rank]
checkDeviceTensors(outputs);
checkDeviceTensors(outputs, true);
auto inputFlattened =
flatten_for_scatter_gather(inputs, outputs, this->size_);
checkDeviceTensors(inputFlattened);
checkDeviceTensors(inputFlattened, true);

auto work = collective(
inputFlattened, outputs,
Expand Down Expand Up @@ -1056,7 +1065,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::alltoall(

c10::intrusive_ptr<Work> ProcessGroupDICL::send(
std::vector<at::Tensor>& tensors, int dstRank, int tag) {
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);
auto p2pPair = mapPGRank2P2P(rank_, dstRank);
return pointToPoint(
tensors, tensors, dstRank,
Expand All @@ -1075,7 +1084,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::send(

c10::intrusive_ptr<Work> ProcessGroupDICL::recv(
std::vector<at::Tensor>& tensors, int srcRank, int tag) {
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);
auto p2pPair = mapPGRank2P2P(rank_, srcRank);
return pointToPoint(
tensors, tensors, srcRank,
Expand Down

0 comments on commit d34260e

Please sign in to comment.