diff --git a/src/psm2_nccl_net.c b/src/psm2_nccl_net.c index e7c13fd..01070e4 100755 --- a/src/psm2_nccl_net.c +++ b/src/psm2_nccl_net.c @@ -74,10 +74,16 @@ typedef struct comm_req { typedef struct { psm2_ep_t ep; psm2_epid_t epid; + int unit; unsigned int refcount; } shared_ep_t; +#define SHARED_EP_MULTI_ERROR 0 +#define SHARED_EP_MULTI_WARN 1 + static int use_shared_ep = 1; +static int shared_ep_multi_policy = SHARED_EP_MULTI_WARN; + static int use_gpudirect = 0; shared_ep_t shared_ep = {0}; @@ -176,6 +182,27 @@ static int psm2comm_init_ep(int dev, psm2_uuid_t uuid, psm2comm_t *comm) psm2_ep_t ep; psm2_epid_t epid; + if (use_nccl_dev_num && use_shared_ep && shared_ep.refcount && + shared_ep.unit != dev) { + char msg[1024]; + int pr = snprintf(msg, sizeof(msg), + "Shared-EP<>NCCL-device numbers mismatch; shared-EP deivce=%d, NCCL device=%d.", + shared_ep.unit, dev); + + switch (shared_ep_multi_policy) { + case SHARED_EP_MULTI_ERROR: + PSM_ERROR("%s", msg); + return ncclInternalError; + case SHARED_EP_MULTI_WARN: + pr = snprintf(msg + pr, sizeof(msg) - pr, " Ignoring NCCL device number."); + PSM_WARN("%s", msg); + break; + default: + assert(0); + return ncclInternalError; + } + } + if (!use_shared_ep || !shared_ep.refcount) { struct psm2_ep_open_opts opts; int rc = psm2_ep_open_opts_get_defaults(&opts); @@ -195,6 +222,8 @@ static int psm2comm_init_ep(int dev, psm2_uuid_t uuid, psm2comm_t *comm) if (use_shared_ep) { shared_ep.ep = ep; shared_ep.epid = epid; + if (use_nccl_dev_num) + shared_ep.unit = dev; } } else { ep = shared_ep.ep; @@ -314,6 +343,18 @@ ncclResult_t psm2_nccl_init(ncclDebugLogger_t logFunction) use_shared_ep = (int)envval; } + char *eshep_multi = getenv("PSM2_NCCL_SHARED_EP_MULTI"); + if (eshep_multi) { + if (!strncasecmp(eshep_multi, "error", strlen(eshep_multi))) + shared_ep_multi_policy = SHARED_EP_MULTI_ERROR; + else if (!strncasecmp(eshep_multi, "warn", strlen(eshep_multi))) + shared_ep_multi_policy = SHARED_EP_MULTI_WARN; + else { + PSM_ERROR("PSM2_NCCL_SHARED_EP_MULTI must be warn or error"); + return ncclInternalError; + } + } + char *euse_gdr = getenv("PSM2_NCCL_USE_GPUDIRECT"); if (euse_gdr) { char *end = NULL;