diff --git a/dipu/tests/python/individual_scripts/test_rt_ddp.py b/dipu/tests/python/individual_scripts/test_rt_ddp.py index f8dbbdc80..cc31df266 100644 --- a/dipu/tests/python/individual_scripts/test_rt_ddp.py +++ b/dipu/tests/python/individual_scripts/test_rt_ddp.py @@ -10,6 +10,7 @@ import torch.optim as optim import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP +from utils.random_shape import ShapeGenerator def debugat(rank=0): @@ -473,12 +474,13 @@ def demo_model_parallel(rank, world_size, port): cleanup() -def run_demo(demo_fn, world_size, port): +def run_demo(demo_fn, world_size, port, *args): mp.spawn( demo_fn, args=( world_size, port, + *args, ), nprocs=world_size, join=True, @@ -582,6 +584,57 @@ def test_get_comm_name(rank, world_size, port): cleanup() +def test_allgather(rank, world_size, port, seed): + random.seed(seed) + torch.manual_seed(seed) + import torch_dipu + + print(f"test allgather on rank {rank} ws: {world_size}") + setup(rank, world_size, port) + shape_gen = ShapeGenerator(seed=seed) + + gathered_tensors = [] + expected_tensors = [] + for i in range(world_size): + shape = shape_gen.random_shape( + (1, 100000), + (1, 4), + ) + device = torch.device(f"cuda:{rank}") + tensor = torch.rand(size=shape, dtype=torch.float16) + tensor = tensor.to(device) + gathered_tensors.append(torch.empty_like(tensor)) + expected_tensors.append(tensor) + tensor_to_gather = expected_tensors[rank] + dist.all_gather(gathered_tensors, tensor_to_gather) + + for i in range(len(gathered_tensors)): + assert torch.allclose(gathered_tensors[i], expected_tensors[i]) + + device = torch.device(f"cuda:{rank}") + t = torch.rand((rank + 1, rank + 1), dtype=torch.float16, device=device) + shapes = [] + for i in range(world_size): + shapes.append((i + 2, i + 2)) + gathered_t = [ + torch.empty(shapes[i], dtype=torch.float16, device=device) + for i in range(world_size) + ] + + try: + dist.all_gather(gathered_t, t) + except RuntimeError as e: + expected_error_message = "Tensor input and output of broadcast must have the same number of elements " + if str(e) == expected_error_message: + print( + "Correct exception raised with expected error message in test_allgather." + ) + else: + print(f"Incorrect error message: {str(e)}") + + cleanup() + + if __name__ == "__main__": port = random.randint(10000, 60000) # get device_count without "import torch_dipu" @@ -619,6 +672,8 @@ def test_get_comm_name(rank, world_size, port): run_demo(test_special_group_stuck, world_size, port) + run_demo(test_allgather, world_size, port, random.randint(0, 10000)) + # need 4 card to run if world_size >= 4: run_demo(test_new_group, world_size, port) diff --git a/dipu/tests/python/utils/random_shape.py b/dipu/tests/python/utils/random_shape.py new file mode 100644 index 000000000..e130c80f6 --- /dev/null +++ b/dipu/tests/python/utils/random_shape.py @@ -0,0 +1,47 @@ +from typing import List, Tuple +import numpy as np + + +__all__ = ["ShapeGenerator"] + + +class ShapeGenerator: + def __init__(self, seed=None): + self.rng = np.random.default_rng(seed) + + def random_shape( + self, + numel_range: Tuple[int, int], + rank_range: Tuple[int, int], + retry: int = 10, + ) -> List[int]: + """ + Generate a random shape. Ranges are inclusive. + """ + assert 0 < numel_range[0] <= numel_range[1] + assert 0 < rank_range[0] <= rank_range[1] + assert retry > 0 + rank = self.rng.integers(rank_range[0], rank_range[1], endpoint=True) + while True: + retry -= 1 + shape = self._try_random_shape(numel_range, rank) + if retry <= 0 or shape.prod() in range(numel_range[0], numel_range[1] + 1): + return shape.tolist() + + def _try_random_shape(self, numel_range: Tuple[int, int], rank: int) -> np.ndarray: + assert 0 < numel_range[0] <= numel_range[1] + lognumel_range = np.log(numel_range) + lognumel = self.rng.uniform(*lognumel_range) + logshape = self._random_partition(lognumel, rank) + shape = np.exp(logshape).round().astype(int) + return shape + + def _random_partition(self, total: float, part: int) -> np.ndarray: + """ + Randomly partition a total into part parts. + """ + assert total > 0 + assert part > 0 + parts = self.rng.random(part) + parts /= parts.sum() + return total * parts diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index ea958b12e..edfb032aa 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -325,7 +325,8 @@ void check_device_single_tensor( } // Check that all `tensors' -void checkDeviceTensors(const std::vector& tensors) { +void checkDeviceTensors(const std::vector& tensors, + bool check_sizes_and_strides) { if (tensors.empty()) { TORCH_CHECK(false, "Tensor list must be nonempty"); } @@ -348,11 +349,13 @@ void checkDeviceTensors(const std::vector& tensors) { if (tensor.scalar_type() != first.scalar_type()) { TORCH_CHECK(false, "Tensors must have identical type"); } - if (tensor.sizes() != first.sizes()) { - TORCH_CHECK(false, "Tensors must have identical size"); - } - if (tensor.strides() != first.strides()) { - TORCH_CHECK(false, "Tensors must have identical strides"); + if (check_sizes_and_strides) { + if (tensor.sizes() != first.sizes()) { + TORCH_CHECK(false, "Tensors must have identical size"); + } + if (tensor.strides() != first.strides()) { + TORCH_CHECK(false, "Tensors must have identical strides"); + } } const auto inserted = usedDevices.insert(tensor.get_device()).second; if (!inserted) { @@ -536,7 +539,7 @@ c10::intrusive_ptr ProcessGroupDICL::pointToPoint( c10::intrusive_ptr ProcessGroupDICL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { // inplace in = out, every rank use both in&out. - checkDeviceTensors(tensors); + checkDeviceTensors(tensors, true); std::vector tensors_cp{tensors}; return collective( tensors_cp, tensors_cp, @@ -565,7 +568,7 @@ c10::intrusive_ptr ProcessGroupDICL::allreduce( c10::intrusive_ptr ProcessGroupDICL::broadcast( std::vector& tensors, const BroadcastOptions& opts) { - checkDeviceTensors(tensors); + checkDeviceTensors(tensors, true); // inplace in = out, only rootRank use in. return collective( tensors, tensors, @@ -587,7 +590,7 @@ c10::intrusive_ptr ProcessGroupDICL::broadcast( c10::intrusive_ptr ProcessGroupDICL::reduce( std::vector& tensors, const ReduceOptions& opts) { // inplace in = out, only rootRank use out. - checkDeviceTensors(tensors); + checkDeviceTensors(tensors, true); auto tensor = tensors.back(); int dev_in_group = 0; @@ -628,7 +631,7 @@ c10::intrusive_ptr ProcessGroupDICL::gather( int curRank = getRank(); int rootRank = static_cast(opts.rootRank); c10d::assertRootRank(raise_invalid_arg_func, rootRank, numRanks); - checkDeviceTensors(inputs); + checkDeviceTensors(inputs, true); c10d::assertSingleElementInput(raise_invalid_arg_func, inputs); auto input = inputs.back(); std::vector outputTensors; @@ -684,38 +687,90 @@ std::string_view ProcessGroupDICL::getCommName( c10::intrusive_ptr ProcessGroupDICL::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { - checkDeviceTensors(inputs); - // output = input * ranks, no inplace. every ranks use both in&out. - auto outputFlattened = - flatten_for_scatter_gather(outputs, inputs, this->size_); + auto inputTensor = inputs.back(); + auto outputs_ = outputs.back(); + const at::Tensor& first_tensor = outputs_.front(); + bool same_size = std::all_of(outputs_.begin(), outputs_.end(), + [&first_tensor](const at::Tensor& t) { + return first_tensor.is_same_size(t); + }); + checkDeviceTensors(inputs, same_size); + if (same_size) { + // output = input * ranks, no inplace. every ranks use both in&out. + auto outputFlattened = + flatten_for_scatter_gather(outputs, inputs, this->size_); + + auto work = collective( + inputs, outputFlattened, + [&](at::Tensor& input, at::Tensor& output, diclComm_t comm, + DIPUStream& stream) { + RECORD_FUNCTION("DiclAllgather", std::vector({input})); + profile::RecordBlockCreator _("DiclAllgather", stream.rawstream(), + static_cast(stream.id())); - auto work = collective( - inputs, outputFlattened, - [&](at::Tensor& input, at::Tensor& output, diclComm_t comm, + return devproxy::diclAllGather(input.data_ptr(), output.data_ptr(), + static_cast(input.numel()), + input.scalar_type(), comm, + stream.rawstream()); + }, + [&](std::vector>& diclComms) {}, + [&](std::vector>& diclComms) { + // Copy the flattened output tensors to the outputs. + for (size_t i = 0; i < outputs.size(); ++i) { + // warnning & todo:: copy in comm stream, + // record dest tensor outputs, because src tensor outputFlattened + // already recorded in collective. + copyInCommStream(diclComms[i], outputs[i], outputFlattened[i], + static_cast(outputs[i].size())); + // copyInCurrentStream(diclComms[i], outputs[i], + // outputFlattened[i]); + } + }, + OpType::ALLGATHER); + return work; + } + const auto num_broadcasts = outputs_.size(); + for (const auto i : c10::irange(num_broadcasts)) { + auto& outCheck = outputs_[i]; + auto& inCheck = (i == this->rank_) ? inputTensor : outCheck; + if (outCheck.numel() != inCheck.numel()) { + throw std::runtime_error( + "Tensor input and output of broadcast must have the " + "same number of elements "); + } + } + return collective( + inputs, outputs_, + [&](at::Tensor& /* unused */, at::Tensor& /* unused */, diclComm_t comm, DIPUStream& stream) { - RECORD_FUNCTION("DiclAllgather", std::vector({input})); - profile::RecordBlockCreator _("DiclAllgather", stream.rawstream(), - static_cast(stream.id())); - - return devproxy::diclAllGather(input.data_ptr(), output.data_ptr(), - static_cast(input.numel()), - input.scalar_type(), comm, - stream.rawstream()); - }, - [&](std::vector>& diclComms) {}, - [&](std::vector>& diclComms) { - // Copy the flattened output tensors to the outputs. - for (size_t i = 0; i < outputs.size(); ++i) { - // warnning & todo:: copy in comm stream, - // record dest tensor outputs, because src tensor outputFlattened - // already recorded in collective. - copyInCommStream(diclComms[i], outputs[i], outputFlattened[i], - static_cast(outputs[i].size())); - // copyInCurrentStream(diclComms[i], outputs[i], outputFlattened[i]); + for (const auto i : c10::irange(num_broadcasts)) { + auto& outTensor = outputs_[i]; + const auto root = static_cast(i); + // Just for the convenience of calling collective, it is necessary to + // record the output elements of different devices, and the work logic + // is correct. + dipu::recordStream(outTensor, stream); + auto& inTensor = i == this->rank_ ? inputTensor : outTensor; + RECORD_FUNCTION("DiclBroadcast", + std::vector({inTensor})); + profile::RecordBlockCreator _("DiclBroadcast", stream.rawstream(), + static_cast(stream.id())); + devproxy::diclBroadcast( + inTensor.data_ptr(), outTensor.data_ptr(), + static_cast(inTensor.numel()), inTensor.scalar_type(), + static_cast(root), comm, stream.rawstream()); +#if DIPU_VENDOR_NAME_ASCEND + if (i == this->rank_) { + devproxy::memCopyD2DAsync( + stream.rawstream(), + static_cast(inTensor.numel()) * + static_cast(inTensor.element_size()), + i, outTensor.data_ptr(), i, inTensor.data_ptr()); + } +#endif } }, - OpType::ALLGATHER); - return work; + OpType::BROADCAST); } c10::intrusive_ptr ProcessGroupDICL::_allgather_base( @@ -787,7 +842,7 @@ c10::intrusive_ptr ProcessGroupDICL::scatter( int curRank = getRank(); int rootRank = static_cast(opts.rootRank); c10d::assertRootRank(raise_invalid_arg_func, rootRank, numRanks); - checkDeviceTensors(outputs); + checkDeviceTensors(outputs, true); c10d::assertSingleElementOutput(raise_invalid_arg_func, outputs); auto output = outputs.back(); std::vector inputTensors; @@ -838,10 +893,10 @@ c10::intrusive_ptr ProcessGroupDICL::reduce_scatter( std::vector>& inputs, const ReduceScatterOptions& opts) { // input = output * ranks, no inplace, output = reduced(input)[rank] - checkDeviceTensors(outputs); + checkDeviceTensors(outputs, true); auto inputFlattened = flatten_for_scatter_gather(inputs, outputs, this->size_); - checkDeviceTensors(inputFlattened); + checkDeviceTensors(inputFlattened, true); auto work = collective( inputFlattened, outputs, @@ -1014,7 +1069,7 @@ c10::intrusive_ptr ProcessGroupDICL::alltoall( c10::intrusive_ptr ProcessGroupDICL::send( std::vector& tensors, int dstRank, int tag) { - checkDeviceTensors(tensors); + checkDeviceTensors(tensors, true); auto p2pPair = mapPGRank2P2P(rank_, dstRank); return pointToPoint( tensors, tensors, dstRank, @@ -1033,7 +1088,7 @@ c10::intrusive_ptr ProcessGroupDICL::send( c10::intrusive_ptr ProcessGroupDICL::recv( std::vector& tensors, int srcRank, int tag) { - checkDeviceTensors(tensors); + checkDeviceTensors(tensors, true); auto p2pPair = mapPGRank2P2P(rank_, srcRank); return pointToPoint( tensors, tensors, srcRank,