Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCX plugin: Add multi-receive posting to avoid excessive flush #125

Merged
merged 11 commits into from
Oct 18, 2023
2 changes: 1 addition & 1 deletion src/p2p_plugin.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2016-2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
Expand Down
83 changes: 41 additions & 42 deletions src/ucx_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,18 @@ typedef struct connect_msg {
size_t addr_len;
} connect_msg_t;

struct ucx_comm;

/**
* Batch of UCX Requests from NCCL perspective
*/
typedef struct ucx_request {
ucp_worker_h worker; /* Worker for all requests */
int pending; /* How many requests are still pending */
int count; /* How many requests are contained */
int used; /* Allocation status */

int size[NCCL_NET_IB_MAX_RECVS];
struct ucx_request *next; /* Next request in the free list */
struct ucx_comm *comm; /* Owning communicator */
ucp_worker_h worker; /* Worker for all requests */
int pending; /* How many requests are still pending */
int count; /* How many requests are contained */
int size[NCCL_NET_IB_MAX_RECVS];
} ucx_request_t;

struct ep_list {
Expand Down Expand Up @@ -179,7 +181,7 @@ typedef struct ucx_comm {
struct ncclSocket sock; /* socket for OOB connection */
int ready; /* indicates that receive communicator is fully initialized */
ucx_request_t reqs[MAX_REQUESTS]; /* max inflight requests */

ucx_request_t *free_req; /* first request available */
connect_msg_t *msg; /* message to establish reverse connection */
void *connect_req; /* msg request */
} ucx_comm_t;
Expand Down Expand Up @@ -390,8 +392,11 @@ ncclResult_t nccl_ucx_listen(int dev, void *handle, void **listen_comm) {
static void ucx_request_init(ucx_comm_t *comm) {
static const int entries = sizeof(comm->reqs) / sizeof(*comm->reqs);

for (int i = 0; i < entries; i++) {
comm->reqs[i].used = 0;
comm->free_req = NULL;
for (int i = entries - 1; i >= 0; i--) {
comm->reqs[i].comm = comm;
comm->reqs[i].next = comm->free_req;
comm->free_req = &comm->reqs[i];
}
}

Expand Down Expand Up @@ -564,23 +569,22 @@ static ucx_request_t *ucx_request_get(ucx_comm_t *comm) {
static const size_t entries = sizeof(comm->reqs) / sizeof(*comm->reqs);
ucx_request_t *req;

for (int i = 0; i < entries; i++) {
req = &comm->reqs[i];
if (req->used == 0) {
req->worker = comm->worker;
req->pending = 0;
req->count = 0;
req->used = 1;
return req;
}
req = comm->free_req;
if (req == NULL) {
WARN("NET/UCX: unable to allocate NCCL request");
return NULL;
}

WARN("NET/UCX: unable to allocate NCCL request");
return NULL;
comm->free_req = req->next;
req->worker = comm->worker;
req->pending = 0;
req->count = 0;
return req;
}

static void ucx_request_release(ucx_request_t *req) {
req->used = 0;
req->next = req->comm->free_req;
req->comm->free_req = req;
}

static void ucx_request_add(ucx_request_t *req, int size) {
Expand All @@ -596,7 +600,7 @@ static ncclResult_t ucx_send_check(ucx_comm_t *comm) {
connect_msg_t *msg;
ucp_ep_params_t ep_params;
void *ucp_req;
int pending;
ucs_status_t status;

ucp_worker_progress(comm->worker);

Expand All @@ -605,27 +609,21 @@ static ncclResult_t ucx_send_check(ucx_comm_t *comm) {
return ncclSuccess;
}

pending = 1;
msg = malloc(info_tag.length);

params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA |
UCP_OP_ATTR_FLAG_NO_IMM_CMPL;
params.cb.recv = recv_handler_nbx;
params.user_data = &pending;
params.op_attr_mask = 0;
ucp_req = ucp_tag_msg_recv_nbx(comm->worker, msg, info_tag.length,
msg_tag, &params);
if (UCS_PTR_IS_ERR(ucp_req)) {
WARN("Unable to receive connect msg (%s)",
ucs_status_string(UCS_PTR_STATUS(ucp_req)));
free(msg);
return ncclSystemError;
} else if (ucp_req == NULL) {
pending--;
}

while (pending > 0) {
ucp_worker_progress(comm->worker);
} else if (ucp_req != NULL) {
do {
ucp_worker_progress(comm->worker);
status = ucp_request_check_status(ucp_req);
} while (status == UCS_INPROGRESS);
assert(status == UCS_OK);
}

ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
Expand Down Expand Up @@ -770,16 +768,19 @@ static ncclResult_t nccl_ucx_irecv(void *recv_comm, int n, void **data,
return ncclInternalError;
}

params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA;
params.cb.recv = recv_handler_nbx;
params.user_data = &req->pending;

for (int i = 0; i < n; i++) {
ucx_request_add(req, sizes[i]);

params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA;
params.cb.recv = recv_handler_nbx;
params.user_data = &req->pending;
if (mh[i]) {
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.memory_type = mh[i]->mem_type;
} else {
params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMORY_TYPE;
}

ucp_req = ucp_tag_recv_nbx(comm->worker, data[i], sizes[i],
Expand Down Expand Up @@ -852,9 +853,7 @@ static ncclResult_t nccl_ucx_test(void *request, int *done, int *size) {
*done = 1;
if (size != NULL) {
/* Posted receives have completed */
for (int i = 0; i < req->count; i++) {
size[i] = req->size[i];
}
memcpy(size, req->size, sizeof(*size) * req->count);
}

ucx_request_release(req);
Expand Down