Files
rocm-systems/src/msccl.cc
T
Bertan Dogancay b617aecc31 Implement ROCTX (#1094)
* Implement roctx
2024-02-27 15:46:15 -07:00

96 lines
3.6 KiB
C++

/*************************************************************************
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT License.
************************************************************************/
#include "enqueue.h"
#include "msccl/msccl_parser.h"
#include "msccl/msccl_setup.h"
#include "msccl/msccl_status.h"
#include <cstdio>
#include <cstdlib>
NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank) {
mscclStatus& status = mscclGetStatus();
if (status.freeAlgoHandles.size() == 0) {
WARN("MSCCL: MSCCL_MAX_NUM_ALGOS (%d) limit reached", MSCCL_MAX_NUM_ALGOS);
return ncclInvalidUsage;
}
*mscclAlgoHandle = *status.freeAlgoHandles.rbegin();
status.freeAlgoHandles.pop_back();
struct mscclAlgo* hostAlgo;
NCCLCHECK(ncclCalloc(&hostAlgo, 1));
NCCLCHECK(mscclGetAlgoFromXmlFile(mscclAlgoFilePath, hostAlgo, rank));
status.hostAlgos[*mscclAlgoHandle] = hostAlgo;
struct mscclAlgo* devAlgo;
NCCLCHECK(ncclCudaMalloc(&devAlgo, 1));
CUDACHECK(hipMemcpy(devAlgo, hostAlgo, sizeof(struct mscclAlgo), hipMemcpyHostToDevice));
status.devAlgos[*mscclAlgoHandle] = devAlgo;
return ncclSuccess;
}
NCCL_API(ncclResult_t, mscclRunAlgo,
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream);
ncclResult_t mscclRunAlgo(
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream) {
struct NvtxParamsMsccl {
size_t sendbytes;
size_t recvbytes;
};
// Just pass the size of one send/recv messages and not the total bytes sent/received.
constexpr nvtxPayloadSchemaEntry_t MscclSchema[] = {
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Send)"},
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Recv)"}
};
NvtxParamsMsccl payload{sendCounts[comm->rank] * ncclTypeSize(dataType), recvCounts[comm->rank] * ncclTypeSize(dataType)};
NVTX3_FUNC_WITH_PARAMS(MSCCL, MscclSchema, payload)
mscclStatus& status = mscclGetStatus();
struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle];
struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];
NCCLCHECK(mscclGetCaptureStatus(stream));
NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType));
NCCLCHECK(mscclSetupScratch(hostAlgo, stream));
NCCLCHECK(mscclSetupSyncFlags(stream));
NCCLCHECK(mscclSetupProxy(hostAlgo, comm, stream));
NCCLCHECK(mscclSetupKernel(sendBuff, recvBuff, count, dataType, op, hostAlgo, devAlgo, comm, stream));
return ncclSuccess;
}
NCCL_API(ncclResult_t, mscclUnloadAlgo, mscclAlgoHandle_t mscclAlgoHandle);
ncclResult_t mscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle) {
mscclStatus& status = mscclGetStatus();
free(status.hostAlgos[mscclAlgoHandle]);
status.hostAlgos.erase(mscclAlgoHandle);
NCCLCHECK(ncclCudaFree(status.devAlgos[mscclAlgoHandle]));
status.devAlgos.erase(mscclAlgoHandle);
status.freeAlgoHandles.push_back(mscclAlgoHandle);
for (auto &s : status.connectedAlgos) {
s.second.erase(mscclAlgoHandle);
}
return ncclSuccess;
}