Skip to content

Commit

Permalink
Post collectives from progress thread
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Jul 7, 2021
1 parent d327b1e commit 26aa619
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 28 deletions.
11 changes: 10 additions & 1 deletion include/torch_ucc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class ProcessGroupUCC : public ProcessGroup {
ucc_ee_h ee,
CommBase* comm)
: ProcessGroup::Work(-1, opType),
ee_(ee),
status_(status),
request_(request),
comm_(comm) {}
Expand All @@ -163,8 +164,12 @@ class ProcessGroupUCC : public ProcessGroup {
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
event_pool_t* ep = nullptr;
at::cuda::CUDAStream* stream;
#endif
protected:
ucc_coll_args_t args;
ucc_team_h team;
ucc_ee_h ee_;
ucc_status_t status_;
ucc_coll_req_h request_;
CommBase* comm_;
Expand Down Expand Up @@ -297,8 +302,12 @@ class CommPG {
std::condition_variable queue_produce_cv;
std::condition_variable queue_consume_cv;
std::deque<c10::intrusive_ptr<ProcessGroupUCC::WorkUCC>> progress_queue;
std::deque<c10::intrusive_ptr<ProcessGroupUCC::WorkUCC>> post_queue;
std::deque<c10::intrusive_ptr<ProcessGroupUCC::WorkUCC>> post_queue_cpu;
bool stop_progress_loop;

void post_cpu_collective(c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> post_req);
void post_cuda_collective(c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> post_req);
public:
c10::DeviceIndex cuda_device_index;
CommPG(torch_ucc_oob_coll_info_t* oob_info,
Expand Down Expand Up @@ -332,7 +341,7 @@ class CommPG {
ucc_team_h& team,
ucc_ee_h ee,
std::unique_ptr<at::cuda::CUDAEvent> cuda_ev,
const at::cuda::CUDAStream& stream,
at::cuda::CUDAStream& stream,
event_pool_t* ep);
#endif

Expand Down
141 changes: 121 additions & 20 deletions src/torch_ucc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ bool ProcessGroupUCC::WorkUCC::isSuccess() const {
}

bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
while (status_ == UCC_OPERATION_INITIALIZED) {}
#ifdef USE_CUDA
if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) {
// block user stream
Expand Down Expand Up @@ -202,7 +203,7 @@ CommPG::CommPG(torch_ucc_oob_coll_info_t* oob_info,

CommPG::~CommPG() {
std::unique_lock<std::mutex> lock(mutex);
queue_consume_cv.wait(lock, [&] { return progress_queue.empty(); });
queue_consume_cv.wait(lock, [&] { return progress_queue.empty() && post_queue_cpu.empty() && post_queue.empty(); });
stop_progress_loop = true;
lock.unlock();
queue_produce_cv.notify_all();
Expand Down Expand Up @@ -409,20 +410,10 @@ c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> CommPG::enqueue_collective(
ucc_coll_args_t& coll,
std::unique_ptr<ProcessGroupUCC::WorkData> data,
ucc_team_h& team) {
ucc_coll_req_h request;
ucc_status_t st;
st = ucc_collective_init(&coll, &request, team);
if (st != UCC_OK) {
LOG(ERROR) << "failed to init collective: " << ucc_status_string(st);
throw std::runtime_error(ucc_status_string(st));
}
st = ucc_collective_post(request);
if (st != UCC_OK) {
LOG(ERROR) << "failed to post collective: " << ucc_status_string(st);
throw std::runtime_error(ucc_status_string(st));
}
auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
opType, UCC_INPROGRESS, request, nullptr, &ucc_comm);
opType, UCC_OPERATION_INITIALIZED, nullptr, nullptr, &ucc_comm);
work->args = coll;
work->team = team;
work->data = std::move(data);
#ifdef USE_UCC_FUTURE
if (torch_ucc_config.use_future) {
Expand All @@ -431,13 +422,63 @@ c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> CommPG::enqueue_collective(
}
#endif
std::unique_lock<std::mutex> lock(mutex);
progress_queue.push_back(work);
post_queue_cpu.push_back(work);
lock.unlock();
queue_produce_cv.notify_one();

return work;
}

#ifdef USE_CUDA
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> CommPG::enqueue_cuda_collective(
OpType opType,
ucc_coll_args_t& coll,
std::unique_ptr<ProcessGroupUCC::WorkData> data,
ucc_team_h& team,
ucc_ee_h ee,
std::unique_ptr<at::cuda::CUDAEvent> cuda_ev,
at::cuda::CUDAStream& stream,
event_pool_t* ep) {
auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
opType, UCC_OPERATION_INITIALIZED, nullptr, ee, &ucc_comm);
work->args = coll;
work->team = team;
work->ep = ep;
work->data = std::move(data);
work->fence = std::move(cuda_ev);
work->stream = &stream;
std::unique_lock<std::mutex> lock(mutex);
post_queue.push_back(work);
lock.unlock();
queue_produce_cv.notify_one();
return work;

// ucc_ev_t comp_ev, *post_ev;
// comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE;
// comp_ev.ev_context = nullptr;
// comp_ev.ev_context_size = 0;
// comp_ev.req = request;
// st = ucc_collective_triggered_post(ee, &comp_ev);
// if (st != UCC_OK) {
// LOG(ERROR) << "failed to post triggered collective: "
// << ucc_status_string(st);
// throw std::runtime_error(ucc_status_string(st));
// }
// st = ucc_ee_get_event(ee, &post_ev);
// TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
// ucc_ee_ack_event(ee, post_ev);
// auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
// opType, UCC_INPROGRESS, request, ee, &ucc_comm);
// cuda_ev->record(stream);
// std::unique_lock<std::mutex> lock(mutex);
// progress_queue.push_back(work);
// lock.unlock();
// queue_produce_cv.notify_one();
// return work;
}
#endif

#if 0
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> CommPG::enqueue_cuda_collective(
OpType opType,
ucc_coll_args_t& coll,
Expand Down Expand Up @@ -480,28 +521,88 @@ c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> CommPG::enqueue_cuda_collective(
queue_produce_cv.notify_one();
return work;
}

#endif

void CommPG::post_cpu_collective(c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> post_req) {
ucc_status_t st;

st = ucc_collective_init(&post_req->args, &post_req->request_, post_req->team);
if (st != UCC_OK) {
LOG(ERROR) << "failed to init collective: " << ucc_status_string(st);
throw std::runtime_error(ucc_status_string(st));
}

st = ucc_collective_post(post_req->request_);
if (st != UCC_OK) {
LOG(ERROR) << "failed to post collective: " << ucc_status_string(st);
throw std::runtime_error(ucc_status_string(st));
}
progress_queue.push_back(post_req);
post_req->status_ = UCC_INPROGRESS;
}

void CommPG::post_cuda_collective(c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> post_req) {
ucc_status_t st;
ucc_ev_t comp_ev, *post_ev;

st = ucc_collective_init(&post_req->args, &post_req->request_, post_req->team);
if (st != UCC_OK) {
LOG(ERROR) << "failed to init collective: " << ucc_status_string(st);
throw std::runtime_error(ucc_status_string(st));
}

comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE;
comp_ev.ev_context = nullptr;
comp_ev.ev_context_size = 0;
comp_ev.req = post_req->request_;
st = ucc_collective_triggered_post(post_req->ee_, &comp_ev);
if (st != UCC_OK) {
LOG(ERROR) << "failed to post triggered collective: "
<< ucc_status_string(st);
throw std::runtime_error(ucc_status_string(st));
}
st = ucc_ee_get_event(post_req->ee_, &post_ev);
TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
ucc_ee_ack_event(post_req->ee_, post_ev);
post_req->fence->record(*(post_req->stream));
progress_queue.push_back(post_req);
post_req->status_ = UCC_INPROGRESS;
}

void CommPG::progress_loop() {
std::unique_lock<std::mutex> lock(mutex);
#ifdef USE_CUDA
bool device_set = false;
#endif
while (!stop_progress_loop) {
if (progress_queue.empty()) {
if (post_queue.empty() && progress_queue.empty() && post_queue_cpu.empty()) {
queue_produce_cv.wait(lock);
continue;
}
auto work = progress_queue.front();
progress_queue.pop_front();
lock.unlock();
queue_consume_cv.notify_one();
#ifdef USE_CUDA
if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) {
c10::cuda::set_device(cuda_device_index);
device_set = true;
}
#endif
while (!post_queue.empty()) {
post_cuda_collective(post_queue.front());
post_queue.pop_front();
}
while (!post_queue_cpu.empty()) {
post_cpu_collective(post_queue_cpu.front());
post_queue_cpu.pop_front();
}

lock.unlock();
queue_consume_cv.notify_one();
auto work = progress_queue.front();
progress_queue.pop_front();
while (work->request_->status > 0) {
work->comm_->progress();
}

try {
while (work->request_->status > 0) {
// operation initialized is in progress or
Expand Down
9 changes: 2 additions & 7 deletions src/torch_ucc_comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ CommUCX::CommUCX(int comm_size) {
}
memset(&worker_params, 0, sizeof(ucp_worker_params_t));
worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
worker_params.thread_mode = UCS_THREAD_MODE_SINGLE;
st = ucp_worker_create(context, &worker_params, &worker);
if (st != UCS_OK) {
LOG(ERROR) << "failed to create UCP worker: " << ucs_status_string(st);
Expand Down Expand Up @@ -133,7 +133,7 @@ CommUCC::CommUCC(torch_ucc_oob_coll_info_t* oob_info) {
}
memset(&lib_params, 0, sizeof(ucc_lib_params_t));
lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE;
lib_params.thread_mode = UCC_THREAD_MULTIPLE;
lib_params.thread_mode = UCC_THREAD_SINGLE;
st = ucc_init(&lib_params, lib_config, &lib);
ucc_lib_config_release(lib_config);
if (st != UCC_OK) {
Expand All @@ -147,11 +147,6 @@ CommUCC::CommUCC(torch_ucc_oob_coll_info_t* oob_info) {
LOG(ERROR) << "failed to query for lib attr: " << ucc_status_string(st);
throw std::runtime_error(ucc_status_string(st));
}
if (lib_attr.thread_mode != UCC_THREAD_MULTIPLE) {
LOG(ERROR) << "ucc library wasn't initialized with mt support "
<< "check ucc compile options ";
throw std::runtime_error("failed to init ucc lib");
}
st = ucc_context_config_read(lib, NULL, &context_config);
if (st != UCC_OK) {
ucc_finalize(lib);
Expand Down

0 comments on commit 26aa619

Please sign in to comment.