From 610b90f63997c1ada9f42ee886b570b6d2297bfb Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Fri, 31 May 2024 19:12:23 +0300 Subject: [PATCH] Update internal plugin api call --- src/sharp_plugin.c | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/sharp_plugin.c b/src/sharp_plugin.c index 338c803d..9d551866 100644 --- a/src/sharp_plugin.c +++ b/src/sharp_plugin.c @@ -129,8 +129,8 @@ int ncclSharpAllGather(void *context, void *buf, int len) { p2p_plugin = nccl_p2p_get_plugin_type(); if (p2p_plugin != NCCL_P2P_UCX) { - NCCLCHECK(ncclNetPlugin_v7.regMr(cComm->recvComm, buf, cComm->nranks*len, NCCL_PTR_HOST, &rMhandle)); - NCCLCHECK(ncclNetPlugin_v7.regMr(cComm->sendComm, buf, cComm->nranks*len, NCCL_PTR_HOST, &sMhandle)); + NCCLCHECK(ncclNetPlugin_v8.regMr(cComm->recvComm, buf, cComm->nranks*len, NCCL_PTR_HOST, &rMhandle)); + NCCLCHECK(ncclNetPlugin_v8.regMr(cComm->sendComm, buf, cComm->nranks*len, NCCL_PTR_HOST, &sMhandle)); } int speer = cComm->rank; @@ -140,21 +140,21 @@ int ncclSharpAllGather(void *context, void *buf, int len) { while (srequest == NULL || rrequest == NULL) { void *rbuf = ((char*)buf)+rpeer*len; int tag = 0x69; - if (srequest == NULL) NCCLCHECK(ncclNetPlugin_v7.isend(cComm->sendComm, ((char*)buf)+speer*len, len, tag, sMhandle, &srequest)); - if (rrequest == NULL) NCCLCHECK(ncclNetPlugin_v7.irecv(cComm->recvComm, 1, &rbuf, &len, &tag, &rMhandle, &rrequest)); + if (srequest == NULL) NCCLCHECK(ncclNetPlugin_v8.isend(cComm->sendComm, ((char*)buf)+speer*len, len, tag, sMhandle, &srequest)); + if (rrequest == NULL) NCCLCHECK(ncclNetPlugin_v8.irecv(cComm->recvComm, 1, &rbuf, &len, &tag, &rMhandle, &rrequest)); } while (srequest || rrequest) { int done = 0; /* silent uninitialized false positive */ - if (rrequest) NCCLCHECK(ncclNetPlugin_v7.test(rrequest, &done, NULL)); + if (rrequest) NCCLCHECK(ncclNetPlugin_v8.test(rrequest, &done, NULL)); if (done) rrequest = NULL; - if (srequest) NCCLCHECK(ncclNetPlugin_v7.test(srequest, &done, NULL)); + if (srequest) NCCLCHECK(ncclNetPlugin_v8.test(srequest, &done, NULL)); if (done) srequest = NULL; } speer = rpeer; } if (p2p_plugin != NCCL_P2P_UCX) { - NCCLCHECK(ncclNetPlugin_v7.deregMr(cComm->recvComm, rMhandle)); - NCCLCHECK(ncclNetPlugin_v7.deregMr(cComm->sendComm, sMhandle)); + NCCLCHECK(ncclNetPlugin_v8.deregMr(cComm->recvComm, rMhandle)); + NCCLCHECK(ncclNetPlugin_v8.deregMr(cComm->sendComm, sMhandle)); } return 0; @@ -247,7 +247,7 @@ ncclResult_t ncclSharpListen(int dev, void* opaqueHandle, void** listenComm) { ncclResult_t status; NCCLCHECK(ncclIbMalloc((void**)&lComm, sizeof(struct ncclSharpListenComm))); - status = ncclNetPlugin_v7.listen(dev, opaqueHandle, &lComm->listenCommP2P); + status = ncclNetPlugin_v8.listen(dev, opaqueHandle, &lComm->listenCommP2P); lComm->dev = dev; *listenComm = lComm; return status; @@ -437,7 +437,7 @@ ncclResult_t ncclSharpDeregMr(void* collComm, void* mhandle) { WARN("SHARP deregmr failed\n"); } - NCCLCHECK(ncclNetPlugin_v7.deregMr(cComm->recvComm, mh->ncclIbMr)); + NCCLCHECK(ncclNetPlugin_v8.deregMr(cComm->recvComm, mh->ncclIbMr)); free(mh); return ncclSuccess; @@ -647,7 +647,7 @@ ncclResult_t ncclSharpIflush(void* collComm, void* data, int size, void* mhandle NCCLCHECK(ncclSharpGetRequest(cComm->reqs, &req)); req->requestType = NCCL_SHARP_REQ_IFLUSH; - ncclNetPlugin_v7.iflush(cComm->recvComm, 1, &data, &size, &mh->ncclIbMr, &req->sharpRequest); + ncclNetPlugin_v8.iflush(cComm->recvComm, 1, &data, &size, &mh->ncclIbMr, &req->sharpRequest); if (!req->sharpRequest) { *request = NULL; req->used = 0; @@ -662,7 +662,7 @@ ncclResult_t ncclSharpTest(void* request, int* done, int* size) { struct ncclSharpRequest* req = (struct ncclSharpRequest*)request; if (req->requestType == NCCL_SHARP_REQ_IFLUSH) { - ncclNetPlugin_v7.test(req->sharpRequest, done, size); + ncclNetPlugin_v8.test(req->sharpRequest, done, size); if (*done == 1) { req->used = 0; } @@ -696,8 +696,8 @@ ncclResult_t ncclSharpCloseColl(void* collComm) { sharp_coll_comm_destroy(cComm->sharpCollComm); sharp_coll_finalize(cComm->sharpCollContext); - NCCLCHECK(ncclNetPlugin_v7.closeRecv(cComm->recvComm)); - NCCLCHECK(ncclNetPlugin_v7.closeSend(cComm->sendComm)); + NCCLCHECK(ncclNetPlugin_v8.closeRecv(cComm->recvComm)); + NCCLCHECK(ncclNetPlugin_v8.closeSend(cComm->sendComm)); free(cComm); return ncclSuccess; } @@ -706,7 +706,7 @@ ncclResult_t ncclSharpCloseListen(void* listenComm) { struct ncclSharpListenComm *lComm = (struct ncclSharpListenComm*)listenComm; ncclResult_t status; - status = ncclNetPlugin_v7.closeListen(lComm->listenCommP2P); + status = ncclNetPlugin_v8.closeListen(lComm->listenCommP2P); free(listenComm); return status; }