diff --git a/.ci/run_nccl_tests.sh b/.ci/run_nccl_tests.sh index b004e53e..c2724a4f 100755 --- a/.ci/run_nccl_tests.sh +++ b/.ci/run_nccl_tests.sh @@ -109,7 +109,7 @@ for TEST_EXE in ${NCCL_TEST_EXE[@]}; do #=================== # Enable ucx_rma tests once this is resolved: https://redmine.mellanox.com/issues/3037941 # for P2P_LAYER in ucx ucx_rma ib - for P2P_LAYER in ucx ib ucx_uct; do + for P2P_LAYER in ib ucx ucx_uct ucx_uct_read; do MPIRUN_OPTIONS_PLUGIN_P2P_LAYER="-x NCCL_PLUGIN_P2P=${P2P_LAYER}" #=================== diff --git a/include/p2p_plugin.h b/include/p2p_plugin.h index 3c3e0ab9..10493b7d 100644 --- a/include/p2p_plugin.h +++ b/include/p2p_plugin.h @@ -30,6 +30,7 @@ typedef enum nccl_p2p_plugin { NCCL_P2P_UCX, NCCL_P2P_UCX_RMA, NCCL_P2P_UCX_UCT, + NCCL_P2P_UCX_UCT_RD, NCCL_P2P_LAST } nccl_p2p_plugin_t; diff --git a/include/ucx_uct_lib.h b/include/ucx_uct_lib.h index ba656efa..02566d6c 100644 --- a/include/ucx_uct_lib.h +++ b/include/ucx_uct_lib.h @@ -20,6 +20,13 @@ #define NCCL_UCT_LISTEN_HANDLE_MAGIC 0x43cf19ed91abdb85 #define NCCL_UCT_REG_ALIGN 4096 +typedef enum { + NCCL_UCT_AM_RTR = 14, /* Use particular values */ + NCCL_UCT_AM_ATP = 15, + NCCL_UCT_AM_RTS = 16, + NCCL_UCT_AM_ATS = 17 +} nccl_uct_am_type_t; + typedef enum { NCCL_UCT_START = 0, NCCL_UCT_CONNECT, @@ -206,6 +213,7 @@ int nccl_uct_flush_index(nccl_uct_comm_t *base, int *sizes, int n); ncclResult_t nccl_uct_flush(nccl_uct_comm_t *base_comm, void *data, int size, nccl_uct_memh_t *uct_memh, uct_completion_t *completion, void **request); +void nccl_uct_empty_callback(uct_completion_t *comp); /* NCCL common plugin callbacks */ ncclResult_t nccl_uct_listen(int dev, void *listen_handle, void **listen_comm); diff --git a/include/ucx_uct_ring.h b/include/ucx_uct_ring.h new file mode 100644 index 00000000..1fbcdc70 --- /dev/null +++ b/include/ucx_uct_ring.h @@ -0,0 +1,103 @@ +/************************************************************************* + * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_UCX_UCT_RING_H_ +#define NCCL_UCX_UCT_RING_H_ + +#include "nccl.h" +#include + +#define NCCL_UCT_RING_SIZE (1 << 7) +#define NCCL_UCT_RING_MASK (NCCL_UCT_RING_SIZE - 1) + +typedef struct nccl_uct_ring { + unsigned first; + unsigned last; + unsigned size; + unsigned entry_size; + int tag[NCCL_UCT_RING_SIZE]; + void *entry; +} nccl_uct_ring_t; + +static inline ncclResult_t nccl_uct_ring_init(nccl_uct_ring_t *ring, + unsigned entry_size) { + int i; + + ring->first = 0; + ring->last = 0; + ring->entry_size = entry_size; + ring->entry = malloc(entry_size * NCCL_UCT_RING_SIZE); + if (ring->entry == NULL) { + free(ring->entry); + return ncclSystemError; + } + + for (i = 0; i < NCCL_UCT_RING_SIZE; i++) { + ring->tag[i] = INT_MAX; + } + return ncclSuccess; +} + +static inline void nccl_uct_ring_deinit(nccl_uct_ring_t *ring) { + free(ring->entry); +} + +static inline void *nccl_uct_ring_get_entry(nccl_uct_ring_t *ring, unsigned i) { + return (uint8_t*)ring->entry + (ring->entry_size * (i & NCCL_UCT_RING_MASK)); +} + +static inline void nccl_uct_ring_append(nccl_uct_ring_t *ring, int tag, + void *data, size_t len) { + int j = ring->last & NCCL_UCT_RING_MASK; + + ring->last++; + + assert((ring->last & NCCL_UCT_RING_MASK) != + (ring->first & NCCL_UCT_RING_MASK)); + assert(ring->tag[j] == INT_MAX); + assert(len == ring->entry_size); + + ring->tag[j] = tag; + memcpy(nccl_uct_ring_get_entry(ring, j), data, len); +} + +static inline int nccl_uct_ring_is_empty(const nccl_uct_ring_t *ring) { + return ring->first == ring->last; +} + +static inline void nccl_uct_ring_consume(nccl_uct_ring_t *ring, unsigned i) { + unsigned j = i & NCCL_UCT_RING_MASK; + + assert(ring->tag[j] != INT_MAX); + ring->tag[j] = INT_MAX; + + /* Cleanup upon tag hit */ + if (i == ring->first) { + for (; i != ring->last; i++) { + j = i & NCCL_UCT_RING_MASK; + if (ring->tag[j] != INT_MAX) { + break; + } + ring->first = i + 1; + } + } +} + +static inline unsigned nccl_uct_ring_find(nccl_uct_ring_t *ring, int tag) { + unsigned i; + + assert(tag != INT_MAX); + + for (i = ring->first; i != ring->last; i++) { + if (ring->tag[i & NCCL_UCT_RING_MASK] == tag) { + return i; + } + } + + return ring->last; +} + +#endif /* NCCL_UCX_UCT_RING_H_ */ diff --git a/src/Makefile.am b/src/Makefile.am index c58d398b..131f1806 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -26,7 +26,8 @@ libnccl_net_la_SOURCES += \ ucx_plugin.c \ ucx_rma_plugin.c \ ucx_uct_lib.c \ - ucx_uct_plugin.c + ucx_uct_plugin.c \ + ucx_uct_rd_plugin.c endif if HAVE_SHARP_PLUGIN diff --git a/src/p2p_plugin.c b/src/p2p_plugin.c index 2715b2dd..78074dbb 100644 --- a/src/p2p_plugin.c +++ b/src/p2p_plugin.c @@ -27,6 +27,10 @@ extern ncclNet_v8_t ucxUctPlugin_v8; extern ncclNet_v7_t ucxUctPlugin_v7; extern ncclNet_v6_t ucxUctPlugin_v6; extern ncclNet_v5_t ucxUctPlugin_v5; +extern ncclNet_v8_t ucxUctRdPlugin_v8; +extern ncclNet_v7_t ucxUctRdPlugin_v7; +extern ncclNet_v6_t ucxUctRdPlugin_v6; +extern ncclNet_v5_t ucxUctRdPlugin_v5; #endif extern ncclNet_v8_t ibPlugin_v8; @@ -77,6 +81,10 @@ ncclNet_v5_t ncclNetPlugin_v5 = { static nccl_p2p_plugin_t p2p_plugin = NCCL_P2P_LAST; +static int nccl_p2p_is_uct_plugin(nccl_p2p_plugin_t plugin) { + return (plugin == NCCL_P2P_UCX_UCT) || (plugin == NCCL_P2P_UCX_UCT_RD); +} + static void pluginSetup() { p2p_plugin = NCCL_P2P_IB; @@ -92,6 +100,7 @@ static void pluginSetup() else if (!strcasecmp(p2p_layer, "ucx")) p2p_plugin = NCCL_P2P_UCX; else if (!strcasecmp(p2p_layer, "ucx_rma")) p2p_plugin = NCCL_P2P_UCX_RMA; else if (!strcasecmp(p2p_layer, "ucx_uct")) p2p_plugin = NCCL_P2P_UCX_UCT; + else if (!strcasecmp(p2p_layer, "ucx_uct_read")) p2p_plugin = NCCL_P2P_UCX_UCT_RD; #endif else { WARN("Invalid value %s for NCCL_PLUGIN_P2P, using default", p2p_layer); @@ -117,6 +126,12 @@ static void pluginSetup() ncclNetPlugin_v6 = ucxUctPlugin_v6; ncclNetPlugin_v5 = ucxUctPlugin_v5; break; + case NCCL_P2P_UCX_UCT_RD: + ncclNetPlugin_v8 = ucxUctRdPlugin_v8; + ncclNetPlugin_v7 = ucxUctRdPlugin_v7; + ncclNetPlugin_v6 = ucxUctRdPlugin_v6; + ncclNetPlugin_v5 = ucxUctRdPlugin_v5; + break; #endif default: ncclNetPlugin_v8 = ibPlugin_v8; @@ -221,7 +236,8 @@ ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int dev, ncclNetPropert INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (nvidia-peermem) enabled for HCA %d '%s", dev, devs[dev].devName); } props->regIsGlobal = 1; - if (((p2p_plugin == NCCL_P2P_UCX_UCT) || (p2p_plugin == NCCL_P2P_IB)) && nccl_p2p_dmabuf_support(dev) == ncclSuccess) { + if ((nccl_p2p_is_uct_plugin(p2p_plugin) || (p2p_plugin == NCCL_P2P_IB)) && + nccl_p2p_dmabuf_support(dev) == ncclSuccess) { props->ptrSupport |= NCCL_PTR_DMABUF; // GDR support via DMA-BUF INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (DMABUF) enabled for HCA %d '%s", dev, devs[dev].devName); } @@ -231,7 +247,7 @@ ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int dev, ncclNetPropert props->maxComms = ibDev->maxQp; if (p2p_plugin == NCCL_P2P_IB || p2p_plugin == NCCL_P2P_UCX || - p2p_plugin == NCCL_P2P_UCX_UCT) { + nccl_p2p_is_uct_plugin(p2p_plugin)) { props->maxRecvs = NCCL_NET_IB_MAX_RECVS; } else { props->maxRecvs = 1; diff --git a/src/ucx_uct_lib.c b/src/ucx_uct_lib.c index bf05dbbd..ee66f9a7 100644 --- a/src/ucx_uct_lib.c +++ b/src/ucx_uct_lib.c @@ -14,6 +14,10 @@ nccl_uct_context_t context = { .dev_count = -1 }; +void nccl_uct_empty_callback(uct_completion_t *comp) { + assert(comp->count == 0); +} + ncclResult_t nccl_uct_iface_set_handler(nccl_uct_iface_t *uct_iface, int id, uct_am_callback_t callback) { UCXCHECK(uct_iface_set_am_handler(uct_iface->iface, id, callback, NULL, 0), @@ -42,6 +46,14 @@ static uct_iface_h nccl_uct_resource_iface_open(uct_worker_h worker, UCXCHECK(uct_md_iface_config_read(md, tl->tl_name, NULL, NULL, &config), return NULL, "read MD iface config for TL '%s'", tl->tl_name); + status = uct_config_modify(config, "IB_TX_INLINE_RESP", "0"); + if (status != UCS_OK) { + WARN("Failed to modify MD configuration for '%s', error %s", + tl->tl_name, ucs_status_string(status)); + uct_config_release(config); + return NULL; + } + params.field_mask = UCT_IFACE_PARAM_FIELD_OPEN_MODE | UCT_IFACE_PARAM_FIELD_DEVICE | UCT_IFACE_PARAM_FIELD_STATS_ROOT | UCT_IFACE_PARAM_FIELD_RX_HEADROOM; @@ -739,7 +751,7 @@ ncclResult_t nccl_uct_flush(nccl_uct_comm_t *base_comm, void *data, int size, uct_iov_t iov; iov.buffer = base_comm->gpu_flush.mem; - iov.length = base_comm->uct_iface->min_get_zcopy; + iov.length = base_comm->uct_iface->min_get_zcopy? : 1; iov.memh = base_comm->gpu_flush.memh; iov.stride = 0; iov.count = 1; diff --git a/src/ucx_uct_plugin.c b/src/ucx_uct_plugin.c index 2188dfaa..f03361c3 100644 --- a/src/ucx_uct_plugin.c +++ b/src/ucx_uct_plugin.c @@ -6,11 +6,6 @@ #include "ucx_uct_lib.h" -typedef enum { - NCCL_UCT_AM_RTR = 14, /* Use particular values */ - NCCL_UCT_AM_ATP = 15 -} nccl_uct_am_type_t; - typedef enum { NCCL_UCT_REQ_IRECV = -1, NCCL_UCT_REQ_IFLUSH = -2 @@ -137,10 +132,6 @@ static void nccl_uct_rdesc_set(nccl_uct_rdesc_t *rdesc, uint64_t id, int n, } } -static void nccl_uct_empty_callback(uct_completion_t *comp) { - assert(comp->count == 0); -} - static nccl_uct_req_t *nccl_uct_rdesc_get_req(nccl_uct_rdesc_t *rdesc, int i, int size) { nccl_uct_req_t *req; diff --git a/src/ucx_uct_rd_plugin.c b/src/ucx_uct_rd_plugin.c new file mode 100644 index 00000000..c58e0e70 --- /dev/null +++ b/src/ucx_uct_rd_plugin.c @@ -0,0 +1,432 @@ +/************************************************************************* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "ucx_uct_lib.h" +#include "ucx_uct_ring.h" + +#define NCCL_UCT_PENDING_SIZE 128 +#define NCCL_UCT_PENDING_MASK (NCCL_UCT_PENDING_SIZE - 1) + +/* Memory chunk to send or receive */ +typedef struct { + int tag; + int size; + void *data; + union { + uct_rkey_t rkey; + nccl_uct_memh_t *uct_memh; + } u; + struct nccl_uct_rd_req *req; + unsigned index; /* irecv(): position in the receive request */ +} nccl_uct_mem_t; + +/* Context for GET requests to be posted */ +typedef struct { + uct_iov_t iov; + uint64_t rva; + uct_rkey_t rkey; + struct nccl_uct_rd_req *req; +} nccl_uct_get_param_t; + +/* Communicator for client or server side */ +typedef struct nccl_uct_rd_comm { + /* Base communicator with endpoints setup */ + nccl_uct_comm_t base; + + /* NCCL request free list */ + int req_count; + struct nccl_uct_rd_req *free_req; + + /* TAG matching rings */ + nccl_uct_ring_t exp; + nccl_uct_ring_t unexp; + + /* GET zcopy for matched chunks, but yet to be posted */ + struct { + unsigned first; + unsigned last; + nccl_uct_get_param_t param[NCCL_UCT_PENDING_SIZE]; + } pending; +} nccl_uct_rd_comm_t; + +/* Either irecv, isend or iflush NCCL request */ +typedef struct nccl_uct_rd_req { + uct_completion_t completion; /* Release when count equals zero */ + int send_rts; /* Request type */ + nccl_uct_rd_comm_t *comm; /* Parent communicator */ + struct nccl_uct_rd_req *next; /* Free list node */ + + int count; /* isend(): 1, irecv(): from 1 to n */ + int rts_count; /* RTS actually received and matched */ + + /* Sizes actually read to report, received from RTS */ + int sizes[NCCL_UCX_UCT_MAX_RECVS]; + + /* Remote completed requests cookies, to send with ATS */ + struct nccl_uct_rd_req *remote_req[NCCL_UCX_UCT_MAX_RECVS]; +} nccl_uct_rd_req_t; + +static inline nccl_uct_rd_comm_t * +nccl_uct_rd_comm_get(nccl_uct_comm_t *base_comm) { + return ucs_container_of(base_comm, nccl_uct_rd_comm_t, base); +} + +static void nccl_uct_rd_send_ats(nccl_uct_rd_req_t *req) { + ucs_status_t status; + + assert(req->send_rts == 0); + assert(req->rts_count == req->count); + assert(req->completion.count == 1); + + status = uct_ep_am_short(req->comm->base.uct_ep->ep, NCCL_UCT_AM_ATS, + (uint64_t)req->comm->base.remote.comm, + req->remote_req, + sizeof(*req->remote_req) * req->rts_count); + if (status == UCS_OK) { + req->completion.count--; + } +} + +static void nccl_uct_rd_pending_add(nccl_uct_rd_comm_t *comm, + nccl_uct_mem_t *src, nccl_uct_mem_t *dst) { + nccl_uct_rd_req_t *req = dst->req; + nccl_uct_get_param_t *param; + + assert(src->size <= dst->size); + assert(req->rts_count < NCCL_UCX_UCT_MAX_RECVS); + + req->sizes[dst->index] = src->size; + req->remote_req[req->rts_count++] = src->req; /* src->req is a cookie */ + + if (src->size == 0) { + req->completion.count--; + return; + } + + param = &comm->pending.param[comm->pending.last & NCCL_UCT_PENDING_MASK]; + comm->pending.last++; + + assert((comm->pending.first & NCCL_UCT_PENDING_MASK) != + (comm->pending.last & NCCL_UCT_PENDING_MASK)); + + param->iov.buffer = dst->data; + param->iov.length = src->size; + param->iov.memh = dst->u.uct_memh->memh; + param->iov.stride = 0; + param->iov.count = 1; + param->rva = (uint64_t)src->data; + param->rkey = src->u.rkey; + param->req = req; +} + +static void nccl_uct_rd_pending_drain(nccl_uct_rd_comm_t *comm) { + ucs_status_t status; + nccl_uct_get_param_t *param; + + for (; comm->pending.first != comm->pending.last; comm->pending.first++) { + param = &comm->pending.param[comm->pending.first & NCCL_UCT_PENDING_MASK]; + + status = uct_ep_get_zcopy(comm->base.uct_ep->ep, ¶m->iov, 1, param->rva, + param->rkey, ¶m->req->completion); + if (status == UCS_OK) { + param->req->completion.count--; + } else if (status != UCS_INPROGRESS) { + break; + } + + if (param->req->completion.count == 1) { + nccl_uct_rd_send_ats(param->req); + } + } +} + +static ucs_status_t nccl_uct_rd_ats_callback(void *arg, void *data, + size_t length, unsigned flags) { + nccl_uct_rd_req_t **req = (nccl_uct_rd_req_t **)((uint8_t *)data + 8); + nccl_uct_rd_req_t **end = (nccl_uct_rd_req_t **)((uint8_t *)data + length); + + for (; req + 1 <= end; req++) { + assert((*req)->completion.count == 1); + assert((*req)->comm == nccl_uct_rd_comm_get(*(nccl_uct_comm_t**)data)); + + (*req)->completion.count = 0; + } + + assert(req == end); + return UCS_OK; +} + +static ucs_status_t nccl_uct_rd_rts_callback(void *arg, void *data, + size_t length, unsigned flags) { + + nccl_uct_rd_comm_t *comm = nccl_uct_rd_comm_get(*(nccl_uct_comm_t**)data); + nccl_uct_mem_t *rts = (nccl_uct_mem_t *)((uint8_t *)data + 8); + nccl_uct_ring_t *exp; + nccl_uct_mem_t *dst; + unsigned i; + + assert(length == (sizeof(*rts) + 8)); + + /* Do we already expect it? */ + exp = &comm->exp; + i = nccl_uct_ring_find(exp, rts->tag); + if (i == exp->last) { + nccl_uct_ring_append(&comm->unexp, rts->tag, rts, sizeof(*rts)); + } else { + /* Receive request was already posted */ + dst = nccl_uct_ring_get_entry(exp, i); + nccl_uct_rd_pending_add(comm, rts, dst); + nccl_uct_ring_consume(exp, i); + } + + return UCS_OK; +} + +static ncclResult_t nccl_uct_rd_iface_set(nccl_uct_iface_t *uct_iface) { + NCCLCHECK(nccl_uct_iface_set_handler(uct_iface, NCCL_UCT_AM_RTS, + nccl_uct_rd_rts_callback)); + NCCLCHECK(nccl_uct_iface_set_handler(uct_iface, NCCL_UCT_AM_ATS, + nccl_uct_rd_ats_callback)); + return ncclSuccess; +} + +static ncclResult_t nccl_uct_rd_comm_alloc(nccl_uct_comm_t **comm_p) { + nccl_uct_rd_comm_t *comm = calloc(1, sizeof(*comm)); + if (comm != NULL) { + *comm_p = &comm->base; + return ncclSuccess; + } + + return ncclSystemError; +} + +static ncclResult_t nccl_uct_rd_comm_init(nccl_uct_comm_t *base_comm, + nccl_uct_context_t *context, + nccl_uct_worker_t *worker, int dev, + const nccl_uct_comm_t *remote_comm) { + nccl_uct_rd_comm_t *comm = nccl_uct_rd_comm_get(base_comm); + + comm->pending.first = 0; + comm->pending.last = 0; + comm->req_count = 0; + comm->free_req = NULL; + + NCCLCHECK(nccl_uct_ring_init(&comm->exp, sizeof(nccl_uct_mem_t))); + NCCLCHECK(nccl_uct_ring_init(&comm->unexp, sizeof(nccl_uct_mem_t))); + + return nccl_uct_comm_init(&comm->base, context, worker, dev, remote_comm); +} + +static ncclResult_t nccl_uct_rd_init(ncclDebugLogger_t logFunction) { + NCCL_STATIC_ASSERT(NCCL_UCT_RING_SIZE >= 2 * MAX_REQUESTS, + "Cannot handle expected/unexpected requests"); + NCCL_STATIC_ASSERT(NCCL_UCT_PENDING_SIZE > MAX_REQUESTS, + "Cannot handle enough pending requests"); + + context.ops.comm_alloc = nccl_uct_rd_comm_alloc; + context.ops.comm_init = nccl_uct_rd_comm_init; + context.ops.iface_set = nccl_uct_rd_iface_set; + context.rkey_size = sizeof(((nccl_uct_mem_t*)0)->u.rkey); + context.am_short_size = sizeof(((nccl_uct_rd_req_t*)0)->remote_req); + if (sizeof(nccl_uct_mem_t) > context.am_short_size) { + context.am_short_size = sizeof(nccl_uct_mem_t); + } + + return nccl_p2p_ib_init(&context.dev_count, ncclIbDevs, context.if_name, + &context.if_addr, NULL, logFunction); +} + +static nccl_uct_rd_req_t *nccl_uct_rd_req_alloc(nccl_uct_rd_comm_t *comm, + int count) { + nccl_uct_rd_req_t *req = comm->free_req; + + if (req == NULL) { + req = malloc(sizeof(*req)); + if (req == NULL) { + return req; + } + } else { + comm->free_req = req->next; + } + + comm->req_count++; + req->comm = comm; + req->completion.func = nccl_uct_empty_callback; + req->completion.count = count; + req->completion.status = UCS_OK; + return req; +} + +static inline void nccl_uct_rd_req_free(nccl_uct_rd_req_t *req) { + req->next = req->comm->free_req; + req->comm->free_req = req; + req->comm->req_count--; +} + +static ncclResult_t nccl_uct_rd_isend(void *send_comm, void *data, int size, + int tag, void *mhandle, void **request) { + + nccl_uct_rd_comm_t *comm = nccl_uct_rd_comm_get(send_comm); + nccl_uct_memh_t *uct_memh = mhandle; + nccl_uct_mem_t rts; + nccl_uct_rd_req_t *req; + ucs_status_t status; + + req = nccl_uct_rd_req_alloc(comm, 1); + if (req == NULL) { + *request = NULL; + return ncclSuccess; + } + + req->send_rts = 1; + req->count = 1; + req->sizes[0] = size; + *request = req; + + rts.tag = tag; + rts.size = size; + rts.data = data; + rts.u.rkey = uct_memh->bundle.rkey; + rts.req = req; + + status = uct_ep_am_short(comm->base.uct_ep->ep, NCCL_UCT_AM_RTS, + (uint64_t)comm->base.remote.comm, &rts, sizeof(rts)); + if (status != UCS_OK) { + nccl_uct_rd_req_free(req); + *request = NULL; + } + + return ncclSuccess; +} + +static ncclResult_t nccl_uct_rd_irecv(void *recv_comm, int n, void **data, + int *sizes, int *tags, void **mhandles, + void **request) { + nccl_uct_rd_comm_t *comm = nccl_uct_rd_comm_get(recv_comm); + nccl_uct_memh_t **uct_memh = (nccl_uct_memh_t**)mhandles; + nccl_uct_ring_t *unexp; + nccl_uct_rd_req_t *req; + nccl_uct_mem_t *rts, recv; + unsigned i, j; + + assert(n <= NCCL_UCX_UCT_MAX_RECVS); + + /* Create a request */ + req = nccl_uct_rd_req_alloc(comm, n + 1); + *request = req; + if (req == NULL) { + return ncclSuccess; + } + + req->send_rts = 0; + req->count = n; + req->rts_count = 0; + + /* Try to match or build expected list */ + for (i = 0; i < n; i++) { + recv.tag = tags[i]; + recv.size = sizes[i]; + recv.data = data[i]; + recv.u.uct_memh = uct_memh[i]; + recv.req = req; + recv.index = i; + + unexp = &comm->unexp; + j = nccl_uct_ring_find(unexp, tags[i]); + if (j == unexp->last) { + nccl_uct_ring_append(&comm->exp, tags[i], &recv, sizeof(recv)); + } else { + rts = nccl_uct_ring_get_entry(unexp, j); + nccl_uct_rd_pending_add(comm, rts, &recv); + nccl_uct_ring_consume(unexp, j); + } + } + + return ncclSuccess; +} + +static ncclResult_t nccl_uct_rd_iflush(void *recv_comm, int n, void **data, + int *sizes, void **mhandle, + void **request) { + ncclResult_t result = ncclSuccess; + nccl_uct_comm_t *base_comm = recv_comm; + nccl_uct_memh_t **uct_memh = (nccl_uct_memh_t**)mhandle; + int last = nccl_uct_flush_index(base_comm, sizes, n); + nccl_uct_rd_req_t *req; + + *request = NULL; + + if (last != -1) { + req = nccl_uct_rd_req_alloc(nccl_uct_rd_comm_get(recv_comm), 1); + if (req != NULL) { + req->send_rts = -1; + *request = req; + + result = nccl_uct_flush(base_comm, data[last], sizes[last], + uct_memh[last], &req->completion, request); + if (*request == NULL) { + nccl_uct_rd_req_free(req); + } + } + } + + return result; +} + +static ncclResult_t nccl_uct_rd_test(void *request, int *done, int *sizes) { + nccl_uct_rd_req_t *req = request; + + while (uct_worker_progress(req->comm->base.uct_worker->worker)) + ; /* empty */ + + nccl_uct_rd_pending_drain(req->comm); + + if (req->completion.count > 0) { + if ((req->send_rts == 0) && (req->completion.count == 1)) { + nccl_uct_rd_send_ats(req); + } + + if (req->completion.count > 0) { + *done = 0; + return ncclSuccess; + } + } + + if ((sizes != NULL) && (req->send_rts > -1)) { + memcpy(sizes, req->sizes, req->count * sizeof(*req->sizes)); + } + + *done = 1; + nccl_uct_rd_req_free(req); + return ncclSuccess; +} + +static ncclResult_t nccl_uct_rd_close(void *close_comm) { + nccl_uct_rd_comm_t *comm = nccl_uct_rd_comm_get(close_comm); + nccl_uct_rd_req_t *req; + + nccl_uct_comm_deinit(close_comm); + + while ((req = comm->free_req) != NULL) { + comm->free_req = req->next; + free(req); + } + + assert(nccl_uct_ring_is_empty(&comm->exp)); + assert(nccl_uct_ring_is_empty(&comm->unexp)); + assert(comm->req_count == 0); + assert(comm->pending.first == comm->pending.last); + + nccl_uct_ring_deinit(&comm->exp); + nccl_uct_ring_deinit(&comm->unexp); + free(comm); + return ncclSuccess; +} + +ncclNet_v8_t ucxUctRdPlugin_v8 = NCCL_UCT_PLUGIN_V8("UCX-UCT-RD", nccl_uct_rd); +ncclNet_v7_t ucxUctRdPlugin_v7 = NCCL_UCT_PLUGIN_V7("UCX-UCT-RD", nccl_uct_rd); +ncclNet_v6_t ucxUctRdPlugin_v6 = NCCL_UCT_PLUGIN_V6("UCX-UCT-RD", nccl_uct_rd); +ncclNet_v5_t ucxUctRdPlugin_v5 = NCCL_UCT_PLUGIN_V5("UCX-UCT-RD", nccl_uct_rd);