Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DIPU]add diff size allgather #941

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion dipu/tests/python/individual_scripts/test_rt_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.random_shape import ShapeGenerator


def debugat(rank=0):
Expand Down Expand Up @@ -473,12 +474,13 @@ def demo_model_parallel(rank, world_size, port):
cleanup()


def run_demo(demo_fn, world_size, port):
def run_demo(demo_fn, world_size, port, *args):
mp.spawn(
demo_fn,
args=(
world_size,
port,
*args,
),
nprocs=world_size,
join=True,
Expand Down Expand Up @@ -582,6 +584,57 @@ def test_get_comm_name(rank, world_size, port):
cleanup()


def test_allgather(rank, world_size, port, seed):
random.seed(seed)
torch.manual_seed(seed)
import torch_dipu

print(f"test allgather on rank {rank} ws: {world_size}")
setup(rank, world_size, port)
shape_gen = ShapeGenerator(seed=seed)

gathered_tensors = []
expected_tensors = []
for i in range(world_size):
shape = shape_gen.random_shape(
(1, 100000),
(1, 4),
)
device = torch.device(f"cuda:{rank}")
tensor = torch.rand(size=shape, dtype=torch.float16)
tensor = tensor.to(device)
gathered_tensors.append(torch.empty_like(tensor))
expected_tensors.append(tensor)
tensor_to_gather = expected_tensors[rank]
dist.all_gather(gathered_tensors, tensor_to_gather)

for i in range(len(gathered_tensors)):
assert torch.allclose(gathered_tensors[i], expected_tensors[i])

device = torch.device(f"cuda:{rank}")
t = torch.rand((rank + 1, rank + 1), dtype=torch.float16, device=device)
shapes = []
for i in range(world_size):
shapes.append((i + 2, i + 2))
gathered_t = [
torch.empty(shapes[i], dtype=torch.float16, device=device)
for i in range(world_size)
]

try:
dist.all_gather(gathered_t, t)
except RuntimeError as e:
expected_error_message = "Tensor input and output of broadcast must have the same number of elements "
if str(e) == expected_error_message:
print(
"Correct exception raised with expected error message in test_allgather."
)
else:
print(f"Incorrect error message: {str(e)}")

cleanup()


if __name__ == "__main__":
port = random.randint(10000, 60000)
# get device_count without "import torch_dipu"
Expand Down Expand Up @@ -619,6 +672,8 @@ def test_get_comm_name(rank, world_size, port):

run_demo(test_special_group_stuck, world_size, port)

run_demo(test_allgather, world_size, port, random.randint(0, 10000))

# need 4 card to run
if world_size >= 4:
run_demo(test_new_group, world_size, port)
47 changes: 47 additions & 0 deletions dipu/tests/python/utils/random_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List, Tuple
import numpy as np


__all__ = ["ShapeGenerator"]


class ShapeGenerator:
def __init__(self, seed=None):
self.rng = np.random.default_rng(seed)

def random_shape(
self,
numel_range: Tuple[int, int],
rank_range: Tuple[int, int],
retry: int = 10,
) -> List[int]:
"""
Generate a random shape. Ranges are inclusive.
"""
assert 0 < numel_range[0] <= numel_range[1]
assert 0 < rank_range[0] <= rank_range[1]
assert retry > 0
rank = self.rng.integers(rank_range[0], rank_range[1], endpoint=True)
while True:
retry -= 1
shape = self._try_random_shape(numel_range, rank)
if retry <= 0 or shape.prod() in range(numel_range[0], numel_range[1] + 1):
return shape.tolist()

def _try_random_shape(self, numel_range: Tuple[int, int], rank: int) -> np.ndarray:
assert 0 < numel_range[0] <= numel_range[1]
lognumel_range = np.log(numel_range)
lognumel = self.rng.uniform(*lognumel_range)
logshape = self._random_partition(lognumel, rank)
shape = np.exp(logshape).round().astype(int)
return shape

def _random_partition(self, total: float, part: int) -> np.ndarray:
"""
Randomly partition a total into part parts.
"""
assert total > 0
assert part > 0
parts = self.rng.random(part)
parts /= parts.sum()
return total * parts
141 changes: 98 additions & 43 deletions dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ void check_device_single_tensor(
}

// Check that all `tensors'
void checkDeviceTensors(const std::vector<at::Tensor>& tensors) {
void checkDeviceTensors(const std::vector<at::Tensor>& tensors,
bool check_sizes_and_strides) {
if (tensors.empty()) {
TORCH_CHECK(false, "Tensor list must be nonempty");
}
Expand All @@ -348,11 +349,13 @@ void checkDeviceTensors(const std::vector<at::Tensor>& tensors) {
if (tensor.scalar_type() != first.scalar_type()) {
TORCH_CHECK(false, "Tensors must have identical type");
}
if (tensor.sizes() != first.sizes()) {
TORCH_CHECK(false, "Tensors must have identical size");
}
if (tensor.strides() != first.strides()) {
TORCH_CHECK(false, "Tensors must have identical strides");
if (check_sizes_and_strides) {
if (tensor.sizes() != first.sizes()) {
TORCH_CHECK(false, "Tensors must have identical size");
}
if (tensor.strides() != first.strides()) {
TORCH_CHECK(false, "Tensors must have identical strides");
}
}
const auto inserted = usedDevices.insert(tensor.get_device()).second;
if (!inserted) {
Expand Down Expand Up @@ -536,7 +539,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::pointToPoint(
c10::intrusive_ptr<Work> ProcessGroupDICL::allreduce(
std::vector<at::Tensor>& tensors, const AllreduceOptions& opts) {
// inplace in = out, every rank use both in&out.
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);
std::vector<at::Tensor> tensors_cp{tensors};
return collective(
tensors_cp, tensors_cp,
Expand Down Expand Up @@ -565,7 +568,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::allreduce(

c10::intrusive_ptr<Work> ProcessGroupDICL::broadcast(
std::vector<at::Tensor>& tensors, const BroadcastOptions& opts) {
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);
// inplace in = out, only rootRank use in.
return collective(
tensors, tensors,
Expand All @@ -587,7 +590,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::broadcast(
c10::intrusive_ptr<Work> ProcessGroupDICL::reduce(
std::vector<at::Tensor>& tensors, const ReduceOptions& opts) {
// inplace in = out, only rootRank use out.
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);

auto tensor = tensors.back();
int dev_in_group = 0;
Expand Down Expand Up @@ -628,7 +631,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::gather(
int curRank = getRank();
int rootRank = static_cast<int>(opts.rootRank);
c10d::assertRootRank(raise_invalid_arg_func, rootRank, numRanks);
checkDeviceTensors(inputs);
checkDeviceTensors(inputs, true);
c10d::assertSingleElementInput(raise_invalid_arg_func, inputs);
auto input = inputs.back();
std::vector<at::Tensor> outputTensors;
Expand Down Expand Up @@ -684,38 +687,90 @@ std::string_view ProcessGroupDICL::getCommName(
c10::intrusive_ptr<Work> ProcessGroupDICL::allgather(
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs, const AllgatherOptions& opts) {
checkDeviceTensors(inputs);
// output = input * ranks, no inplace. every ranks use both in&out.
auto outputFlattened =
flatten_for_scatter_gather(outputs, inputs, this->size_);
auto inputTensor = inputs.back();
auto outputs_ = outputs.back();
const at::Tensor& first_tensor = outputs_.front();
bool same_size = std::all_of(outputs_.begin(), outputs_.end(),
[&first_tensor](const at::Tensor& t) {
return first_tensor.is_same_size(t);
});
checkDeviceTensors(inputs, same_size);
if (same_size) {
// output = input * ranks, no inplace. every ranks use both in&out.
auto outputFlattened =
flatten_for_scatter_gather(outputs, inputs, this->size_);

auto work = collective(
inputs, outputFlattened,
[&](at::Tensor& input, at::Tensor& output, diclComm_t comm,
DIPUStream& stream) {
RECORD_FUNCTION("DiclAllgather", std::vector<c10::IValue>({input}));
profile::RecordBlockCreator _("DiclAllgather", stream.rawstream(),
static_cast<int>(stream.id()));

auto work = collective(
inputs, outputFlattened,
[&](at::Tensor& input, at::Tensor& output, diclComm_t comm,
return devproxy::diclAllGather(input.data_ptr(), output.data_ptr(),
static_cast<size_t>(input.numel()),
input.scalar_type(), comm,
stream.rawstream());
},
[&](std::vector<std::shared_ptr<DICLComm>>& diclComms) {},
[&](std::vector<std::shared_ptr<DICLComm>>& diclComms) {
// Copy the flattened output tensors to the outputs.
for (size_t i = 0; i < outputs.size(); ++i) {
// warnning & todo:: copy in comm stream,
// record dest tensor outputs, because src tensor outputFlattened
// already recorded in collective.
copyInCommStream<true>(diclComms[i], outputs[i], outputFlattened[i],
static_cast<int>(outputs[i].size()));
// copyInCurrentStream(diclComms[i], outputs[i],
// outputFlattened[i]);
}
},
OpType::ALLGATHER);
return work;
}
const auto num_broadcasts = outputs_.size();
for (const auto i : c10::irange(num_broadcasts)) {
auto& outCheck = outputs_[i];
auto& inCheck = (i == this->rank_) ? inputTensor : outCheck;
if (outCheck.numel() != inCheck.numel()) {
throw std::runtime_error(
"Tensor input and output of broadcast must have the "
"same number of elements ");
}
}
return collective(
inputs, outputs_,
[&](at::Tensor& /* unused */, at::Tensor& /* unused */, diclComm_t comm,
DIPUStream& stream) {
RECORD_FUNCTION("DiclAllgather", std::vector<c10::IValue>({input}));
profile::RecordBlockCreator _("DiclAllgather", stream.rawstream(),
static_cast<int>(stream.id()));

return devproxy::diclAllGather(input.data_ptr(), output.data_ptr(),
static_cast<size_t>(input.numel()),
input.scalar_type(), comm,
stream.rawstream());
},
[&](std::vector<std::shared_ptr<DICLComm>>& diclComms) {},
[&](std::vector<std::shared_ptr<DICLComm>>& diclComms) {
// Copy the flattened output tensors to the outputs.
for (size_t i = 0; i < outputs.size(); ++i) {
// warnning & todo:: copy in comm stream,
// record dest tensor outputs, because src tensor outputFlattened
// already recorded in collective.
copyInCommStream<true>(diclComms[i], outputs[i], outputFlattened[i],
static_cast<int>(outputs[i].size()));
// copyInCurrentStream(diclComms[i], outputs[i], outputFlattened[i]);
for (const auto i : c10::irange(num_broadcasts)) {
auto& outTensor = outputs_[i];
const auto root = static_cast<int64_t>(i);
// Just for the convenience of calling collective, it is necessary to
// record the output elements of different devices, and the work logic
// is correct.
dipu::recordStream(outTensor, stream);
auto& inTensor = i == this->rank_ ? inputTensor : outTensor;
RECORD_FUNCTION("DiclBroadcast",
std::vector<c10::IValue>({inTensor}));
profile::RecordBlockCreator _("DiclBroadcast", stream.rawstream(),
static_cast<int>(stream.id()));
devproxy::diclBroadcast(
inTensor.data_ptr(), outTensor.data_ptr(),
static_cast<size_t>(inTensor.numel()), inTensor.scalar_type(),
static_cast<int>(root), comm, stream.rawstream());
#if DIPU_VENDOR_NAME_ASCEND
if (i == this->rank_) {
devproxy::memCopyD2DAsync(
stream.rawstream(),
static_cast<size_t>(inTensor.numel()) *
static_cast<size_t>(inTensor.element_size()),
i, outTensor.data_ptr(), i, inTensor.data_ptr());
}
#endif
}
},
OpType::ALLGATHER);
return work;
OpType::BROADCAST);
}

c10::intrusive_ptr<Work> ProcessGroupDICL::_allgather_base(
Expand Down Expand Up @@ -787,7 +842,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::scatter(
int curRank = getRank();
int rootRank = static_cast<int>(opts.rootRank);
c10d::assertRootRank(raise_invalid_arg_func, rootRank, numRanks);
checkDeviceTensors(outputs);
checkDeviceTensors(outputs, true);
c10d::assertSingleElementOutput(raise_invalid_arg_func, outputs);
auto output = outputs.back();
std::vector<at::Tensor> inputTensors;
Expand Down Expand Up @@ -838,10 +893,10 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::reduce_scatter(
std::vector<std::vector<at::Tensor>>& inputs,
const ReduceScatterOptions& opts) {
// input = output * ranks, no inplace, output = reduced(input)[rank]
checkDeviceTensors(outputs);
checkDeviceTensors(outputs, true);
auto inputFlattened =
flatten_for_scatter_gather(inputs, outputs, this->size_);
checkDeviceTensors(inputFlattened);
checkDeviceTensors(inputFlattened, true);

auto work = collective(
inputFlattened, outputs,
Expand Down Expand Up @@ -1014,7 +1069,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::alltoall(

c10::intrusive_ptr<Work> ProcessGroupDICL::send(
std::vector<at::Tensor>& tensors, int dstRank, int tag) {
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);
auto p2pPair = mapPGRank2P2P(rank_, dstRank);
return pointToPoint(
tensors, tensors, dstRank,
Expand All @@ -1033,7 +1088,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::send(

c10::intrusive_ptr<Work> ProcessGroupDICL::recv(
std::vector<at::Tensor>& tensors, int srcRank, int tag) {
checkDeviceTensors(tensors);
checkDeviceTensors(tensors, true);
auto p2pPair = mapPGRank2P2P(rank_, srcRank);
return pointToPoint(
tensors, tensors, srcRank,
Expand Down