Skip to content

Commit

Permalink
[DIPU] Implement all_to_all_single (with unequal splits) and all_to_a…
Browse files Browse the repository at this point in the history
…ll (#842)

* implement all_to_all_single with unequal splits

* update a comment for diclAllToAllUnequalSplit

* make clang-tidy happy

* implement all_to_all

* make clang-tidy happy

* remove functions copied from c10d

* add examples to the comments for alltoall tests
  • Loading branch information
jfxu-st authored Jun 21, 2024
1 parent ffd1d97 commit 7eb759c
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 8 deletions.
82 changes: 81 additions & 1 deletion dipu/tests/python/individual_scripts/test_rt_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def demo_alltoall_base_equal_split(rank, world_size, port):

expected = torch.cat(
[
(torch.arange(split_size) + i * tensor_size + rank * split_size)
torch.arange(split_size) + i * tensor_size + rank * split_size
for i in range(world_size)
]
)
Expand All @@ -369,6 +369,84 @@ def demo_alltoall_base_equal_split(rank, world_size, port):
cleanup()


def demo_alltoall_base_unequal_split(rank, world_size, port):
import torch_dipu

setup(rank, world_size, port)

# Example: For world_size = 2,
# input_split_sizes: [1,2] (rank 0)
# [3,4] (rank 1)
# output_split_sizes: [1,3] (rank 0)
# [2,4] (rank 1)
# src: [0, 1, 2] (rank 0)
# [3, 4, 5, 6, 7, 8, 9] (rank 1)
# expected / dst: [0, 3, 4, 5] (rank 0)
# [1, 2, 6, 7, 8, 9] (rank 1)

input_split_sizes = torch.arange(world_size) + 1 + rank * world_size
output_split_sizes = torch.arange(0, world_size * world_size, world_size) + 1 + rank
src = (
torch.arange(input_split_sizes.sum().item())
+ torch.arange(input_split_sizes[0]).sum().item()
).to(rank)
dst = torch.empty(output_split_sizes.sum().item(), dtype=torch.int64).to(rank)

expected = torch.cat(
[
torch.arange(output_split_sizes[i])
+ torch.arange(output_split_sizes[i]).sum().item()
for i in range(world_size)
]
)

dist.all_to_all_single(
dst, src, output_split_sizes.tolist(), input_split_sizes.tolist()
)
dist.barrier()
assert torch.allclose(expected, dst.cpu())
cleanup()


def demo_alltoall(rank, world_size, port):
import torch_dipu

setup(rank, world_size, port)

# Example: For world_size = 2,
# src: [[0], [1, 2]] (rank 0)
# [[3, 4, 5], [6, 7, 8, 9]] (rank 1)
# expected / dst: [[0], [3, 4, 5]] (rank 0)
# [[1, 2], [6, 7, 8, 9]] (rank 1)

input_split_sizes = torch.arange(world_size) + 1 + rank * world_size
output_split_sizes = torch.arange(0, world_size * world_size, world_size) + 1 + rank
src = list(
(
torch.arange(input_split_sizes.sum().item())
+ torch.arange(input_split_sizes[0]).sum().item()
).split(input_split_sizes.tolist())
)
src = [tensor.to(rank) for tensor in src]
dst = list(
torch.empty(output_split_sizes.sum().item(), dtype=torch.int64).split(
output_split_sizes.tolist()
)
)
dst = [tensor.to(rank) for tensor in dst]

expected = [
torch.arange(output_split_sizes[i])
+ torch.arange(output_split_sizes[i]).sum().item()
for i in range(world_size)
]
dist.all_to_all(dst, src)
dist.barrier()
for i in range(world_size):
assert torch.allclose(expected[i], dst[i].cpu())
cleanup()


def demo_model_parallel(rank, world_size, port):
print(f"Running DDP with model parallel example on rank {rank}.")
backend = "nccl"
Expand Down Expand Up @@ -477,6 +555,8 @@ def test_get_comm_name(rank, world_size, port):
run_demo(demo_reducescatter, world_size, port)
run_demo(demo_reducescatter_base, world_size, port)
run_demo(demo_alltoall_base_equal_split, world_size, port)
run_demo(demo_alltoall_base_unequal_split, world_size, port)
run_demo(demo_alltoall, world_size, port)
run_demo(demo_gather, world_size, port)
run_demo(demo_scatter, world_size, port)

Expand Down
6 changes: 6 additions & 0 deletions dipu/torch_dipu/csrc_dipu/runtime/device/diclapis.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ DIPU_WEAK diclResult_t diclAllToAllEqualSplit(const void* sendBuf,
diclComm_t comm,
deviceStream_t stream);

DIPU_WEAK diclResult_t diclAllToAllUnequalSplit(
const void* sendBuf, const size_t* sendCounts,
const size_t* sendDisplacements, void* recvBuf, const size_t* recvCounts,
const size_t* recvDisplacements, at::ScalarType dataType, diclComm_t comm,
deviceStream_t stream);

DIPU_API diclResult_t diclSend(const void* sendBuf, size_t count,
at::ScalarType datatype, int peer,
diclComm_t comm, deviceStream_t stream);
Expand Down
61 changes: 57 additions & 4 deletions dipu/torch_dipu/csrc_dipu/runtime/devproxy/diclproxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ devapis::diclResult_t diclAllToAllEqualSplit(
comm, stream);
}

