f3c5156bbf
Co-authored-by: Islam <nusislam@amd.com>
476 строки
19 KiB
C++
476 строки
19 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2015-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
|
|
#include "argcheck.h" // Need some checks here since we access comm
|
|
#include "collectives.h"
|
|
#include "enqueue.h"
|
|
#include "graph/topo.h"
|
|
#include "nccl.h"
|
|
#include "api_trace.h"
|
|
#include "nvtx_payload_schemas.h"
|
|
#include "msccl/msccl_lifecycle.h"
|
|
|
|
#ifdef ENABLE_ROCSHMEM
|
|
#include <rocshmem/rocshmem.hpp>
|
|
#endif
|
|
|
|
using namespace rccl;
|
|
|
|
const char* ncclFuncToString(ncclFunc_t fn) {
|
|
switch (fn) {
|
|
case ncclFuncAllGather: return "AllGather";
|
|
case ncclFuncAllReduce: return "AllReduce";
|
|
case ncclFuncAlltoAll: return "AlltoAll";
|
|
case ncclFuncBroadcast: return "Broadcast";
|
|
case ncclFuncGather: return "Gather";
|
|
case ncclFuncRecv: return "Recv";
|
|
case ncclFuncReduce: return "Reduce";
|
|
case ncclFuncReduceScatter: return "ReduceScatter";
|
|
case ncclFuncScatter: return "Scatter";
|
|
case ncclFuncSendRecv: return "SendRecv";
|
|
case ncclFuncSend: return "Send";
|
|
default: return "Invalid";
|
|
}
|
|
}
|
|
|
|
const char* ncclDevRedOpToString(ncclDevRedOp_t op) {
|
|
switch (op) {
|
|
case ncclDevSum: return "Sum";
|
|
case ncclDevProd: return "Prod";
|
|
case ncclDevMinMax: return "MinMax";
|
|
case ncclDevPreMulSum: return "PreMulSum";
|
|
case ncclDevSumPostDiv: return "SumPostDiv";
|
|
default: return "Unknown";
|
|
}
|
|
}
|
|
|
|
const char* ncclDatatypeToString(ncclDataType_t type) {
|
|
switch (type) {
|
|
case ncclInt8: return "ncclInt8";
|
|
case ncclInt32: return "ncclInt32";
|
|
case ncclUint32: return "ncclUint32";
|
|
case ncclInt64: return "ncclInt64";
|
|
case ncclUint64: return "ncclUint64";
|
|
case ncclFloat16: return "ncclFloat16";
|
|
case ncclFloat32: return "ncclFloat32";
|
|
case ncclFloat64: return "ncclFloat64";
|
|
case ncclBfloat16: return "ncclBfloat16";
|
|
case ncclFloat8e4m3: return "ncclFloat8e4m3";
|
|
case ncclFloat8e5m2: return "ncclFloat8e5m2";
|
|
default: return "Unknown";
|
|
}
|
|
}
|
|
|
|
const char* ncclAlgoToString(int algo) {
|
|
switch (algo) {
|
|
case NCCL_ALGO_TREE: return "TREE";
|
|
case NCCL_ALGO_RING: return "RING";
|
|
case NCCL_ALGO_COLLNET_DIRECT: return "COLLNET_DIRECT";
|
|
case NCCL_ALGO_COLLNET_CHAIN: return "COLLNET_CHAIN";
|
|
case NCCL_ALGO_NVLS: return "NVLS";
|
|
case NCCL_ALGO_NVLS_TREE: return "NVLS_TREE";
|
|
case NCCL_ALGO_PAT: return "PAT";
|
|
default: return "Unknown";
|
|
}
|
|
}
|
|
|
|
const char* ncclProtoToString(int proto) {
|
|
switch (proto) {
|
|
case NCCL_PROTO_LL: return "LL";
|
|
case NCCL_PROTO_LL128: return "LL128";
|
|
case NCCL_PROTO_SIMPLE: return "SIMPLE";
|
|
default: return "Unknown";
|
|
}
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size_t sendcount,
|
|
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream);
|
|
ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sendcount,
|
|
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(AllGather, NcclNvtxParamsAllGather,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcount * ncclTypeSize(datatype), datatype));
|
|
|
|
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
|
|
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
|
|
ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS, nullptr };
|
|
|
|
int nRanks, rank;
|
|
int in_place = 0;
|
|
const void* srcBuf;
|
|
void* dstBuf;
|
|
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
|
NCCLCHECK(ncclCommUserRank(comm, &rank));
|
|
size_t msgSize = sendcount * ncclTypeSize(datatype) * nRanks;
|
|
|
|
if (!mscclIsCaller())
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrAllGather, info));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
|
|
}
|
|
|
|
if (rcclUseAllGatherDirect(comm, msgSize)) {
|
|
INFO(NCCL_INIT, "RCCL DIRECT ALLGATHER count = %zu, msgSize = %zu, comm = %p, stream = %p, rank = %d, sendbuff = %p, recvbuff = %p",
|
|
sendcount, msgSize, comm, stream, rank, sendbuff, recvbuff);
|
|
// use direct allgather
|
|
if (sendcount == 0) return ncclSuccess;
|
|
size_t rankOffset = sendcount * ncclTypeSize(datatype);
|
|
if (sendbuff == (((char*)recvbuff) + rank * rankOffset)) {
|
|
srcBuf = ((char*)recvbuff) + rank * rankOffset;
|
|
dstBuf = recvbuff;
|
|
in_place = 1;
|
|
} else {
|
|
srcBuf = sendbuff;
|
|
dstBuf = recvbuff;
|
|
}
|
|
|
|
NCCLCHECK(ncclGroupStart());
|
|
|
|
for (int r = 0; r < nRanks; r++) {
|
|
if (r == rank && in_place)
|
|
continue;
|
|
|
|
NCCLCHECK(ncclSend(((char*)srcBuf), sendcount, datatype, r, comm, stream));
|
|
NCCLCHECK(ncclRecv(((char*)dstBuf) + r * rankOffset, sendcount, datatype, r, comm, stream));
|
|
}
|
|
NCCLCHECK(ncclGroupEnd());
|
|
return ncclSuccess;
|
|
} else {
|
|
// use ring allgather
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
}
|
|
|
|
RCCL_PARAM(AlltoAllPivotEnable, "ALL_TO_ALL_PIVOT_ENABLE", 0);
|
|
|
|
NCCL_API(ncclResult_t, ncclAlltoAll, const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, ncclComm* comm, cudaStream_t stream);
|
|
ncclResult_t ncclAlltoAll_impl(const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, ncclComm* comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(AlltoAll, NcclNvtxParamsAlltoAll,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), datatype));
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrAllToAll, sendbuff, recvbuff, count, datatype, comm, stream));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream);
|
|
}
|
|
|
|
size_t rankOffset = count * ncclTypeSize(datatype);
|
|
size_t rankAlign = rankOffset & ((~rankOffset) + 1);
|
|
size_t msgSize = count * ncclTypeSize(datatype) * comm->nRanks;
|
|
|
|
struct ncclInfo info;
|
|
if (comm->topo->pivotA2AEnabled && comm->nChannels >= comm->topo->pivotA2ANumBiRings * 2 &&
|
|
rankOffset >= 744 * 1024 && rankAlign != 4 && rcclParamAlltoAllPivotEnable()) {
|
|
info = { ncclFuncAlltoAllPivot, "AlltoAllPivot",
|
|
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream, /* Args */
|
|
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr };
|
|
} else {
|
|
#ifdef ENABLE_ROCSHMEM
|
|
if (rcclUseAllToAllGda(comm) && msgSize <= comm->rocshmemThreshold) {
|
|
struct ncclInfo info = { ncclFuncAllToAllGda, "AllToAllGda",
|
|
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream,
|
|
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr };
|
|
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
#endif ENABLE_ROCSHMEM
|
|
info = { ncclFuncAlltoAll, "AlltoAll",
|
|
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream, /* Args */
|
|
ALLTOALL_CHUNKSTEPS, ALLTOALL_SLICESTEPS };
|
|
}
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclAlltoAllv, const void *sendbuff, const size_t sendcounts[], const size_t sdispls[],
|
|
void *recvbuff, const size_t recvcounts[], const size_t rdispls[],
|
|
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream);
|
|
ncclResult_t ncclAlltoAllv_impl(const void *sendbuff, const size_t sendcounts[], const size_t sdispls[],
|
|
void *recvbuff, const size_t recvcounts[], const size_t rdispls[],
|
|
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(AlltoAllv, NcclNvtxParamsAlltoAllv,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcounts[comm->rank] * ncclTypeSize(datatype),
|
|
recvcounts[comm->rank] * ncclTypeSize(datatype), datatype));
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrAllToAllv, sendbuff, recvbuff, 0, datatype, comm, stream, -1, sendcounts, sdispls, recvcounts, rdispls));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls,
|
|
0, datatype, 0, 0, ncclSum, mscclFuncAllToAllv, comm, stream);
|
|
}
|
|
|
|
int nRanks;
|
|
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
|
if (!mscclIsCaller()) Recorder::instance().skip(true);
|
|
NCCLCHECK(ncclGroupStart());
|
|
for (int r=0; r<nRanks; r++) {
|
|
NCCLCHECK(ncclSend(
|
|
((char*)sendbuff) + sdispls[r]*ncclTypeSize(datatype),
|
|
sendcounts[r],
|
|
datatype,
|
|
r,
|
|
comm,
|
|
stream));
|
|
NCCLCHECK(ncclRecv(
|
|
((char*)recvbuff) + rdispls[r]*ncclTypeSize(datatype),
|
|
recvcounts[r],
|
|
datatype,
|
|
r,
|
|
comm,
|
|
stream));
|
|
}
|
|
NCCLCHECK(ncclGroupEnd());
|
|
if (!mscclIsCaller()) Recorder::instance().skip(false);
|
|
return ncclSuccess;
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
|
|
ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(AllReduce, NcclNvtxParamsAllReduce,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), op, datatype));
|
|
|
|
// RCCL update slice steps for AllReduce if single node
|
|
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
|
|
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
|
|
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS, nullptr };
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrAllReduce, info));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
//MSCCL not supported for FP8 datatype
|
|
if (datatype != ncclFloat8e4m3 && datatype != ncclFloat8e5m2) {
|
|
// MSCCL threshold for Bfloat16 = 8MB
|
|
if (datatype != ncclBfloat16 || (count * ncclTypeSize(datatype) <= 8388608)) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream);
|
|
}
|
|
}
|
|
}
|
|
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
|
|
ncclResult_t ncclAllReduceWithBias_impl(const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream, const void* acc) {
|
|
NVTX3_FUNC_WITH_PARAMS(AllReduce, NcclNvtxParamsAllReduce,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), op, datatype));
|
|
|
|
if (acc == nullptr) {
|
|
WARN("ncclAllReduceWithBias : acc cannot be nullptr");
|
|
return ncclInvalidArgument;
|
|
}
|
|
|
|
// RCCL update slice steps for AllReduce if single node
|
|
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
|
|
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
|
|
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS, acc };
|
|
|
|
NCCLCHECK(Recorder::instance().record(rrAllReduceWithBias, info));
|
|
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclBroadcast, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
|
|
ncclComm_t comm, cudaStream_t stream);
|
|
ncclResult_t ncclBroadcast_impl(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
|
|
ncclComm_t comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(Broadcast, NcclNvtxParamsBroadcast,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), root, datatype));
|
|
|
|
struct ncclInfo info = { ncclFuncBroadcast, "Broadcast",
|
|
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
|
|
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS, nullptr };
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrBroadcast, info));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream);
|
|
}
|
|
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
/* Deprecated original "in place" function, similar to MPI */
|
|
NCCL_API(ncclResult_t, ncclBcast, void* buff, size_t count, ncclDataType_t datatype, int root,
|
|
ncclComm_t comm, cudaStream_t stream);
|
|
ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root,
|
|
ncclComm_t comm, cudaStream_t stream) {
|
|
NCCLCHECK(Recorder::instance().record(rrBcast, buff, buff, count, datatype, comm, stream, root));
|
|
return ncclBroadcast(buff, buff, count, datatype, root, comm, stream);
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclGather, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
|
|
ncclComm* comm, cudaStream_t stream);
|
|
ncclResult_t ncclGather_impl(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
|
|
ncclComm* comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(Gather, NcclNvtxParamsGather,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), root));
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrGather, sendbuff, recvbuff, count, datatype, comm, stream, root));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
count, datatype, root, 0, ncclSum, mscclFuncGather, comm, stream);
|
|
}
|
|
|
|
struct ncclInfo info = { ncclFuncGather, "Gather",
|
|
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
|
|
GATHER_CHUNKSTEPS, GATHER_SLICESTEPS };
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
|
|
ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(Reduce, NcclNvtxParamsReduce,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), root, op, datatype));
|
|
|
|
struct ncclInfo info = { ncclFuncReduce, "Reduce",
|
|
sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */
|
|
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS, nullptr };
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrReduce, info));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
count, datatype, root, 0, op, mscclFuncReduce, comm, stream);
|
|
}
|
|
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount,
|
|
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
|
|
ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t recvcount,
|
|
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(ReduceScatter, NcclNvtxParamsReduceScatter,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, recvcount * ncclTypeSize(datatype), op, datatype));
|
|
|
|
struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
|
|
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
|
|
REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS, nullptr };
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrReduceScatter, info));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream);
|
|
}
|
|
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclScatter, const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, int root, ncclComm* comm, cudaStream_t stream);
|
|
ncclResult_t ncclScatter_impl(const void* sendbuff, void* recvbuff, size_t count,
|
|
ncclDataType_t datatype, int root, ncclComm* comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(Scatter, NcclNvtxParamsScatter,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), root, datatype));
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrScatter, sendbuff, recvbuff, count, datatype, comm, stream, root));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
count, datatype, root, 0, ncclSum, mscclFuncScatter, comm, stream);
|
|
}
|
|
|
|
struct ncclInfo info = { ncclFuncScatter, "Scatter",
|
|
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
|
|
SCATTER_CHUNKSTEPS, SCATTER_SLICESTEPS };
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclSend, const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
|
|
ncclComm_t comm, cudaStream_t stream);
|
|
ncclResult_t ncclSend_impl(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
|
|
ncclComm_t comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(Send, NcclNvtxParamsSendRecv,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), peer, datatype));
|
|
|
|
struct ncclInfo info = { ncclFuncSend, "Send",
|
|
NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
|
|
1, 1, nullptr };
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrSend, info));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr,
|
|
count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream);
|
|
}
|
|
|
|
return ncclEnqueueCheck(&info);
|
|
}
|
|
|
|
NCCL_API(ncclResult_t, ncclRecv, void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
|
ncclComm_t comm, cudaStream_t stream);
|
|
ncclResult_t ncclRecv_impl(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
|
ncclComm_t comm, cudaStream_t stream) {
|
|
NVTX3_FUNC_WITH_PARAMS(Recv, NcclNvtxParamsSendRecv,
|
|
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), peer, datatype));
|
|
|
|
struct ncclInfo info = { ncclFuncRecv, "Recv",
|
|
NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
|
|
1, 1, nullptr };
|
|
|
|
if (!mscclIsCaller()) // when msccl falls back to
|
|
{
|
|
NCCLCHECK(Recorder::instance().record(rrRecv, info));
|
|
}
|
|
|
|
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
|
return mscclEnqueueCheck(
|
|
nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
|
count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream);
|
|
}
|
|
|
|
return ncclEnqueueCheck(&info);
|
|
}
|