Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DIPU]add diff size allgather #941

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion dipu/tests/python/individual_scripts/test_rt_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.random_shape import ShapeGenerator


def debugat(rank=0):
Expand Down Expand Up @@ -473,12 +474,13 @@ def demo_model_parallel(rank, world_size, port):
cleanup()


def run_demo(demo_fn, world_size, port):
def run_demo(demo_fn, world_size, port, *args):
mp.spawn(
demo_fn,
args=(
world_size,
port,
*args,
),
nprocs=world_size,
join=True,
Expand Down Expand Up @@ -582,6 +584,35 @@ def test_get_comm_name(rank, world_size, port):
cleanup()


def test_allgather(rank, world_size, port, seed):
random.seed(seed)
torch.manual_seed(seed)
import torch_dipu

print(f"test allgather on rank {rank} ws: {world_size}")
setup(rank, world_size, port)
shape_gen = ShapeGenerator(seed=seed)

gathered_tensors = []
expected_tensors = []
for i in range(world_size):
shape = shape_gen.random_shape(
(1, 100000),
(1, 4),
)
device = torch.device(f"cuda:{rank}")
tensor = torch.rand(size=shape, dtype=torch.float16)
tensor = tensor.to(device)
gathered_tensors.append(torch.empty_like(tensor))
expected_tensors.append(tensor)
tensor_to_gather = expected_tensors[rank]
dist.all_gather(gathered_tensors, tensor_to_gather)

for i in range(len(gathered_tensors)):
assert torch.allclose(gathered_tensors[i], expected_tensors[i])
cleanup()


if __name__ == "__main__":
port = random.randint(10000, 60000)
# get device_count without "import torch_dipu"
Expand Down Expand Up @@ -619,6 +650,8 @@ def test_get_comm_name(rank, world_size, port):

run_demo(test_special_group_stuck, world_size, port)

run_demo(test_allgather, world_size, port, random.randint(0, 10000))

# need 4 card to run
if world_size >= 4:
run_demo(test_new_group, world_size, port)
47 changes: 47 additions & 0 deletions dipu/tests/python/utils/random_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List, Tuple
import numpy as np


__all__ = ["ShapeGenerator"]


class ShapeGenerator:
def __init__(self, seed=None):
self.rng = np.random.default_rng(seed)

def random_shape(
self,
numel_range: Tuple[int, int],
rank_range: Tuple[int, int],
retry: int = 10,
) -> List[int]:
"""
Generate a random shape. Ranges are inclusive.
"""
assert 0 < numel_range[0] <= numel_range[1]
assert 0 < rank_range[0] <= rank_range[1]
assert retry > 0
rank = self.rng.integers(rank_range[0], rank_range[1], endpoint=True)
while True:
retry -= 1
shape = self._try_random_shape(numel_range, rank)
if retry <= 0 or shape.prod() in range(numel_range[0], numel_range[1] + 1):
return shape.tolist()

def _try_random_shape(self, numel_range: Tuple[int, int], rank: int) -> np.ndarray:
assert 0 < numel_range[0] <= numel_range[1]
lognumel_range = np.log(numel_range)
lognumel = self.rng.uniform(*lognumel_range)
logshape = self._random_partition(lognumel, rank)
shape = np.exp(logshape).round().astype(int)
return shape

def _random_partition(self, total: float, part: int) -> np.ndarray:
"""
Randomly partition a total into part parts.
"""
assert total > 0
assert part > 0
parts = self.rng.random(part)
parts /= parts.sum()
return total * parts
104 changes: 73 additions & 31 deletions dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,15 @@ std::vector<at::Tensor> flatten_for_scatter_gather(
return flattened;
}

bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
for (const auto& input_tensor : input_tensors) {
if (!input_tensors[0].is_same_size(input_tensor)) {
return false;
}
}
return true;
}

