diff --git a/projects/rccl/src/msccl.cc b/projects/rccl/src/msccl.cc index 5814ac353a..19b98a8af5 100644 --- a/projects/rccl/src/msccl.cc +++ b/projects/rccl/src/msccl.cc @@ -46,15 +46,16 @@ ncclResult_t mscclRunAlgo_impl( size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op, mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream) { struct NvtxParamsMsccl { - size_t sendbytes; - size_t recvbytes; + size_t bytes; + ncclRedOp_t op; }; // Just pass the size of one send/recv messages and not the total bytes sent/received. constexpr nvtxPayloadSchemaEntry_t MscclSchema[] = { - {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Send)"}, - {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Recv)"} + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"}, + {0, NVTX_PAYLOAD_ENTRY_NCCL_REDOP, "Reduction operation", nullptr, 0, + offsetof(NvtxParamsMsccl, op)} }; - NvtxParamsMsccl payload{sendCounts[comm->rank] * ncclTypeSize(dataType), recvCounts[comm->rank] * ncclTypeSize(dataType)}; + NvtxParamsMsccl payload{count * ncclTypeSize(dataType), op}; NVTX3_FUNC_WITH_PARAMS(MSCCL, MscclSchema, payload) mscclStatus& status = mscclGetStatus(comm->rank);