Files
rocm-systems/src/include/ce_coll.h
T

77 rader
2.1 KiB
C
Normal vy Historik

2025-09-02 13:21:14 -07:00
/*************************************************************************
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef NCCL_CE_COLL_H_
#define NCCL_CE_COLL_H_
#include "nccl.h"
#include "nccl_common.h"
#include "bitops.h"
// Memory operations per rank for different synchronization protocols
#define NCCL_CE_SYNC_OPS_PER_RANK_MC 2
#define NCCL_CE_SYNC_OPS_PER_RANK_UC 3
struct ncclCeColl {
uint8_t* baseUCSymReadyPtr;
uint8_t* baseUCSymComplPtr;
size_t baseUCSymReadyOffset;
size_t baseUCSymComplOffset;
uint32_t ceSeqNum;
bool useCompletePtr;
uint32_t intraBatchSyncFreq;
uint64_t intraBatchSyncMsgThreshold;
struct ncclDevrWindow* ceSyncWin;
};
struct ncclCeInitTask {
struct ncclCeInitTask *next;
struct ncclComm* comm;
};
struct alignas(16) ncclCeCollArgs {
ncclFunc_t func;
int rootRank;
size_t nElts;
size_t eltSize;
uint8_t* sendBuff;
uint8_t* recvBuff;
struct ncclDevrWindow* sendWin;
struct ncclDevrWindow* recvWin;
};
struct ncclCeBatchOpsParams {
void** dsts;
void** srcs;
size_t* sizes;
size_t numOps;
bool intraBatchSync;
#if CUDART_VERSION >= 12080
cudaMemcpyAttributes* attrs;
size_t* attrIdxs;
size_t numAttrs;
#endif
};
bool ncclCeImplemented(ncclFunc_t coll, int/*ncclDevRedOp_t*/ red, ncclDataType_t ty);
ncclResult_t ncclCeInit(struct ncclComm* comm);
ncclResult_t ncclCeFinalize(struct ncclComm* comm);
ncclResult_t ncclMemOpSync(struct ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclLaunchCeColl(struct ncclComm* comm, struct ncclKernelPlan* plan);
ncclResult_t ncclCeAllGather(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream);
ncclResult_t ncclCeScatter(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream);
ncclResult_t ncclCeGather(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream);
ncclResult_t ncclCeAlltoAll(struct ncclComm* comm, struct ncclCeCollArgs* args, cudaStream_t stream);
#endif /* NCCL_CE_COLL_H_ */