// TODO(jfxu-st): For CUDA, use NCCL Group Calls for higher performance
// TODO(jfxu-st): For CUDA, use NCCL group calls for higher performance
// Ref:
// https://github.com/pytorch/pytorch/blob/f2d7f235a684c593f5a1ff2ca0b47b47274bfe85/torch/csrc/cuda/nccl.cpp#L828-L838
// Ref:
Expand All @@ -137,9 +137,9 @@ devapis::diclResult_t diclAllToAllEqualSplit(
"implementation based on devproxy::diclScatter will be used")
const size_t numBytesPerRank = count * c10::elementSize(dataType);
std::vector<const void*> sendBuf2d(commSize);
for (const auto peer : c10::irange(commSize)) {
sendBuf2d[peer] =
reinterpret_cast<const char*>(sendBuf) + peer * numBytesPerRank;
for (const auto scatterRootRank : c10::irange(commSize)) {
sendBuf2d[scatterRootRank] = reinterpret_cast<const char*>(sendBuf) +
scatterRootRank * numBytesPerRank;
}
for (const auto peer : c10::irange(commSize)) {
diclScatter(sendBuf2d.data(),
Expand All @@ -149,6 +149,59 @@ devapis::diclResult_t diclAllToAllEqualSplit(
return devapis::DICL_SUCCESS;
}

DIPU_API devapis::diclResult_t diclAllToAllUnequalSplit(
const void* sendBuf, const size_t* sendCounts,
const size_t* sendDisplacements, void* recvBuf, const size_t* recvCounts,
const size_t* recvDisplacements, at::ScalarType dataType, diclComm_t comm,
deviceStream_t stream, int currRank, int commSize) {
if (devapis::diclAllToAllUnequalSplit) {
return devapis::diclAllToAllUnequalSplit(
sendBuf, sendCounts, sendDisplacements, recvBuf, recvCounts,
recvDisplacements, dataType, comm, stream);
}

// TODO(jfxu-st): For CUDA, use NCCL group calls for higher performance
// Ref:
// https://github.com/pytorch/pytorch/blob/f2d7f235a684c593f5a1ff2ca0b47b47274bfe85/torch/csrc/cuda/nccl.cpp#L871-L893

TORCH_WARN_ONCE(
"devapis::diclAllToAllUnequalSplit is not implemented, so a fallback "
"implementation based on devproxy::diclSend and devproxy::diclRecv will "
"be used")

size_t elementSize = c10::elementSize(dataType);
for (const auto scatterRootRank : c10::irange(commSize)) {
if (currRank != scatterRootRank) {
DIPU_CALL_DICLAPIS(
diclRecv(reinterpret_cast<char*>(recvBuf) +
recvDisplacements[scatterRootRank] * elementSize,
recvCounts[scatterRootRank], dataType, scatterRootRank, comm,
stream));
continue;
}

for (const auto dstRank : c10::irange(commSize)) {
if (dstRank == scatterRootRank) {
continue;
}
DIPU_CALL_DICLAPIS(diclSend(reinterpret_cast<const char*>(sendBuf) +
sendDisplacements[dstRank] * elementSize,
sendCounts[dstRank], dataType, dstRank, comm,
stream));
}

auto deviceId = static_cast<devapis::deviceId_t>(currRank);
devproxy::memCopyD2DAsync(stream, sendCounts[currRank] * elementSize,
deviceId,
reinterpret_cast<char*>(recvBuf) +
recvDisplacements[currRank] * elementSize,
deviceId,
reinterpret_cast<const char*>(sendBuf) +
sendDisplacements[currRank] * elementSize);
}
return devapis::DICL_SUCCESS;
}

devapis::diclResult_t diclSend(const void* sendbuff, size_t count,
at::ScalarType datatype, int peer,
diclComm_t comm, deviceStream_t stream) {
Expand Down
9 changes: 9 additions & 0 deletions dipu/torch_dipu/csrc_dipu/runtime/devproxy/diclproxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ DIPU_API devapis::diclResult_t diclAllToAllEqualSplit(
devapis::diclAllToAllEqualSplit is not implemented */
int currRank, int commSize);

DIPU_API devapis::diclResult_t diclAllToAllUnequalSplit(
const void* sendBuf, const size_t* sendCounts,
const size_t* sendDisplacements, void* recvBuf, const size_t* recvCounts,
const size_t* recvDisplacements, at::ScalarType dataType, diclComm_t comm,
deviceStream_t stream,
/* The following arguments are only used for a fallback implementation when
devapis::diclAllToAllUnequalSplit is not implemented */
int currRank, int commSize);

DIPU_API devapis::diclResult_t diclSend(const void* sendbuff, size_t count,
at::ScalarType datatype, int peer,
diclComm_t comm, deviceStream_t stream);
Expand Down
118 changes: 115 additions & 3 deletions dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@
#include <vector>

#include <ATen/core/TensorBody.h>
#include <ATen/ops/cat.h>
#include <ATen/record_function.h>
#include <c10/core/Device.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <c10/util/typeid.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/torch.h>

#include "csrc_dipu/aten/ops/NodispatchUtils.hpp"
#include "csrc_dipu/profiler/profiler.h"
#include "csrc_dipu/runtime/core/DIPUGuard.h"
#include "csrc_dipu/runtime/core/DIPUStream.h"
Expand Down Expand Up @@ -865,7 +872,8 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::reduce_scatter(
c10::intrusive_ptr<Work> ProcessGroupDICL::alltoall_base(
at::Tensor& outputTensor, at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes, const AllToAllOptions& opts) {
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts /* unused */) {
check_device_single_tensor(outputTensor, true);
check_device_single_tensor(inputTensor, true);
TORCH_CHECK(outputTensor.scalar_type() == inputTensor.scalar_type(),
Expand Down Expand Up @@ -896,8 +904,112 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::alltoall_base(
},
OpType::ALLTOALL_BASE);
}
// TODO(jfxu-st): support unequal splits
TORCH_CHECK(false, "DICL doesn't support unequal splits")

c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
auto outputs = std::vector<at::Tensor>{outputTensor};
auto inputs = std::vector<at::Tensor>{inputTensor};
return collective(
inputs, outputs,
[&](at::Tensor& input, at::Tensor& output, diclComm_t comm,
DIPUStream& stream) {
std::vector<size_t> outputCounts(size_);
std::vector<size_t> inputCounts(size_);
std::vector<size_t> outputDisplacements(size_);
std::vector<size_t> inputDisplacements(size_);
c10d::computeLengthsAndOffsets(outputSplitSizes, output, &outputCounts,
&outputDisplacements);
c10d::computeLengthsAndOffsets(inputSplitSizes, input, &inputCounts,
&inputDisplacements);
RECORD_FUNCTION("DiclAlltoAllUnequalSplit",
std::vector<c10::IValue>({input}));
profile::RecordBlockCreator _("DiclAlltoAllUnequalSplit",
stream.rawstream(),
static_cast<int>(stream.id()));
return devproxy::diclAllToAllUnequalSplit(
input.data_ptr(), inputCounts.data(), inputDisplacements.data(),
output.data_ptr(), outputCounts.data(), outputDisplacements.data(),
output.scalar_type(), comm, stream.rawstream(), rank_, size_);
},
OpType::ALLTOALL_BASE);
}

c10::intrusive_ptr<Work> ProcessGroupDICL::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& /* unused */) {
size_t numTensors = outputTensors.size();
TORCH_CHECK(numTensors == inputTensors.size(),
"Tensor lists must have identical length")
c10::Device device = outputTensors[0].device();
at::ScalarType dataType = outputTensors[0].scalar_type();
for (const auto i : c10::irange(numTensors)) {
check_device_single_tensor(outputTensors[i], true);
check_device_single_tensor(inputTensors[i], true);
TORCH_CHECK(device == outputTensors[i].device() &&
device == inputTensors[i].device(),
"Tensors must be on the same device")
TORCH_CHECK(dataType == outputTensors[i].scalar_type() &&
dataType == inputTensors[i].scalar_type(),
"Tensors must have identical data type")
}

// TODO(jfxu-st): For CUDA, use NCCL Group Calls for higher performance
// Ref:
// https://github.com/pytorch/pytorch/blob/f2d7f235a684c593f5a1ff2ca0b47b47274bfe85/torch/csrc/cuda/nccl.cpp#L916-L941

// TODO(jfxu-st): For the vendors that don't implement
// devapis::diclAllToAllUnequalSplit, including CUDA, we need a more
// performant fallback without using a flattened tensor for relay

std::vector<int64_t> outputSplitSizes(numTensors);
std::vector<int64_t> inputSplitSizes(numTensors);
int64_t outputFlattenedTensorSize = 0;
for (const auto i : c10::irange(numTensors)) {
outputSplitSizes[i] = outputTensors[i].numel();
inputSplitSizes[i] = inputTensors[i].numel();
outputFlattenedTensorSize += outputTensors[i].numel();
}
at::Tensor outputFlattenedTensor = native::nodispatch::empty(
{outputFlattenedTensorSize},
at::TensorOptions().device(dipu::DIPU_DEVICE_TYPE).dtype(dataType));
at::Tensor inputFlattenedTensor = at::cat(inputTensors);

auto outputs = std::vector<at::Tensor>{outputFlattenedTensor};
auto inputs = std::vector<at::Tensor>{inputFlattenedTensor};
return collective(
inputs, outputs,
[&](at::Tensor& input, at::Tensor& output, diclComm_t comm,
DIPUStream& stream) {
std::vector<size_t> outputCounts(size_);
std::vector<size_t> inputCounts(size_);
std::vector<size_t> outputDisplacements(size_);
std::vector<size_t> inputDisplacements(size_);
c10d::computeLengthsAndOffsets(outputSplitSizes, output, &outputCounts,
&outputDisplacements);
c10d::computeLengthsAndOffsets(inputSplitSizes, input, &inputCounts,
&inputDisplacements);
RECORD_FUNCTION("DiclAlltoAllUnequalSplit",
std::vector<c10::IValue>({input}));
profile::RecordBlockCreator _("DiclAlltoAllUnequalSplit",
stream.rawstream(),
static_cast<int>(stream.id()));
return devproxy::diclAllToAllUnequalSplit(
input.data_ptr(), inputCounts.data(), inputDisplacements.data(),
output.data_ptr(), outputCounts.data(), outputDisplacements.data(),
output.scalar_type(), comm, stream.rawstream(), rank_, size_);
},
[&](std::vector<std::shared_ptr<DICLComm>>&) {},
[&](std::vector<std::shared_ptr<DICLComm>>& comms) {
DIPUStreamGuard _(comms[0]->diclStream_.unwrap());
size_t offset = 0;
for (const auto i : c10::irange(numTensors)) {
outputTensors[i].copy_(
outputs[0].slice(0, offset, offset + outputSplitSizes[i]));
offset += outputSplitSizes[i];
}
},
OpType::ALLTOALL);
}

c10::intrusive_ptr<Work> ProcessGroupDICL::send(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ class DIPU_API ProcessGroupDICL : public Backend {
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts) override;

c10::intrusive_ptr<Work> alltoall(std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts) override;

c10::intrusive_ptr<Work> send(std::vector<at::Tensor>& tensors, int dstRank,
int tag) override;

Expand Down
13 changes: 13 additions & 0 deletions dipu/torch_dipu/csrc_dipu/runtime/distributed/c10dOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ c10::intrusive_ptr<Work> alltoall_base_dipu_(
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> alltoall_dipu_(
const at::TensorList& output_tensors, const at::TensorList& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group, int64_t timeout) {
auto output_tensors_vec = output_tensors.vec();
auto input_tensors_vec = input_tensors.vec();
auto work =
process_group->getBackend(dipu::DIPU_DEVICE_TYPE)
->alltoall(output_tensors_vec, input_tensors_vec,
AllToAllOptions{std::chrono::milliseconds(timeout)});
return {std::move(output_tensors_vec), work};
}

c10::intrusive_ptr<Work> barrier_dipu(
at::Tensor /* unused */, // NOLINT(performance-unnecessary-value-param)
const c10::intrusive_ptr<ProcessGroup>& process_group,
Expand Down Expand Up @@ -216,6 +228,7 @@ TORCH_LIBRARY_IMPL(c10d, DIPU_DEVICE_TYPE_MACRO, m) {
m.impl("reduce_scatter_", reduce_scatter_dipu_);
m.impl("_reduce_scatter_base_", _reduce_scatter_base_dipu_);
m.impl("alltoall_base_", alltoall_base_dipu_);
m.impl("alltoall_", alltoall_dipu_);
m.impl("barrier", barrier_dipu);

// not implement
Expand Down
Loading

0 comments on commit 7eb759c

Please sign in to comment.