Skip to content

Commit

Permalink
Update internal plugin api call
Browse files Browse the repository at this point in the history
  • Loading branch information
bureddy committed May 31, 2024
1 parent e375f37 commit 610b90f
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/sharp_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ int ncclSharpAllGather(void *context, void *buf, int len) {

p2p_plugin = nccl_p2p_get_plugin_type();
if (p2p_plugin != NCCL_P2P_UCX) {
NCCLCHECK(ncclNetPlugin_v7.regMr(cComm->recvComm, buf, cComm->nranks*len, NCCL_PTR_HOST, &rMhandle));
NCCLCHECK(ncclNetPlugin_v7.regMr(cComm->sendComm, buf, cComm->nranks*len, NCCL_PTR_HOST, &sMhandle));
NCCLCHECK(ncclNetPlugin_v8.regMr(cComm->recvComm, buf, cComm->nranks*len, NCCL_PTR_HOST, &rMhandle));
NCCLCHECK(ncclNetPlugin_v8.regMr(cComm->sendComm, buf, cComm->nranks*len, NCCL_PTR_HOST, &sMhandle));
}

int speer = cComm->rank;
Expand All @@ -140,21 +140,21 @@ int ncclSharpAllGather(void *context, void *buf, int len) {
while (srequest == NULL || rrequest == NULL) {
void *rbuf = ((char*)buf)+rpeer*len;
int tag = 0x69;
if (srequest == NULL) NCCLCHECK(ncclNetPlugin_v7.isend(cComm->sendComm, ((char*)buf)+speer*len, len, tag, sMhandle, &srequest));
if (rrequest == NULL) NCCLCHECK(ncclNetPlugin_v7.irecv(cComm->recvComm, 1, &rbuf, &len, &tag, &rMhandle, &rrequest));
if (srequest == NULL) NCCLCHECK(ncclNetPlugin_v8.isend(cComm->sendComm, ((char*)buf)+speer*len, len, tag, sMhandle, &srequest));
if (rrequest == NULL) NCCLCHECK(ncclNetPlugin_v8.irecv(cComm->recvComm, 1, &rbuf, &len, &tag, &rMhandle, &rrequest));
}
while (srequest || rrequest) {
int done = 0; /* silent uninitialized false positive */
if (rrequest) NCCLCHECK(ncclNetPlugin_v7.test(rrequest, &done, NULL));
if (rrequest) NCCLCHECK(ncclNetPlugin_v8.test(rrequest, &done, NULL));
if (done) rrequest = NULL;
if (srequest) NCCLCHECK(ncclNetPlugin_v7.test(srequest, &done, NULL));
if (srequest) NCCLCHECK(ncclNetPlugin_v8.test(srequest, &done, NULL));
if (done) srequest = NULL;
}
speer = rpeer;
}
if (p2p_plugin != NCCL_P2P_UCX) {
NCCLCHECK(ncclNetPlugin_v7.deregMr(cComm->recvComm, rMhandle));
NCCLCHECK(ncclNetPlugin_v7.deregMr(cComm->sendComm, sMhandle));
NCCLCHECK(ncclNetPlugin_v8.deregMr(cComm->recvComm, rMhandle));
NCCLCHECK(ncclNetPlugin_v8.deregMr(cComm->sendComm, sMhandle));
}

return 0;
Expand Down Expand Up @@ -247,7 +247,7 @@ ncclResult_t ncclSharpListen(int dev, void* opaqueHandle, void** listenComm) {
ncclResult_t status;

NCCLCHECK(ncclIbMalloc((void**)&lComm, sizeof(struct ncclSharpListenComm)));
status = ncclNetPlugin_v7.listen(dev, opaqueHandle, &lComm->listenCommP2P);
status = ncclNetPlugin_v8.listen(dev, opaqueHandle, &lComm->listenCommP2P);
lComm->dev = dev;
*listenComm = lComm;
return status;
Expand Down Expand Up @@ -437,7 +437,7 @@ ncclResult_t ncclSharpDeregMr(void* collComm, void* mhandle) {
WARN("SHARP deregmr failed\n");
}

NCCLCHECK(ncclNetPlugin_v7.deregMr(cComm->recvComm, mh->ncclIbMr));
NCCLCHECK(ncclNetPlugin_v8.deregMr(cComm->recvComm, mh->ncclIbMr));

free(mh);
return ncclSuccess;
Expand Down Expand Up @@ -647,7 +647,7 @@ ncclResult_t ncclSharpIflush(void* collComm, void* data, int size, void* mhandle

NCCLCHECK(ncclSharpGetRequest(cComm->reqs, &req));
req->requestType = NCCL_SHARP_REQ_IFLUSH;
ncclNetPlugin_v7.iflush(cComm->recvComm, 1, &data, &size, &mh->ncclIbMr, &req->sharpRequest);
ncclNetPlugin_v8.iflush(cComm->recvComm, 1, &data, &size, &mh->ncclIbMr, &req->sharpRequest);
if (!req->sharpRequest) {
*request = NULL;
req->used = 0;
Expand All @@ -662,7 +662,7 @@ ncclResult_t ncclSharpTest(void* request, int* done, int* size) {
struct ncclSharpRequest* req = (struct ncclSharpRequest*)request;

if (req->requestType == NCCL_SHARP_REQ_IFLUSH) {
ncclNetPlugin_v7.test(req->sharpRequest, done, size);
ncclNetPlugin_v8.test(req->sharpRequest, done, size);
if (*done == 1) {
req->used = 0;
}
Expand Down Expand Up @@ -696,8 +696,8 @@ ncclResult_t ncclSharpCloseColl(void* collComm) {
sharp_coll_comm_destroy(cComm->sharpCollComm);
sharp_coll_finalize(cComm->sharpCollContext);

NCCLCHECK(ncclNetPlugin_v7.closeRecv(cComm->recvComm));
NCCLCHECK(ncclNetPlugin_v7.closeSend(cComm->sendComm));
NCCLCHECK(ncclNetPlugin_v8.closeRecv(cComm->recvComm));
NCCLCHECK(ncclNetPlugin_v8.closeSend(cComm->sendComm));
free(cComm);
return ncclSuccess;
}
Expand All @@ -706,7 +706,7 @@ ncclResult_t ncclSharpCloseListen(void* listenComm) {
struct ncclSharpListenComm *lComm = (struct ncclSharpListenComm*)listenComm;
ncclResult_t status;

status = ncclNetPlugin_v7.closeListen(lComm->listenCommP2P);
status = ncclNetPlugin_v8.closeListen(lComm->listenCommP2P);
free(listenComm);
return status;
}
Expand Down

0 comments on commit 610b90f

Please sign in to comment.