Fix collective trace when rccl is configured (#1056)

* Fix collective trace when rccl is configured

[ROCm/rccl commit: c4dbf8a914]
This commit is contained in:
Bertan Dogancay
2024-01-22 09:26:44 -07:00
committed by GitHub
parent 8b8179a689
commit 56482a8be8
5 changed files with 28 additions and 68 deletions
+1
View File
@@ -1,3 +1,4 @@
# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*.gcov
/coverage/
build/
+18 -5
View File
@@ -150,7 +150,7 @@ function(gen_device_table)
message(STATUS "Generating ${DEVICE_TABLE_FILE}")
## Generate device table and list all the functions
file(WRITE ${DEVICE_TABLE_FILE} "#include \"common.h\"\n#include \"collectives.h\"\n\n")
file(WRITE ${DEVICE_TABLE_FILE} "#include \"common.h\"\n#include \"collectives.h\"\n#include \"devcomm.h\"\n\n")
## Declaration of device functions
foreach(func IN LISTS FUNC_LIST)
@@ -216,6 +216,18 @@ function(gen_device_table)
file(APPEND ${DEVICE_TABLE_FILE} " }\n}\n")
endif()
## Function name table for collective trace
if(COLLTRACE)
file(APPEND ${DEVICE_TABLE_FILE} "const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
foreach(func ${FUNC_LIST})
file(APPEND ${DEVICE_TABLE_FILE} " \"${func}\",\n")
endforeach()
foreach(type IN LISTS ALL_TYPES)
file(APPEND ${DEVICE_TABLE_FILE} " \"ncclFunction_OneRankReduce_PreMulSum_${type}\",\n")
endforeach()
file(APPEND ${DEVICE_TABLE_FILE} "};\n")
endif()
## Add the device_table file to HIP_SOURCES
list(APPEND HIP_SOURCES ${DEVICE_TABLE_FILE})
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
@@ -252,6 +264,7 @@ function(gen_host_table)
list(FIND FUNC_LIST "ncclFunction_${coll}_${algo}_${proto}_${redop}_${type}" fn_id)
if(NOT ${fn_id} EQUAL -1)
set(last_valid_fn_id ${fn_id})
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclFunction_${coll}_${algo}_${proto}_${redop}_${type}\n")
else()
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id},\n")
@@ -262,16 +275,16 @@ function(gen_host_table)
endforeach()
endforeach()
endforeach()
math(EXPR fn_id "${fn_id} + 1")
math(EXPR last_valid_fn_id "${last_valid_fn_id} + 1")
## Add OneRankReduce function ids at the end
foreach(type IN LISTS ALL_TYPES)
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclFunction_OneRankReduce_PreMulSum_${type}\n")
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${last_valid_fn_id}, // ncclFunction_OneRankReduce_PreMulSum_${type}\n")
## Increment the index and func id for each OneRankReduce
math(EXPR idx "${idx} + 1")
math(EXPR fn_id "${fn_id} + 1")
math(EXPR last_valid_fn_id "${last_valid_fn_id} + 1")
endforeach()
file(APPEND ${HOST_TABLE_FILE} "-1};\n\n")
file(APPEND ${HOST_TABLE_FILE} "${last_valid_fn_id}};\n\n")
## Add the host_table file to HIP_SOURCES
list(APPEND HIP_SOURCES ${HOST_TABLE_FILE})
+1
View File
@@ -21,6 +21,7 @@ struct ncclDevRedOpFull {
#define FUNC_INDEX_P2P 1015
#define FUNC_INDEX_ALLTOALL_PIVOT 675
#define FUNC_INDEX_TOTAL 1026 // Total number of functions that goes into librccl.so index in host_table.cc
#define NCCL_FUNC_NAME(func, algo, proto, devredop, type) \
ncclFunction_##func##_##algo##_##proto##_##devredop##_##type
+3
View File
@@ -21,6 +21,9 @@
#define NCCL_NUM_FUNCTIONS 5 // SendRecv and AllToAllPivot not included for now
typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclFuncSend, ncclFuncRecv, ncclFuncAllToAllPivot, ncclNumFuncs} ncclFunc_t;
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2];
#ifdef ENABLE_COLLTRACE
extern const char* funcNames[FUNC_INDEX_TOTAL];
#endif
#define NCCL_NUM_ALGORITHMS 6 // Tree/Ring/CollNet*
#define NCCL_ALGO_TREE 0
+5 -63
View File
@@ -19,6 +19,7 @@
#include "graph.h"
#include "argcheck.h"
#include "devcomm.h"
#include "collectives.h"
#if defined(ENABLE_NPKIT)
#include "npkit/npkit.h"
#endif
@@ -179,17 +180,6 @@ void NCCL_NO_OPTIMIZE commPoison(ncclComm_t comm) {
RCCL_PARAM(KernelCollTraceEnable, "KERNEL_COLL_TRACE_ENABLE", 0);
#ifdef ENABLE_COLLTRACE
#define MAX_NAME_LENGTH 64
// Helper function to generate function names and update funcIdx
void generateFunctionName(char* func_names, int& funcIdx, const char* format, ...) {
char* line = func_names + MAX_NAME_LENGTH * funcIdx;
va_list args;
va_start(args, format);
vsnprintf(line, MAX_NAME_LENGTH, format, args);
va_end(args);
funcIdx++;
}
// Should be in sync with 'ALL_COLLS' in Generator.cmake
void *ncclCommThreadMain(void *arg) {
ncclComm_t comm = (ncclComm_t)arg;
@@ -198,53 +188,6 @@ void *ncclCommThreadMain(void *arg) {
memset(head, 0, sizeof(int)*MAXCHANNELS);
vega_gpu_rtc_freq = GetDeviceWallClockRateInKhz(comm->cudaDev) * 1.0E3;
char* func_names = (char *)malloc(MAX_NAME_LENGTH*(ncclFuncId_P2p()+/*OneRankReduce*/11));
int funcIdx = 0;
// AllGather --> RING / <all_protos> / Sum / int8_t
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
generateFunctionName(func_names, funcIdx, "AllGatherRing%sSum_i8", ncclProtoStr[pr]);
}
// AllReduce --> <all_algos> / <all_protos> / <all_redops> / <all_types>
for (int al = 0; al < NCCL_NUM_ALGORITHMS - 2; al++) {
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
for (int ty = 0; ty < ncclNumTypes; ty++) {
if (redop == 5 && ty > 5) continue;
generateFunctionName(func_names, funcIdx, "AllReduce%s%s%s%s", ncclAlgoStr[al], ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
}
}
}
}
// AllToAllPivot --> RING / SIMPLE / Sum / int8_t
generateFunctionName(func_names, funcIdx, "AllToAllPivotRingSimpleSum_i8");
// Broadcast --> RING / <all_protos> / Sum / int8_t
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
generateFunctionName(func_names, funcIdx, "BroadcastRing%sSum_i8", ncclProtoStr[pr]);
}
// Reduce --> RING / <all_protos> / <all_redops> / <all_types>
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
for (int ty = 0; ty < ncclNumTypes; ty++) {
if (redop == 5 && ty > 5) continue;
generateFunctionName(func_names, funcIdx, "ReduceRing%s%s%s", ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
}
}
}
// ReduceScatter --> RING / <all_protos> / <all_redops> / <all_types>
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
for (int ty = 0; ty < ncclNumTypes; ty++) {
if (redop == 5 && ty > 5) continue;
generateFunctionName(func_names, funcIdx, "ReduceScatterRing%s%s%s", ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
}
}
}
// SendRecv --> RING / SIMPLE / Sum / int8_t
generateFunctionName(func_names, funcIdx, "SendRecvRingSimpleSum_i8");
// OneRankReduce --> PreMulSum / <all_types>
for (int ty = 0; ty < ncclNumTypes; ty++) {
generateFunctionName(func_names, funcIdx, "OneRankReducePreMulSum%s", ncclTypeStr[ty]);
}
do {
for (int channel = 0; channel < MAXCHANNELS; channel++) {
int tail = comm->collTraceTail[channel].tail%COLLTRACE_NUM_ITEMS;
@@ -276,9 +219,9 @@ void *ncclCommThreadMain(void *arg) {
sprintf(line, "## [%012.6f] [%02d:%02d] %06lx", (double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid, td->opCount);
offset = strlen(line);
if (type == ncclCollTraceCollElemType) {
sprintf(line+offset, " CE %s nw %d bi %d nc %d busId %lx nRanks %d", func_names+MAX_NAME_LENGTH*fIdx, td->coll.nWarps, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks);
sprintf(line+offset, " CE %s nw %d bi %d nc %d busId %lx nRanks %d", funcNames[fIdx], td->coll.nWarps, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks);
} else if (type == ncclCollTraceP2pElemType) {
sprintf(line+offset, " PE %s %d -> %d/%d/%d/%d conn/nw/ws/ng %d/%d/%d/%d -> %d busId %lx nRanks %d", func_names+MAX_NAME_LENGTH*fIdx,
sprintf(line+offset, " PE %s %d -> %d/%d/%d/%d conn/nw/ws/ng %d/%d/%d/%d -> %d busId %lx nRanks %d", funcNames[fIdx],
td->p2p[0].peer, td->p2p[0].connIndex, td->p2p[0].nWarps, td->p2p[0].warpStart, td->p2p[0].ngroups,
td->p2p[1].connIndex, td->p2p[1].nWarps, td->p2p[1].warpStart, td->p2p[1].ngroups, td->p2p[1].peer, comm->busId, comm->nRanks);
} else {
@@ -286,9 +229,9 @@ void *ncclCommThreadMain(void *arg) {
case ncclCollTraceKernelLaunchType:
case ncclCollTraceCollLaunchType:
if ((type&0xf) == ncclCollTraceKernelLaunchType)
sprintf(line+offset, " KL HWID %8x %s", td->data_0, func_names+MAX_NAME_LENGTH*fIdx);
sprintf(line+offset, " KL HWID %8x %s", td->data_0, funcNames[fIdx]);
else if ((type&0xf) == ncclCollTraceCollLaunchType)
sprintf(line+offset, " CL %s", func_names+MAX_NAME_LENGTH*fIdx);
sprintf(line+offset, " CL %s", funcNames[fIdx]);
offset = strlen(line);
if ((type&0xf0) == ncclCollTraceCollElemType)
sprintf(line+offset, " nw %d bi %d nc %d busId %lx nRanks %d", td->coll.nWarps, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks);
@@ -316,7 +259,6 @@ void *ncclCommThreadMain(void *arg) {
}
}
} while(!comm->collTraceExit);
free(func_names);
pthread_exit(NULL);
}
#endif