template <bool RecordDest, typename Dest, typename Src>
void copyInCommStream(std::shared_ptr<DICLComm>& diclComm, const Dest& dest,
const Src& src, int nums) {
Expand Down Expand Up @@ -684,38 +693,71 @@ std::string_view ProcessGroupDICL::getCommName(
c10::intrusive_ptr<Work> ProcessGroupDICL::allgather(
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs, const AllgatherOptions& opts) {
checkDeviceTensors(inputs);
// output = input * ranks, no inplace. every ranks use both in&out.
auto outputFlattened =
flatten_for_scatter_gather(outputs, inputs, this->size_);

auto work = collective(
inputs, outputFlattened,
[&](at::Tensor& input, at::Tensor& output, diclComm_t comm,
DIPUStream& stream) {
RECORD_FUNCTION("DiclAllgather", std::vector<c10::IValue>({input}));
profile::RecordBlockCreator _("DiclAllgather", stream.rawstream(),
static_cast<int>(stream.id()));
auto inputTensor = inputs.back();
auto outputs_ = outputs.back();
bool same_size = check_same_size(outputs_);
if (same_size) {
checkDeviceTensors(inputs);
// output = input * ranks, no inplace. every ranks use both in&out.
auto outputFlattened =
flatten_for_scatter_gather(outputs, inputs, this->size_);

auto work = collective(
inputs, outputFlattened,
[&](at::Tensor& input, at::Tensor& output, diclComm_t comm,
DIPUStream& stream) {
RECORD_FUNCTION("DiclAllgather", std::vector<c10::IValue>({input}));
profile::RecordBlockCreator _("DiclAllgather", stream.rawstream(),
static_cast<int>(stream.id()));

return devproxy::diclAllGather(input.data_ptr(), output.data_ptr(),
static_cast<size_t>(input.numel()),
input.scalar_type(), comm,
stream.rawstream());
},
[&](std::vector<std::shared_ptr<DICLComm>>& diclComms) {},
[&](std::vector<std::shared_ptr<DICLComm>>& diclComms) {
// Copy the flattened output tensors to the outputs.
for (size_t i = 0; i < outputs.size(); ++i) {
// warnning & todo:: copy in comm stream,
// record dest tensor outputs, because src tensor outputFlattened
// already recorded in collective.
copyInCommStream<true>(diclComms[i], outputs[i], outputFlattened[i],
static_cast<int>(outputs[i].size()));
// copyInCurrentStream(diclComms[i], outputs[i], outputFlattened[i]);
}
},
OpType::ALLGATHER);
return work;
return devproxy::diclAllGather(input.data_ptr(), output.data_ptr(),
static_cast<size_t>(input.numel()),
input.scalar_type(), comm,
stream.rawstream());
},
[&](std::vector<std::shared_ptr<DICLComm>>& diclComms) {},
[&](std::vector<std::shared_ptr<DICLComm>>& diclComms) {
// Copy the flattened output tensors to the outputs.
for (size_t i = 0; i < outputs.size(); ++i) {
// warnning & todo:: copy in comm stream,
// record dest tensor outputs, because src tensor outputFlattened
// already recorded in collective.
copyInCommStream<true>(diclComms[i], outputs[i], outputFlattened[i],
static_cast<int>(outputs[i].size()));
// copyInCurrentStream(diclComms[i], outputs[i],
// outputFlattened[i]);
}
},
OpType::ALLGATHER);
return work;
} else {
const auto num_reduces = outputs_.size();
return collective(
inputs, outputs_,
[&](at::Tensor& input, at::Tensor& output, diclComm_t comm,
DIPUStream& stream) {
for (const int i : c10::irange(num_reduces)) {
auto& opt = outputs_[i];
auto& ipt = (i == this->rank_) ? inputTensor : opt;
auto broadcastOpts = BroadcastOptions{
static_cast<int64_t>(i), static_cast<int64_t>(0), opts.timeout};
if (opt.numel() != ipt.numel()) {
throw std::runtime_error(
"Tensor input and output of _broadcast_oop must have the "
"same number of elements ");
}
RECORD_FUNCTION("DiclBroadcast", std::vector<c10::IValue>({ipt}));
profile::RecordBlockCreator _("DiclBroadcast", stream.rawstream(),
static_cast<int>(stream.id()));
const auto root = broadcastOpts.rootRank + broadcastOpts.rootTensor;
devproxy::diclBroadcast(ipt.data_ptr(), opt.data_ptr(),
static_cast<size_t>(ipt.numel()),
ipt.scalar_type(), static_cast<int>(root),
comm, stream.rawstream());
}
},
OpType::BROADCAST);
}
}

c10::intrusive_ptr<Work> ProcessGroupDICL::_allgather_base(
Expand Down
Loading