77 řádky
2.1 KiB
C
77 řádky
2.1 KiB
C
|
|
/*************************************************************************
|
||
|
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||
|
|
*
|
||
|
|
* See LICENSE.txt for license information
|
||
|
|
************************************************************************/
|
||
|
|
|
||
|
|
#ifndef NCCL_CE_COLL_H_
|
||
|
|
#define NCCL_CE_COLL_H_
|
||
|
|
|
||
|
|
#include "nccl.h"
|
||
|
|
#include "nccl_common.h"
|
||
|
|
#include "bitops.h"
|
||
|
|
|
||
|
|
// Memory operations per rank for different synchronization protocols
|
||
|
|
#define NCCL_CE_SYNC_OPS_PER_RANK_MC 2
|
||
|
|
#define NCCL_CE_SYNC_OPS_PER_RANK_UC 3
|
||
|
|
|
||
|
|
struct ncclCeColl {
|
||
|
|
uint8_t* baseUCSymReadyPtr;
|
||
|
|
uint8_t* baseUCSymComplPtr;
|
||
|
|
size_t baseUCSymReadyOffset;
|
||
|
|
size_t baseUCSymComplOffset;
|
||
|
|
uint32_t ceSeqNum;
|
||
|
|
bool useCompletePtr;
|
||
|
|
uint32_t intraBatchSyncFreq;
|
||
|
|
uint64_t intraBatchSyncMsgThreshold;
|
||
|
|
struct ncclDevrWindow* ceSyncWin;
|
||
|
|
};
|
||
|
|
|
||
|
|
struct ncclCeInitTask {
|
||
|
|
struct ncclCeInitTask *next;
|
||
|
|
struct ncclComm* comm;
|
||
|
|
};
|
||
|
|
|
||
|
|
struct alignas(16) ncclCeCollArgs {
|
||
|
|
ncclFunc_t func;
|
||
|
|
int rootRank;
|
||
|
|
size_t nElts;
|
||
|
|
size_t eltSize;
|
||
|
|
uint8_t* sendBuff;
|
||
|
|
uint8_t* recvBuff;
|
||
|
|
struct ncclDevrWindow* sendWin;
|
||
|
|
struct ncclDevrWindow* recvWin;
|
||
|
|
};
|
||
|
|
|
||
|
|
struct ncclCeBatchOpsParams {
|
||
|
|
void** dsts;
|
||
|
|
void** srcs;
|
||
|
|
size_t* sizes;
|
||
|
|
size_t numOps;
|
||
|
|
bool intraBatchSync;
|
||
|
|
#if CUDART_VERSION >= 12080
|
||
|
|
cudaMemcpyAttributes* attrs;
|
||
|
|
size_t* attrIdxs;
|
||
|
|
size_t numAttrs;
|
||
|
|
#endif
|
||
|
|
};
|
||
|
|
|
||
|
|
bool ncclCeImplemented(ncclFunc_t coll, int/*ncclDevRedOp_t*/ red, ncclDataType_t ty);
|
||
|
|
|
||
|
|
ncclResult_t ncclCeInit(struct ncclComm* comm);
|
||
|
|
|
||
|
|
ncclResult_t ncclCeFinalize(struct ncclComm* comm);
|
||
|
|
|
||
|
|
ncclResult_t ncclMemOpSync(struct ncclComm* comm, cudaStream_t stream);
|
||
|
|
|
||
|
|
ncclResult_t ncclLaunchCeColl(struct ncclComm* comm, struct ncclKernelPlan* plan);
|
||
|
|
|
||
|
|
ncclResult_t ncclCeAllGather(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream);
|
||
|
|
|
||
|
|
ncclResult_t ncclCeScatter(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream);
|
||
|
|
|
||
|
|
ncclResult_t ncclCeGather(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream);
|
||
|
|
|
||
|
|
ncclResult_t ncclCeAlltoAll(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream);
|
||
|
|
#endif /* NCCL_CE_COLL_H_ */
|