diff --git a/src/sharp_plugin.c b/src/sharp_plugin.c index 9d55186..c905956 100644 --- a/src/sharp_plugin.c +++ b/src/sharp_plugin.c @@ -137,18 +137,18 @@ int ncclSharpAllGather(void *context, void *buf, int len) { for (int i=0; inranks-1; i++) { void* srequest = NULL, *rrequest = NULL; int rpeer = (speer-1+cComm->nranks)%cComm->nranks; - while (srequest == NULL || rrequest == NULL) { - void *rbuf = ((char*)buf)+rpeer*len; + int sdone = 0; + int rdone = 0; + while (!sdone || !rdone) { int tag = 0x69; if (srequest == NULL) NCCLCHECK(ncclNetPlugin_v8.isend(cComm->sendComm, ((char*)buf)+speer*len, len, tag, sMhandle, &srequest)); - if (rrequest == NULL) NCCLCHECK(ncclNetPlugin_v8.irecv(cComm->recvComm, 1, &rbuf, &len, &tag, &rMhandle, &rrequest)); - } - while (srequest || rrequest) { - int done = 0; /* silent uninitialized false positive */ - if (rrequest) NCCLCHECK(ncclNetPlugin_v8.test(rrequest, &done, NULL)); - if (done) rrequest = NULL; - if (srequest) NCCLCHECK(ncclNetPlugin_v8.test(srequest, &done, NULL)); - if (done) srequest = NULL; + if (rrequest == NULL) { + void *rbuf = ((char*)buf)+rpeer*len; + NCCLCHECK(ncclNetPlugin_v8.irecv(cComm->recvComm, 1, &rbuf, &len, &tag, &rMhandle, &rrequest)); + } + + if (!sdone && srequest) NCCLCHECK(ncclNetPlugin_v8.test(srequest, &sdone, NULL)); + if (!rdone && rrequest) NCCLCHECK(ncclNetPlugin_v8.test(rrequest, &rdone, NULL)); } speer = rpeer; }