From 7e482ec3e9e85217bf6dfef2d2fce18ccfbf5efa Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Wed, 15 May 2024 08:52:45 -0700 Subject: [PATCH] sharp: Add trace option at rank 0 for sharp colls --- src/sharp_plugin.c | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/sharp_plugin.c b/src/sharp_plugin.c index d37e8b2e..338c803d 100644 --- a/src/sharp_plugin.c +++ b/src/sharp_plugin.c @@ -34,6 +34,7 @@ NCCL_PARAM(SharpGroupSizeThresh, "SHARP_GROUP_SIZE_THRESH", 2); NCCL_PARAM(SharpV3Datatypes, "SHARP_V3_DATATYPES", 2); NCCL_PARAM(SharpDisableRS, "SHARP_DISABLE_REDUCE_SCATTER", 0); NCCL_PARAM(SharpDisableAG, "SHARP_DISABLE_ALLGATHER", 0); +NCCL_PARAM(enableSharpTrace, "SHARP_COLL_TRACE", 0); enum ncclSharpRequestType { NCCL_SHARP_REQ_SHARP_COLL, @@ -500,6 +501,9 @@ ncclResult_t ncclSharpIallreduce(void* collComm, void* sendData, void* recvData, reduce_spec.op = op_type; reduce_spec.aggr_mode = SHARP_AGGREGATION_NONE; + if (ncclParamenableSharpTrace() && cComm->rank == 0) + INFO(NCCL_COLL, "Allreduce count:%d, op:%d dtype:%d ", count, op_type, sharp_type); + #if BLOCKING==0 if (SHARP_COLL_SUCCESS != sharp_coll_do_allreduce_nb(cComm->sharpCollComm, &reduce_spec, &req->sharpRequest)) { WARN("SHARP allreduce failed\n"); @@ -546,6 +550,10 @@ ncclResult_t ncclSharpIallgather(void* collComm, void* sendData, int nRecvParts, gather_spec.size = recvParts[0].size; gather_spec.offset = windowOffset; + if (ncclParamenableSharpTrace() && cComm->rank == 0) + INFO(NCCL_COLL, "Allgather size:%lu bytesPerRank:%lu windowOffset:%lu windowBytes:%lu", + recvParts[0].size, bytesPerRank, windowOffset, windowBytes); + #if BLOCKING==0 if (SHARP_COLL_SUCCESS != sharp_coll_do_allgather_nb(cComm->sharpCollComm, &gather_spec, &req->sharpRequest)) { WARN("SHARP Allgather failed\n"); @@ -611,6 +619,10 @@ ncclResult_t ncclSharpIreducescatter(void* collComm, int nSendParts, ncclNetSGE_ reduce_spec.op = op_type; reduce_spec.aggr_mode = SHARP_AGGREGATION_NONE; + if (ncclParamenableSharpTrace() && cComm->rank == 0) + INFO(NCCL_COLL, "ReduceScatter bytesPerRank:%lu windowOffset:%lu windowBytes:%lu op_type:%d dtype:%d", + bytesPerRank, windowOffset, windowBytes, op_type, sharp_type); + #if BLOCKING==0 if (SHARP_COLL_SUCCESS != sharp_coll_do_reduce_scatter_nb(cComm->sharpCollComm, &reduce_spec, &req->sharpRequest)) { WARN("SHARP reduce_scatter failed\n");