Skip to content

Commit

Permalink
Address review: add freelist and various code fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tvegas1 committed Oct 11, 2023
1 parent 90a623b commit d6ff08f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/p2p_plugin.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2016-2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
Expand Down
83 changes: 41 additions & 42 deletions src/ucx_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,18 @@ typedef struct connect_msg {
size_t addr_len;
} connect_msg_t;

struct ucx_comm;

/**
* Batch of UCX Requests from NCCL perspective
*/
typedef struct ucx_request {
ucp_worker_h worker; /* Worker for all requests */
int pending; /* How many requests are still pending */
int count; /* How many requests are contained */
int used; /* Allocation status */

int size[NCCL_NET_IB_MAX_RECVS];
struct ucx_request *next; /* Next request in the free list */
struct ucx_comm *comm; /* Owning communicator */
ucp_worker_h worker; /* Worker for all requests */
int pending; /* How many requests are still pending */
int count; /* How many requests are contained */
int size[NCCL_NET_IB_MAX_RECVS];
} ucx_request_t;

struct ep_list {
Expand Down Expand Up @@ -179,7 +181,7 @@ typedef struct ucx_comm {
struct ncclSocket sock; /* socket for OOB connection */
int ready; /* indicates that receive communicator is fully initialized */
ucx_request_t reqs[MAX_REQUESTS]; /* max inflight requests */

ucx_request_t *free_req; /* first request available */
connect_msg_t *msg; /* message to establish reverse connection */
void *connect_req; /* msg request */
} ucx_comm_t;
Expand Down Expand Up @@ -390,8 +392,11 @@ ncclResult_t nccl_ucx_listen(int dev, void *handle, void **listen_comm) {
static void ucx_request_init(ucx_comm_t *comm) {
static const int entries = sizeof(comm->reqs) / sizeof(*comm->reqs);

for (int i = 0; i < entries; i++) {
comm->reqs[i].used = 0;
comm->free_req = NULL;
for (int i = entries - 1; i >= 0; i--) {
comm->reqs[i].comm = comm;
comm->reqs[i].next = comm->free_req;
comm->free_req = &comm->reqs[i];
}
}

Expand Down Expand Up @@ -564,23 +569,22 @@ static ucx_request_t *ucx_request_get(ucx_comm_t *comm) {
static const size_t entries = sizeof(comm->reqs) / sizeof(*comm->reqs);
ucx_request_t *req;

for (int i = 0; i < entries; i++) {
req = &comm->reqs[i];
if (req->used == 0) {
req->worker = comm->worker;
req->pending = 0;
req->count = 0;
req->used = 1;
return req;
}
req = comm->free_req;
if (req == NULL) {
WARN("NET/UCX: unable to allocate NCCL request");
return NULL;
}

WARN("NET/UCX: unable to allocate NCCL request");
return NULL;
comm->free_req = req->next;
req->worker = comm->worker;
req->pending = 0;
req->count = 0;
return req;
}

static void ucx_request_release(ucx_request_t *req) {
req->used = 0;
req->next = req->comm->free_req;
req->comm->free_req = req;
}

static void ucx_request_add(ucx_request_t *req, int size) {
Expand All @@ -596,7 +600,7 @@ static ncclResult_t ucx_send_check(ucx_comm_t *comm) {
connect_msg_t *msg;
ucp_ep_params_t ep_params;
void *ucp_req;
int pending;
ucs_status_t status;

ucp_worker_progress(comm->worker);

Expand All @@ -605,27 +609,21 @@ static ncclResult_t ucx_send_check(ucx_comm_t *comm) {
return ncclSuccess;
}

pending = 1;
msg = malloc(info_tag.length);

params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA |
UCP_OP_ATTR_FLAG_NO_IMM_CMPL;
params.cb.recv = recv_handler_nbx;
params.user_data = &pending;
params.op_attr_mask = 0;
ucp_req = ucp_tag_msg_recv_nbx(comm->worker, msg, info_tag.length,
msg_tag, &params);
if (UCS_PTR_IS_ERR(ucp_req)) {
WARN("Unable to receive connect msg (%s)",
ucs_status_string(UCS_PTR_STATUS(ucp_req)));
free(msg);
return ncclSystemError;
} else if (ucp_req == NULL) {
pending--;
}

while (pending > 0) {
ucp_worker_progress(comm->worker);
} else if (ucp_req != NULL) {
do {
ucp_worker_progress(comm->worker);
status = ucp_request_check_status(ucp_req);
} while (status == UCS_INPROGRESS);
assert(status == UCS_OK);
}

ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
Expand Down Expand Up @@ -770,16 +768,19 @@ static ncclResult_t nccl_ucx_irecv(void *recv_comm, int n, void **data,
return ncclInternalError;
}

params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA;
params.cb.recv = recv_handler_nbx;
params.user_data = &req->pending;

for (int i = 0; i < n; i++) {
ucx_request_add(req, sizes[i]);

params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA;
params.cb.recv = recv_handler_nbx;
params.user_data = &req->pending;
if (mh[i]) {
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.memory_type = mh[i]->mem_type;
} else {
params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMORY_TYPE;
}

ucp_req = ucp_tag_recv_nbx(comm->worker, data[i], sizes[i],
Expand Down Expand Up @@ -852,9 +853,7 @@ static ncclResult_t nccl_ucx_test(void *request, int *done, int *size) {
*done = 1;
if (size != NULL) {
/* Posted receives have completed */
for (int i = 0; i < req->count; i++) {
size[i] = req->size[i];
}
memcpy(size, req->size, sizeof(*size) * req->count);
}

ucx_request_release(req);
Expand Down

0 comments on commit d6ff08f

Please sign in to comment.