From d6ff08f4a9510e114664c4c20ed3a41375705b86 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Wed, 11 Oct 2023 12:48:49 +0300 Subject: [PATCH] Address review: add freelist and various code fixes --- src/p2p_plugin.c | 2 +- src/ucx_plugin.c | 83 ++++++++++++++++++++++++------------------------ 2 files changed, 42 insertions(+), 43 deletions(-) diff --git a/src/p2p_plugin.c b/src/p2p_plugin.c index 5869935c..eb0d5fa1 100644 --- a/src/p2p_plugin.c +++ b/src/p2p_plugin.c @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2016-2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ diff --git a/src/ucx_plugin.c b/src/ucx_plugin.c index 3a10f4cd..70d05496 100644 --- a/src/ucx_plugin.c +++ b/src/ucx_plugin.c @@ -120,16 +120,18 @@ typedef struct connect_msg { size_t addr_len; } connect_msg_t; +struct ucx_comm; + /** * Batch of UCX Requests from NCCL perspective */ typedef struct ucx_request { - ucp_worker_h worker; /* Worker for all requests */ - int pending; /* How many requests are still pending */ - int count; /* How many requests are contained */ - int used; /* Allocation status */ - - int size[NCCL_NET_IB_MAX_RECVS]; + struct ucx_request *next; /* Next request in the free list */ + struct ucx_comm *comm; /* Owning communicator */ + ucp_worker_h worker; /* Worker for all requests */ + int pending; /* How many requests are still pending */ + int count; /* How many requests are contained */ + int size[NCCL_NET_IB_MAX_RECVS]; } ucx_request_t; struct ep_list { @@ -179,7 +181,7 @@ typedef struct ucx_comm { struct ncclSocket sock; /* socket for OOB connection */ int ready; /* indicates that receive communicator is fully initialized */ ucx_request_t reqs[MAX_REQUESTS]; /* max inflight requests */ - + ucx_request_t *free_req; /* first request available */ connect_msg_t *msg; /* message to establish reverse connection */ void *connect_req; /* msg request */ } ucx_comm_t; @@ -390,8 +392,11 @@ ncclResult_t nccl_ucx_listen(int dev, void *handle, void **listen_comm) { static void ucx_request_init(ucx_comm_t *comm) { static const int entries = sizeof(comm->reqs) / sizeof(*comm->reqs); - for (int i = 0; i < entries; i++) { - comm->reqs[i].used = 0; + comm->free_req = NULL; + for (int i = entries - 1; i >= 0; i--) { + comm->reqs[i].comm = comm; + comm->reqs[i].next = comm->free_req; + comm->free_req = &comm->reqs[i]; } } @@ -564,23 +569,22 @@ static ucx_request_t *ucx_request_get(ucx_comm_t *comm) { static const size_t entries = sizeof(comm->reqs) / sizeof(*comm->reqs); ucx_request_t *req; - for (int i = 0; i < entries; i++) { - req = &comm->reqs[i]; - if (req->used == 0) { - req->worker = comm->worker; - req->pending = 0; - req->count = 0; - req->used = 1; - return req; - } + req = comm->free_req; + if (req == NULL) { + WARN("NET/UCX: unable to allocate NCCL request"); + return NULL; } - WARN("NET/UCX: unable to allocate NCCL request"); - return NULL; + comm->free_req = req->next; + req->worker = comm->worker; + req->pending = 0; + req->count = 0; + return req; } static void ucx_request_release(ucx_request_t *req) { - req->used = 0; + req->next = req->comm->free_req; + req->comm->free_req = req; } static void ucx_request_add(ucx_request_t *req, int size) { @@ -596,7 +600,7 @@ static ncclResult_t ucx_send_check(ucx_comm_t *comm) { connect_msg_t *msg; ucp_ep_params_t ep_params; void *ucp_req; - int pending; + ucs_status_t status; ucp_worker_progress(comm->worker); @@ -605,14 +609,8 @@ static ncclResult_t ucx_send_check(ucx_comm_t *comm) { return ncclSuccess; } - pending = 1; msg = malloc(info_tag.length); - - params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_USER_DATA | - UCP_OP_ATTR_FLAG_NO_IMM_CMPL; - params.cb.recv = recv_handler_nbx; - params.user_data = &pending; + params.op_attr_mask = 0; ucp_req = ucp_tag_msg_recv_nbx(comm->worker, msg, info_tag.length, msg_tag, ¶ms); if (UCS_PTR_IS_ERR(ucp_req)) { @@ -620,12 +618,12 @@ static ncclResult_t ucx_send_check(ucx_comm_t *comm) { ucs_status_string(UCS_PTR_STATUS(ucp_req))); free(msg); return ncclSystemError; - } else if (ucp_req == NULL) { - pending--; - } - - while (pending > 0) { - ucp_worker_progress(comm->worker); + } else if (ucp_req != NULL) { + do { + ucp_worker_progress(comm->worker); + status = ucp_request_check_status(ucp_req); + } while (status == UCS_INPROGRESS); + assert(status == UCS_OK); } ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; @@ -770,16 +768,19 @@ static ncclResult_t nccl_ucx_irecv(void *recv_comm, int n, void **data, return ncclInternalError; } + params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA; + params.cb.recv = recv_handler_nbx; + params.user_data = &req->pending; + for (int i = 0; i < n; i++) { ucx_request_add(req, sizes[i]); - params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_USER_DATA; - params.cb.recv = recv_handler_nbx; - params.user_data = &req->pending; if (mh[i]) { params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE; params.memory_type = mh[i]->mem_type; + } else { + params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMORY_TYPE; } ucp_req = ucp_tag_recv_nbx(comm->worker, data[i], sizes[i], @@ -852,9 +853,7 @@ static ncclResult_t nccl_ucx_test(void *request, int *done, int *size) { *done = 1; if (size != NULL) { /* Posted receives have completed */ - for (int i = 0; i < req->count; i++) { - size[i] = req->size[i]; - } + memcpy(size, req->size, sizeof(*size) * req->count); } ucx_request_release(req);