diff --git a/src/enqueue.cc b/src/enqueue.cc index 6f3318e7e6..856bc57cff 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -1616,9 +1616,11 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) { } NCCLCHECKGOTO(ArgsCheck(info), ret, fail); - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", - info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count, - info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d", + info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count, + info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream, + info->comm->tasks.nTasksP2p + info->comm->tasks.nTasksColl, + info->comm->localRankToRank[info->comm->localRank]); TRACE_CALL("nccl%s(%" PRIx64 ",%" PRIx64 ",%zi,%d,%d,%d,%p,%p)", info->opName, reinterpret_cast(info->sendbuff), reinterpret_cast(info->recvbuff), info->count, info->datatype, info->op, info->root, info->comm, info->stream); NCCLCHECKGOTO(taskAppend(info->comm, info), ret, fail); diff --git a/tools/rccl_replayer/Makefile b/tools/rccl_replayer/Makefile new file mode 100644 index 0000000000..e21d4a7390 --- /dev/null +++ b/tools/rccl_replayer/Makefile @@ -0,0 +1,21 @@ +RCCL_DIR ?= ../../build/release + +ifdef MPI_DIR +main: rcclReplayer.cpp + /opt/rocm/bin/hipcc rcclReplayer.cpp \ + -g \ + -o rcclReplayer \ + -I$(MPI_DIR)/ \ + -I$(RCCL_DIR) \ + -I$(RCCL_DIR)/include/rccl \ + -I/opt/rocm/include/hip \ + -L$(MPI_DIR)/lib \ + -L$(RCCL_DIR) -lmpich -lrccl +else +main: + @echo "Error: MPI_DIR was not specified." + @exit 1 +endif + +clean: + rm -f ./rcclReplayer \ No newline at end of file diff --git a/tools/rccl_replayer/README.md b/tools/rccl_replayer/README.md new file mode 100644 index 0000000000..60de03a845 --- /dev/null +++ b/tools/rccl_replayer/README.md @@ -0,0 +1,83 @@ +# RCCL REPLAYER +Collective log replayer tool for RCCL. + +## Table of Contents + +1. [Introduction](#introduction) +2. [Features](#features) +3. [How It Works](#how-it-works) +4. [Installation](#installation) +5. [Usage](#usage) + +## Introduction + +Replayer is a dubugging tool designed to analyze and replay collective logs obtained from RCCL (ROCm Communication Collectives Library) runs. It can be a useful tool when trying to recreate problem situations (without as much setup), or as a user-directed utility to run collectives (by crafting their own 'logfile'). + +## Features + +- Parses and validates collective logs from RCCL runs. +- Detects missing/faulty group calls and provides report. +- Replays collective calls based on the recorded data. +- Skips faulty group calls during replay. +- Supports various MPI ranks and GPU configurations. +- Supports multi-node environment. + +*Note: RCCL Replayer executes collective calls with dummy data.* + +## How It Works + +Replayer operates in the following steps: + +1. **Collective Log Collection:** During your RCCL runs, the collective logs are generated when NCCL_DEBUG=INFO and NCCL_DEBUG_SUBSYS=COLL enabled, capturing important information like hostname, deviceIdx, collective call type, number of elements used, data type, operation type, task number, and global rank number about collective communication patterns. + +2. **Data Aggregation:** Replayer collects and pareses the collective logs. organizing them based on opCount (collective count in the group call), and global rank information. + +3. **Group Call Validation:** After acquiring data from the collective logs and generating group calls, the replayer validates the results using two different methods. For Non-Send/Recv collectives, it checks if each MPI rank has the required number of collective tasks. For Send/Recv collectives, it verifies if they all have a matching pair. + +4. **Replaying RCCL:** Based on the aggregated and validated data, Replayer will replay the collective logs to reproduce the RCCL runs from your application. + +5. **Reporting and Skipping.** Replayer outputs the detected faulty group calls and skips them during replay. It provides a report showing which group calls were skipped and why and, at the end, summarizes how many group calls were replayed and how many were skipped. + +## Installation + +To build the replayer, follow these steps: +1. Navigate to the rccl_replayer directory. +2. Make sure 'MPI_DIR' is set to the path where your MPI installation is located. + +```bash + cd rccl/tools/rccl_replayer + MPI_DIR=/path/to/mpi make +``` + +Depending on the MPI library used and your installation path, you may need to set the MPI_DIR path accordingly. + + +## Usage + +After successfully building the replayer, you can run it using the following command: + +```bash + mpirun -np ./rcclReplayer +``` + +Replace with the number of MPI processes you want to run during the replay, with the path to the collective log file generated during your RCCL runs, and with the number of GPUs per MPI rank used in your application. + +Depending on the MPI library you use, you may need to modify the mpirun command accordingly. + +### Multi-Node Environment: + +If multiple nodes were used for your application, you can also replay the collective logs using multiple nodes. See the following command: + +```bash + mpirun --hostfile -np ./rcclReplayer +``` + +### SLURM: + +For systems using SLURM, you can use the following command to replay the collective logs: + +```bash + srun -N -n ./rcclReplayer +``` + +Replace with the number of nodes used in your application. diff --git a/tools/rccl_replayer/rcclReplayer.cpp b/tools/rccl_replayer/rcclReplayer.cpp new file mode 100644 index 0000000000..0629a4d9b2 --- /dev/null +++ b/tools/rccl_replayer/rcclReplayer.cpp @@ -0,0 +1,387 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "rcclReplayer.hpp" + +bool ParseLineItem(char const* line, LineItem& li) +{ + return sscanf(line, + "%[^:]:%d:%d [%d] NCCL INFO %[^:]: opCount %d sendbuff %s " + "recvbuff %s count %lu datatype %d op %d root %d comm %s " + "[nranks=%d] stream %p task %d globalrank %d", + li.hostname, &li.pid, &li.tid, &li.cudaDev, li.opName, + &li.opCount, li.sendbuff, li.recvbuff, + &li.count, &li.datatype, &li.op, &li.root, li.comm, + &li.nRanks, &li.stream, &li.task, &li.globalRank) == 17; +} + +void ParseCollectives(char const* logFilename, int const numGlobalRanks, std::vector& groupCalls) { + int mpiRank; + MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank); + + groupCalls.clear(); + + FILE *fp = fopen(logFilename, "r"); + if (!fp) { + printf("[ERROR] Unable to open file %s\n", logFilename); + exit(-1); + } + + char line[1000]; + LineItem li; + int lineNum = 0; + while (fgets(line, 1000, fp)) { + ++lineNum; + + //Ignore invalid lines and collectives + if (!ParseLineItem(line, li) || li.nRanks != numGlobalRanks) continue; + + TaskInfo taskInfo; + taskInfo.funcType = GetFuncType(li.opName); + taskInfo.inPlace = !strcmp(li.sendbuff, li.recvbuff); + taskInfo.count = li.count; + taskInfo.datatype = (ncclDataType_t) li.datatype; + taskInfo.op = (ncclRedOp_t) li.op; + taskInfo.root = li.root; + + // Find the appropriate GroupCall that this task belongs to + // If it doesn't exist yet, then create it + bool found = false; + for (auto& gc : groupCalls) { + if (gc.rankData.count(li.globalRank)) { + RankData& rd = gc.rankData[li.globalRank]; + if (rd.comm != li.comm || rd.tasks.size() != li.task) + continue; + + rd.tasks.push_back(taskInfo); + found = true; + break; + } + // Rank has no tasks - make sure this is task 0 + else if (li.task == 0) { + gc.rankData[li.globalRank].comm = li.comm; + gc.rankData[li.globalRank].lineNum = lineNum; + gc.rankData[li.globalRank].tasks.push_back(taskInfo); + found = true; + break; + } + } + + // If no collectives were found, create new one + if (!found) { + if (li.task != 0) { + if (mpiRank == 0) printf("[WARN] Was unable to find corresponding collective for line %d\n", lineNum); + } + + groupCalls.resize(groupCalls.size() + 1); + GroupCall& gc = groupCalls.back(); + gc.opCount = li.opCount; + gc.rankData[li.globalRank].comm = li.comm; + gc.rankData[li.globalRank].lineNum = lineNum; + gc.rankData[li.globalRank].tasks.push_back(taskInfo); + } + } + + // - For non Send/Recv, check that all ranks participate with same parameters count + // - For Send/Recv, check that pairs of Send/Recv calls exist + if (mpiRank == 0) printf("Found %lu groupCalls\n", groupCalls.size()); + for (int i = 0; i < groupCalls.size(); i++) { + GroupCall& gc = groupCalls[i]; + std::map, std::vector> arrivalCounter; + + gc.isValid = true; + + if (mpiRank == 0) { + printf("GroupCall %d\n", i); + printf(" - OpCount: %d\n", gc.opCount); + } + + for (auto rd : gc.rankData) { + if (mpiRank == 0) { + printf(" - Rank %02d: comm %s\n", rd.first, rd.second.comm.c_str()); + } + + for (int task = 0; task < rd.second.tasks.size(); task++) { + TaskInfo ti = rd.second.tasks[task]; + const char* funcName; + + if (ti.funcType == ncclCollSend || ti.funcType == ncclCollRecv) + funcName = "Send/Recv"; + else + funcName = ncclFuncNames[ti.funcType]; + + std::tuple key(funcName, ti.count, ti.datatype, ti.op); + + if (mpiRank == 0) { + printf(" - Task %02d: %32s inPlace=%d count=%lu datatype=%d op=%d root=%d\n", + task, funcName, ti.inPlace, ti.count, ti.datatype, ti.op, ti.root); + } + + auto& rankVector = arrivalCounter[key]; + if (rankVector.size() < numGlobalRanks) { + rankVector.resize(numGlobalRanks); + } + + // rankVector in arrivalCount represents the rank information + // Count the number of tasks that are going to be executed by each rank. This is to validate the group call later on. + // Nom-Send/Recv rank counts (rankVector elements) should be equal at the end, and for Send/Recv, all the elements of rankVector should be equal to 0 + if (ti.funcType == ncclCollRecv) { + rankVector[ti.root]--; + } else { + rankVector[rd.first]++; + } + } + } + + // Iterate through the map variable and report/validate the results + for (const auto& e : arrivalCounter) { + int maxVal; + const char* funcName = std::get<0>(e.first); + size_t count = std::get<1>(e.first); + int datatype = std::get<2>(e.first); + int op = std::get<3>(e.first); + + bool isp2p = (strcmp(std::get<0>(e.first), "Send/Recv") == 0); + if (!isp2p) maxVal = *std::max_element(e.second.begin(), e.second.end()); + + // Validate all the ranks have required amount of collective call (task) + for (int i = 0; i < e.second.size(); i++) { + if (e.second[i] != (isp2p ? 0 : maxVal)) { + std::string warning = (isp2p ? (e.second[i] > 0 ? "[WARN] Missing Recv" : "[WARN] Missing Send") : "[WARN] Missing " + std::string(funcName)) + + " count=" + std::to_string(count) + " datatype=" + std::to_string(datatype) + " op=" + std::to_string(op) + " at rank [" + std::to_string(i) + "]"; + if(mpiRank == 0) printf("%s\n", warning.c_str()); + + gc.isValid = false; + } + } + } + } +} + +// GetSize will return a pair of bytes where first element in pair represents bytesSent and the second bytesRecv +std::pair GetSize(TaskInfo taskInfo, int numGlobalRanks) { + size_t sendNumBytes; + size_t recvNumBytes; + + if (taskInfo.funcType == ncclCollBroadcast || taskInfo.funcType == ncclCollReduce || taskInfo.funcType == ncclCollAllReduce) { + sendNumBytes = taskInfo.count * DataTypeToBytes(taskInfo.datatype); + recvNumBytes = sendNumBytes; + } else if (taskInfo.funcType == ncclCollAllGather || taskInfo.funcType == ncclCollGather) { + sendNumBytes = taskInfo.count * DataTypeToBytes(taskInfo.datatype); + recvNumBytes = numGlobalRanks * sendNumBytes; + } else if (taskInfo.funcType == ncclCollReduceScatter || taskInfo.funcType == ncclCollScatter) { + recvNumBytes = taskInfo.count * DataTypeToBytes(taskInfo.datatype); + sendNumBytes = numGlobalRanks * recvNumBytes; + } else if (taskInfo.funcType == ncclCollAllToAll) { + sendNumBytes = numGlobalRanks * taskInfo.count * DataTypeToBytes(taskInfo.datatype); + recvNumBytes = sendNumBytes; + } else { + sendNumBytes = taskInfo.count * DataTypeToBytes(taskInfo.datatype); + recvNumBytes = sendNumBytes; + } + return std::make_pair(sendNumBytes, recvNumBytes); +} + +void ExecuteCollective(TaskInfo task, ncclComm_t comm, hipStream_t stream, const void *sendbuff, void *recvbuff) { + + int funcTypeValue = (int)task.funcType; + + switch (funcTypeValue) { + case ncclCollAllGather: + NCCLCHECK(ncclAllGather(sendbuff, recvbuff, task.count, task.datatype, comm, stream)); + break; + case ncclCollAllReduce: + NCCLCHECK(ncclAllReduce(sendbuff, recvbuff, task.count, task.datatype, task.op, comm, stream)); + break; + case ncclCollBroadcast: + NCCLCHECK(ncclBroadcast(sendbuff, recvbuff, task.count, task.datatype, task.root, comm, stream)); + break; + case ncclCollReduce: + NCCLCHECK(ncclReduce(sendbuff, recvbuff, task.count, task.datatype, task.op, task.root, comm, stream)); + break; + case ncclCollReduceScatter: + NCCLCHECK(ncclReduceScatter(sendbuff, recvbuff, task.count, task.datatype, task.op, comm, stream)); + break; + case ncclCollGather: + NCCLCHECK(ncclGather(sendbuff, recvbuff, task.count, task.datatype, task.root, comm, stream)); + break; + case ncclCollScatter: + NCCLCHECK(ncclScatter(sendbuff, recvbuff, task.count, task.datatype, task.root, comm, stream)); + break; + case ncclCollAllToAll: + NCCLCHECK(ncclAllToAll(sendbuff, recvbuff, task.count, task.datatype, comm, stream)); + break; + case ncclCollSend: + NCCLCHECK(ncclSend(sendbuff, task.count, task.datatype, task.root, comm, stream)); + break; + case ncclCollRecv: + NCCLCHECK(ncclRecv(recvbuff, task.count, task.datatype, task.root, comm, stream)); + break; + default: + printf("Error: unsupported collective\n"); + exit(1); + } +} + +void ReplayRccl(GroupCall& groupCall, std::vector comms, std::vector streams, + int const localGpuOffset, int const numGpusPerMpiRank, int const firstGlobalRank, int const numGlobalRanks) { + + std::vector> sendbuff(numGpusPerMpiRank); + std::vector> recvbuff(numGpusPerMpiRank); + + NCCLCHECK(ncclGroupStart()); + for (int localIdx = 0; localIdx < numGpusPerMpiRank; localIdx++) { + int globalRank = firstGlobalRank + localIdx; + RankData& rankData = groupCall.rankData[globalRank]; + + for (auto task : rankData.tasks) { + void* sendBuffer; + void* recvBuffer; + + // Each task has a size based on the type of collective (funcType) + std::pair numBytes = GetSize(task, numGlobalRanks); + + if (task.inPlace) { + numBytes.first = std::max(numBytes.first, numBytes.second); + numBytes.second = numBytes.first; + } + + // Set the device and allocate send/recv buffers + HIPCALL(hipSetDevice(localGpuOffset + localIdx)); + HIPCALL(hipMalloc(&sendBuffer, numBytes.first)); + HIPCALL(hipMalloc(&recvBuffer, numBytes.second)); + HIPCALL(hipMemset(sendBuffer, 0, numBytes.first)); + HIPCALL(hipMemset(recvBuffer, 0, numBytes.second)); + HIPCALL(hipDeviceSynchronize()); + + // Add the send and receive buffers to their respective vectors + sendbuff[localIdx].push_back(sendBuffer); + recvbuff[localIdx].push_back(recvBuffer); + + // Execute the collective call (task) + ExecuteCollective(task, comms[localIdx], streams[localIdx], sendBuffer, recvBuffer); + } + } + NCCLCHECK(ncclGroupEnd()); + + // Synchronize devices + for (int i = 0; i < numGpusPerMpiRank; i++) { + HIPCALL(hipStreamSynchronize(streams[i])); + } + + // Free device memory for each task on each GPU + for (int i = 0; i < numGpusPerMpiRank; i++) { + for (auto& sendBuffer : sendbuff[i]) HIPCALL(hipFree(sendBuffer)); + for (auto& recvBuffer : recvbuff[i]) HIPCALL(hipFree(recvBuffer)); + } +} + +int main(int argc, char **argv) { + MPI_Init(&argc, &argv); + if (argc <= 1) { + printf("Usage: %s logfile [numGpusPerMpiRank = 1]\n", argv[0]); + exit(1); + } + + // Parse rank information + int mpiRank, numMpiRanks; + MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank); + MPI_Comm_size(MPI_COMM_WORLD, &numMpiRanks); + + // Default value for numGpusPerMpiRank is 1 + char* logFilename = argv[1]; + int numGpusPerMpiRank = (argc > 2 ? atoi(argv[2]) : 1); + int numGlobalRanks = numMpiRanks * numGpusPerMpiRank; + + if (mpiRank == 0) + printf("RCCL Replayer: %d x %d = %d total ranks\n", numMpiRanks, numGpusPerMpiRank, numGlobalRanks); + + // Parse logfile for Collectives + std::vector groupCalls; + ParseCollectives(logFilename, numGlobalRanks, groupCalls); + + int localGpuOffset = 0; + int firstGlobalRank = mpiRank * numGpusPerMpiRank; + int lastGlobalRank = firstGlobalRank + numGpusPerMpiRank - 1; + + // Figure out the host and get the localGpuOffset + int nameLen; + char name[MPI_MAX_PROCESSOR_NAME]; + std::vector allnames(numMpiRanks * MPI_MAX_PROCESSOR_NAME, 0); + + MPI_Get_processor_name(name, &nameLen); + MPI_Allgather(name, MPI_MAX_PROCESSOR_NAME, MPI_CHAR, + allnames.data(), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, MPI_COMM_WORLD); + + for (int rank = 0; rank < mpiRank; rank++) + { + if (!strcmp(name, allnames.data() + (rank * MPI_MAX_PROCESSOR_NAME))) + localGpuOffset += numGpusPerMpiRank; + } + + printf("Rank %d [%s] LocalGpuOffset: %d GlobalRankFirst %d GlobalRankLast %d\n", + mpiRank, name, localGpuOffset, firstGlobalRank, lastGlobalRank); + + // Create a unique ID and broadcast it to all ranks + ncclUniqueId uniqueId; + if (mpiRank == 0) ncclGetUniqueId(&uniqueId); + MPI_Bcast(&uniqueId, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD); + + // Each rank has it's own comm and stream + std::vector comms(numGpusPerMpiRank); + std::vector streams(numGpusPerMpiRank); + + // Initialize comms and strams + NCCLCHECK(ncclGroupStart()); + for (int i = 0; i < numGpusPerMpiRank; i++) { + HIPCALL(hipSetDevice(localGpuOffset + i)); + NCCLCHECK(ncclCommInitRank(&(comms[i]), numGlobalRanks, uniqueId, firstGlobalRank + i)); + HIPCALL(hipStreamCreate(&(streams[i]))); + } + NCCLCHECK(ncclGroupEnd()); + + int numSkippedCalls = 0; + auto start = std::chrono::high_resolution_clock::now(); + for (auto groupCall : groupCalls) + if (groupCall.isValid) + ReplayRccl(groupCall, comms, streams, localGpuOffset, numGpusPerMpiRank, firstGlobalRank, numGlobalRanks); + else { + if (mpiRank == 0) printf("[ERROR] in group call: (skipping...)\n"); + for (auto rd : groupCall.rankData) { + if (mpiRank == 0) printf(" - Rank %02d: comm %s in line %d\n", rd.first, rd.second.comm.c_str(), rd.second.lineNum); + for (int task = 0; task < rd.second.tasks.size(); task++) { + TaskInfo ti = rd.second.tasks[task]; + if (mpiRank == 0) + printf(" - Task %02d: %32s inPlace=%d count=%lu datatype=%d op=%d root=%d\n", + task, ncclFuncNames[ti.funcType], ti.inPlace, ti.count, ti.datatype, ti.op, ti.root); + } + } + numSkippedCalls++; + } + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration duration = end - start; + + // Need to destroy comms and streams after collective execution is done + for (int i = 0; i < numGpusPerMpiRank; ++i) { + ncclCommDestroy(comms[i]); + HIPCALL(hipStreamDestroy(streams[i])); + } + + MPI_Finalize(); + + if (mpiRank == 0) printf("Executed group calls: %zu\n", groupCalls.size() - numSkippedCalls); + if (mpiRank == 0) printf("Skipped group calls: %d\n", numSkippedCalls); + + // Time it takes to execute all the group calls + if (mpiRank == 0) printf("Execution Time: %f seconds\n", duration.count()); + + // Means no hang + printf("MPI Rank %d Success\n", mpiRank); + + return 0; +} diff --git a/tools/rccl_replayer/rcclReplayer.hpp b/tools/rccl_replayer/rcclReplayer.hpp new file mode 100644 index 0000000000..3d8f1dc2fc --- /dev/null +++ b/tools/rccl_replayer/rcclReplayer.hpp @@ -0,0 +1,170 @@ +#pragma once +#include +#include + +#include + +// NOTE: Parsing is based on this line logging collective information in enqueue.cc +// INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d \ + root %d comm %p [nranks=%d] stream %p task %d globalrank %d", +// info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count, +// info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream, +// info->comm->tasks.nTasksP2p + info->comm->tasks.nTasksColl, +// info->comm->localRankToRank[info->comm->localRank]); + +#define MPICHECK(cmd) do { \ + int e = cmd; \ + if( e != MPI_SUCCESS ) { \ + printf("Failed: MPI error %s:%d '%d'\n", \ + __FILE__,__LINE__, e); \ + exit(EXIT_FAILURE); \ + } \ +} while(0) + +#define HIPCALL(cmd) \ + do { \ + hipError_t error = (cmd); \ + if (error != hipSuccess) \ + { \ + printf("Encountered HIP error (%s) at line %d in file %s\n", \ + hipGetErrorString(error), __LINE__, __FILE__); \ + exit(-1); \ + } \ + } while (0) + +#define NCCLCHECK(cmd) do { \ + ncclResult_t res = cmd; \ + if (res != ncclSuccess) { \ + printf("NCCL failure %s:%d '%s'\n", \ + __FILE__,__LINE__,ncclGetErrorString(res)); \ + } \ +} while(0) + +struct LineItem +{ + char hostname[MPI_MAX_PROCESSOR_NAME]; + int pid; + int tid; + int cudaDev; + char opName[32]; + int opCount; + char sendbuff[32]; + char recvbuff[32]; + size_t count; + int datatype; + int op; + int root; + char comm[32]; + int nRanks; + void* stream; + int task; + int globalRank; +}; + +// Enumeration of all collective functions currently supported +typedef enum +{ + ncclCollBroadcast = 0, + ncclCollReduce, + ncclCollAllGather, + ncclCollReduceScatter, + ncclCollAllReduce, + ncclCollGather, + ncclCollScatter, + ncclCollAllToAll, + ncclCollAllToAllv, + ncclCollSend, + ncclCollRecv, + ncclNumFuncs +} ncclFunc_t; + +char const ncclFuncNames[ncclNumFuncs][32] = +{ + "Broadcast", + "Reduce", + "AllGather", + "ReduceScatter", + "AllReduce", + "Gather", + "Scatter", + "AllToAll", + "AllToAllv", + "Send", + "Recv" +}; + +struct TaskInfo +{ + ncclFunc_t funcType; + bool inPlace; + size_t count; + ncclDataType_t datatype; + ncclRedOp_t op; + int root; +}; + +struct RankData +{ + int lineNum; + std::string comm; + std::vector tasks; +}; + +struct GroupCall +{ + bool isValid; + int opCount; + std::map rankData; // Indexed by globalRank +}; + +size_t DataTypeToBytes(ncclDataType_t const dataType) +{ + switch (dataType) { + case ncclInt8: return 1; + case ncclUint8: return 1; + case ncclInt32: return 4; + case ncclUint32: return 4; + case ncclInt64: return 8; + case ncclUint64: return 8; + case ncclFloat16: return 2; + case ncclFloat32: return 4; + case ncclFloat64: return 8; + case ncclBfloat16: return 2; + default: + printf("Unsupported datatype (%d)\n", dataType); + exit(0); + } +} + +ncclFunc_t GetFuncType(char* func) +{ + for (int i = 0; i < ncclNumFuncs; i++) + if (!strcmp(func, ncclFuncNames[i])) return (ncclFunc_t)i; + printf("[ERROR] Unrecognzied func %s\n", func); + exit(1); +} + +// parse the logs and assign them into lineItem +bool ParseLineItem(char const* line, LineItem& li); + +// this covers grouping the logs based on opCount and task number, +// validatation of the groupCalls for both non-send/recv collectives and send/recv +void ParseCollectives(char const* logFilename, + int const numGlobalRanks, + std::vector& groupCalls); + +// size differ for each collective call and getSize gives a specific size in bytes depending on type of task, +// global rank, element count and data type +std::pair GetSize(TaskInfo taskInfo, + int numGlobalRanks); + +// executes the collective call (task) +void ExecuteCollective(TaskInfo task, ncclComm_t comm, hipStream_t stream, const void *sendbuff, void *recvbuff); + +// allocates send/recv buff, sets the device based on which rank the task belongs to, +// syncronize devices after executing all the tasks and free device memory. +void ReplayRccl(GroupCall& groupCall, std::vector comms, std::vector streams, + int const localGpuOffset, + int const numGpusPerMpiRank, + int const firstGlobalRank, + int const numGlobalRanks); \ No newline at end of file