/************************************************************************* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #ifndef NCCL_CE_COLL_H_ #define NCCL_CE_COLL_H_ #include "nccl.h" #include "nccl_common.h" #include "bitops.h" // Memory operations per rank for different synchronization protocols #define NCCL_CE_SYNC_OPS_PER_RANK_MC 2 #define NCCL_CE_SYNC_OPS_PER_RANK_UC 3 struct ncclCeColl { uint8_t* baseUCSymReadyPtr; uint8_t* baseUCSymComplPtr; size_t baseUCSymReadyOffset; size_t baseUCSymComplOffset; uint32_t ceSeqNum; bool useCompletePtr; uint32_t intraBatchSyncFreq; uint64_t intraBatchSyncMsgThreshold; struct ncclDevrWindow* ceSyncWin; }; struct ncclCeInitTask { struct ncclCeInitTask *next; struct ncclComm* comm; }; struct alignas(16) ncclCeCollArgs { ncclFunc_t func; int rootRank; size_t nElts; size_t eltSize; uint8_t* sendBuff; uint8_t* recvBuff; struct ncclDevrWindow* sendWin; struct ncclDevrWindow* recvWin; }; struct ncclCeBatchOpsParams { void** dsts; void** srcs; size_t* sizes; size_t numOps; bool intraBatchSync; #if CUDART_VERSION >= 12080 cudaMemcpyAttributes* attrs; size_t* attrIdxs; size_t numAttrs; #endif }; bool ncclCeImplemented(ncclFunc_t coll, int/*ncclDevRedOp_t*/ red, ncclDataType_t ty); ncclResult_t ncclCeInit(struct ncclComm* comm); ncclResult_t ncclCeFinalize(struct ncclComm* comm); ncclResult_t ncclMemOpSync(struct ncclComm* comm, cudaStream_t stream); ncclResult_t ncclLaunchCeColl(struct ncclComm* comm, struct ncclKernelPlan* plan); ncclResult_t ncclCeAllGather(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream); ncclResult_t ncclCeScatter(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream); ncclResult_t ncclCeGather(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream); ncclResult_t ncclCeAlltoAll(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream); #endif /* NCCL_CE_COLL_H_ */