2018-12-13 15:56:12 -08:00
/*************************************************************************
2022-01-07 06:39:55 -08:00
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
2023-02-04 01:43:38 +00:00
* Modifications Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved.
2024-03-09 07:17:53 +08:00
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
2018-12-13 15:56:12 -08:00
*
* See LICENSE.txt for license information
************************************************************************/
#include "enqueue.h"
2019-11-19 14:57:39 -08:00
#include "argcheck.h"
2020-01-16 16:02:42 -08:00
#include "coll_net.h"
2020-12-04 18:52:32 -05:00
#include "graph/topo.h"
2021-01-28 09:45:01 -07:00
#include <hip/hip_runtime.h>
#include <hip/hip_ext.h>
2021-04-12 16:00:11 -07:00
#include "gdrwrap.h"
2021-09-08 13:56:25 -07:00
#include "bootstrap.h"
2022-09-09 01:20:52 +00:00
#include <cstring>
2022-05-03 01:30:26 -07:00
#include "channel.h"
2022-09-09 01:20:52 +00:00
#include "rocmwrap.h"
2022-09-06 10:29:46 -06:00
#include "rccl_vars.h"
2024-09-10 05:57:10 -07:00
#include "profiler.h"
2023-09-26 05:47:28 -07:00
#include "transport.h"
2024-01-30 09:24:22 -07:00
#include "common.h"
2024-08-22 12:36:07 -05:00
#include "api_trace.h"
2021-09-23 09:52:42 -07:00
#include <cstring> // std::memcpy
2022-05-24 02:02:31 -07:00
#include <cinttypes> // PRIx64
2024-12-18 08:26:06 -08:00
#include <cassert>
2022-05-24 02:02:31 -07:00
2025-04-19 00:21:27 -04:00
using namespace rccl ;
2022-08-18 02:53:17 -07:00
struct ncclKernelMatch {
void * kernelFn ;
bool specialized ;
};
2018-12-13 15:56:12 -08:00
2023-06-21 16:16:09 -04:00
#ifdef ENABLE_COLLTRACE
2025-04-30 23:33:08 -05:00
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 2 : 0))
static ncclKernelMatch const ncclKerns [ 4 ] = {
{( void * ) ncclDevKernel_Generic , true },
2024-11-12 18:27:29 -05:00
{( void * ) ncclDevKernel_Generic_4 , true },
2025-04-30 23:33:08 -05:00
{( void * ) ncclDevKernelDebug_Generic , true },
2024-11-12 18:27:29 -05:00
{( void * ) ncclDevKernelDebug_Generic_4 , true }
2018-12-13 15:56:12 -08:00
};
2023-06-21 16:16:09 -04:00
#else
2024-11-12 18:27:29 -05:00
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll)
2025-04-30 23:33:08 -05:00
static ncclKernelMatch const ncclKerns [ 2 ] = {
{( void * ) ncclDevKernel_Generic , true },
2024-11-12 18:27:29 -05:00
{( void * ) ncclDevKernel_Generic_4 , true }
2023-06-21 16:16:09 -04:00
};
#endif
2018-12-13 15:56:12 -08:00
2023-02-27 02:48:21 -08:00
NCCL_PARAM ( L1SharedMemoryCarveout , "L1_SHARED_MEMORY_CARVEOUT" , 0 );
2021-04-12 16:00:11 -07:00
2023-02-27 02:48:21 -08:00
// Returns maximum kernel stack size of all CUDA kernels
ncclResult_t ncclInitKernelsForDevice ( int cudaArch , size_t * maxStackSize ) {
constexpr int KernelCount = sizeof ( ncclKerns ) / sizeof ( ncclKerns [ 0 ]);
ncclResult_t result = ncclSuccess ;
2022-09-08 14:45:27 -07:00
2023-02-27 02:48:21 -08:00
if ( maxStackSize ) * maxStackSize = 0 ;
int carveout = ncclParamL1SharedMemoryCarveout ();
// Keep track if we already visited a function pointer.
void * lru [ 2 ] = { nullptr , nullptr };
for ( int i = 0 ; i < KernelCount ; i ++ ) {
void * fn = ncclKerns [ i ]. kernelFn ;
if ( fn == lru [ 0 ] || fn == lru [ 1 ]) goto next_kernel ;
lru [ 1 ] = lru [ 0 ];
lru [ 0 ] = fn ;
if ( maxStackSize ) {
cudaFuncAttributes attr = { 0 };
2024-05-10 07:31:12 -07:00
if ( cudaFuncGetAttributes ( & attr , fn ) != cudaSuccess )
WARN ( "Failed to get kernel attributes" );
2023-02-27 02:48:21 -08:00
if ( attr . localSizeBytes > * maxStackSize ) * maxStackSize = attr . localSizeBytes ;
ignore0 :;
}
2022-09-08 14:45:27 -07:00
2023-02-27 02:48:21 -08:00
if ( carveout ) {
CUDACHECKGOTO ( cudaFuncSetAttribute ( fn ,
cudaFuncAttributePreferredSharedMemoryCarveout , carveout ),
result , ignore1 );
ignore1 :;
}
2022-09-08 14:45:27 -07:00
2023-02-27 02:48:21 -08:00
if ( ncclShmemDynamicSize ( cudaArch ) != 0 ) {
CUDACHECKGOTO ( cudaFuncSetAttribute ( fn ,
cudaFuncAttributeMaxDynamicSharedMemorySize , ncclShmemDynamicSize ( cudaArch )),
result , next_kernel );
}
next_kernel :;
2022-03-30 02:25:49 -07:00
}
2023-02-27 02:48:21 -08:00
return result ;
2022-03-30 02:25:49 -07:00
}
2024-06-11 01:28:01 -07:00
////////////////////////////////////////////////////////////////////////////////
// Data movement metrics.
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
static inline int ncclFuncTrafficPerByte ( ncclFunc_t func , int nRanks ) {
switch ( func ) {
case ncclFuncAllReduce : return 2 ;
case ncclFuncAllGather : return nRanks ;
case ncclFuncReduceScatter : return nRanks ;
default : return 1 ;
2018-12-13 15:56:12 -08:00
}
}
2024-06-11 01:28:01 -07:00
/*****************************************************************************/
/* Launch system : synchronization and CUDA kernel launch */
/*****************************************************************************/
2022-05-24 02:02:31 -07:00
static ncclResult_t addProxyOpIfNeeded ( struct ncclComm * comm , struct ncclKernelPlan * plan , struct ncclProxyOp * op ) {
bool needed = true ;
NCCLCHECK ( ncclProxySaveOp ( comm , op , & needed ));
if ( needed ) {
struct ncclProxyOp * q = ncclMemoryPoolAlloc < struct ncclProxyOp > ( & comm -> memPool_ncclProxyOp , & comm -> memPermanent );
* q = * op ; // C++ struct assignment
2024-06-11 01:28:01 -07:00
ncclIntruQueueEnqueue ( & comm -> planner . wipPlan . channels [ op -> channelId ]. proxyOpQueue , q );
2024-02-05 05:06:02 -08:00
}
return ncclSuccess ;
}
2024-06-11 01:28:01 -07:00
static void addWorkBatchToPlan (
struct ncclComm * comm , struct ncclKernelPlan * plan , int channelId ,
enum ncclDevWorkType workType , int devFuncId , uint32_t workOffset ,
int p2pRound = - 1
2024-02-05 05:06:02 -08:00
) {
2024-06-11 01:28:01 -07:00
ncclKernelPlanner :: WipPlan :: Channel * chan = & comm -> planner . wipPlan . channels [ channelId ];
size_t workSize = ncclDevWorkSize ( workType );
// Conditions causing us to create a new blank batch.
bool newBatch = ( chan -> workBatchQueue . tail == nullptr );
struct ncclDevWorkBatch * batch = nullptr ;
if ( ! newBatch ) {
batch = & chan -> workBatchQueue . tail -> batch ;
// All of the conditions that prevent us from appending to current batch.
newBatch |= batch -> workType != ( uint8_t ) workType ;
newBatch |= batch -> funcId != devFuncId ;
// The following ensure the device can handle a batch this large. They have to
// account for all extension batches being fused together which is why
// wipBatch.workBytes and wipBatch.nP2ps aren't reset to 0 for a new extension
// batch further down.
newBatch |= NCCL_MAX_DEV_WORK_BATCH_BYTES < chan -> wipBatch . workBytes + workSize ;
if ( workType == ncclDevWorkTypeP2p ) {
newBatch |= chan -> wipBatch . nP2ps == NCCL_MAX_DEV_WORK_P2P_PER_BATCH ;
for ( int i = 0 ; i < chan -> wipBatch . nP2ps ; i ++ ) {
newBatch |= p2pRound == chan -> wipBatch . p2pRounds [ i ];
2022-05-24 02:02:31 -07:00
}
2018-12-13 15:56:12 -08:00
}
}
2024-06-11 01:28:01 -07:00
// Conditions causing us to create an extension batch (prev->nextExtends=1)
uint32_t offset = newBatch ? 0 : ( workOffset - batch -> offsetBase );
bool extendBatch = 63 * workSize < offset ;
extendBatch |= 0 != offset % workSize ;
if ( newBatch || extendBatch ) {
if ( ! newBatch ) batch -> nextExtends = extendBatch ; // Extending the previous batch.
struct ncclWorkBatchList * batchNode = ncclMemoryStackAlloc < ncclWorkBatchList > ( & comm -> memScoped );
2024-09-10 05:57:10 -07:00
// Coverity thinks that ncclIntruQueueEnqueue will access chan->workBatchQueue->tail, which might
// be NULL. But that code is guarded by chan->workBatchQueue->head not being NULL, in which
// case tail won't be NULL either.
// coverity[var_deref_model:FALSE]
2024-06-11 01:28:01 -07:00
ncclIntruQueueEnqueue ( & chan -> workBatchQueue , batchNode );
batch = & batchNode -> batch ;
batch -> nextExtends = 0 ;
batch -> workType = ( uint32_t ) workType ;
batch -> funcId = devFuncId ;
batch -> offsetBase = workOffset ;
batch -> offsetBitset = 0 ;
offset = 0 ;
if ( newBatch ) {
// Since extension batches are fused together on the device, and these values
// account for constraints on the fused batch, we only reset the values on
// a new batch
chan -> wipBatch . workBytes = 0 ;
chan -> wipBatch . nP2ps = 0 ;
// We don't count extension batches since this is used to derive a proxyOpCount,
// and we wan't all ops which are fused together to have the same value.
chan -> nWorkBatchesP2p += ( workType == ncclDevWorkTypeP2p ? 1 : 0 );
2024-02-05 05:06:02 -08:00
}
2024-06-11 01:28:01 -07:00
plan -> nWorkBatches += 1 ;
2024-02-05 05:06:02 -08:00
}
2024-06-11 01:28:01 -07:00
batch -> offsetBitset |= 1ull << ( offset / workSize );
chan -> wipBatch . workBytes += workSize ;
if ( workType == ncclDevWorkTypeP2p ) {
// We need to ensure that a single batch doesn't have multiple p2p's
// of the same round since they would use the same connections.
chan -> wipBatch . p2pRounds [ chan -> wipBatch . nP2ps ++ ] = p2pRound ;
2024-02-05 05:06:02 -08:00
}
}
2024-06-11 01:28:01 -07:00
static void finishPlan ( struct ncclComm * comm , struct ncclKernelPlan * plan ) {
ncclKernelPlanner :: WipPlan :: Channel * wipChannels = comm -> planner . wipPlan . channels ;
size_t workBytes = plan -> workBytes ;
size_t batchBytes = plan -> nWorkBatches * sizeof ( struct ncclDevWorkBatch );
plan -> threadPerBlock = std :: max ( plan -> threadPerBlock , NCCL_MIN_NTHREADS );
// If we can fit everything into the kernel args we do so.
if ( sizeof ( ncclDevKernelArgs ) + batchBytes + workBytes <= comm -> workArgsBytes ) {
plan -> workStorageType = ncclDevWorkStorageTypeArgs ;
}
plan -> kernelArgsSize = sizeof ( struct ncclDevKernelArgs ) + batchBytes ;
plan -> kernelArgsSize += ( plan -> workStorageType == ncclDevWorkStorageTypeArgs ) ? workBytes : 0 ;
plan -> kernelArgsSize = alignUp ( plan -> kernelArgsSize , 16 );
plan -> kernelArgs = ( struct ncclDevKernelArgs * ) ncclMemoryStackAlloc ( & comm -> memScoped , plan -> kernelArgsSize , /*align=*/ 16 );
plan -> kernelArgs -> comm = comm -> devComm ;
plan -> kernelArgs -> channelMask = plan -> channelMask ;
plan -> kernelArgs -> workStorageType = plan -> workStorageType ;
// Put batches into the kernel arguments. The first batch for each channel
// must be located at batchZero[blockIdx.x]. To achieve this we round robin
// over the channels in ascending order until they're exhausted.
2025-01-23 11:48:18 -06:00
struct channelMasks hasBatchMask = plan -> channelMask ;
2024-06-11 01:28:01 -07:00
struct ncclDevWorkBatch * batchPrev [ MAXCHANNELS ] = {}; // {0...}
struct ncclDevWorkBatch * batchZero = ( struct ncclDevWorkBatch * )( plan -> kernelArgs + 1 );
int batchIx = 0 ;
2025-01-23 11:48:18 -06:00
for ( int maskIdx = 0 ; maskIdx < MAXCHANNELS / 64 ; maskIdx ++ ) {
while ( hasBatchMask . masks [ maskIdx ] != 0 ) {
uint64_t tmpMask = hasBatchMask . masks [ maskIdx ]; // channels with a batch for this round.
do {
int c = popFirstOneBit ( & tmpMask ) + maskIdx * 64 ;
if ( ! ncclIntruQueueEmpty ( & wipChannels [ c ]. workBatchQueue )) {
struct ncclWorkBatchList * batchNode = ncclIntruQueueDequeue ( & wipChannels [ c ]. workBatchQueue );
if ( batchPrev [ c ] != nullptr ) {
batchPrev [ c ] -> nextJump = int ( & batchZero [ batchIx ] - batchPrev [ c ]);
}
batchPrev [ c ] = & batchZero [ batchIx ];
batchZero [ batchIx ++ ] = batchNode -> batch ;
2024-06-11 01:28:01 -07:00
}
2025-01-23 11:48:18 -06:00
if ( ncclIntruQueueEmpty ( & wipChannels [ c ]. workBatchQueue )) {
hasBatchMask . masks [ maskIdx ] ^= 1ull << ( c % 64 );
}
} while ( tmpMask != 0 );
2021-04-12 16:00:11 -07:00
}
2024-02-05 05:06:02 -08:00
}
2024-06-11 01:28:01 -07:00
// Merge-sort per-channel proxy-op lists by opCount when merging them into plan->proxyOpQueue
// Phase 1: scan first op of each channel, store opCount in headIds[c].
uint64_t headIds [ MAXCHANNELS ];
int nHeads = 0 ;
2022-05-24 02:02:31 -07:00
int channelUbound = 0 ;
for ( int c = 0 ; c < MAXCHANNELS ; c ++ ) {
2024-06-11 01:28:01 -07:00
struct ncclProxyOp * op = ncclIntruQueueHead ( & wipChannels [ c ]. proxyOpQueue );
headIds [ c ] = op ? op -> opCount : uint64_t ( - 1 );
if ( op ) nHeads += 1 ;
if ( op ) plan -> hasProxyOps = true ;
if ( op ) channelUbound = c + 1 ;
}
// Phase 2: Dequeue from planner->channels[c], enqueue in merged order to plan
while ( nHeads != 0 ) {
int c = - 1 ;
uint64_t minId = uint64_t ( - 1 );
// Find channel with least proxy-op id. We store the heads[c]->opCount in
// headIds[c] to remove indirect loads from this loop.
for ( int c1 = 0 ; c1 < channelUbound ; c1 ++ ) {
uint64_t id = headIds [ c1 ];
id = ( id >> 1 | id << 63 ); // Move tag bit to order collectives before p2p's
if ( id < minId ) { c = c1 ; minId = id ; }
2018-12-13 15:56:12 -08:00
}
2024-06-11 01:28:01 -07:00
struct ncclProxyOp * op = ncclIntruQueueDequeue ( & wipChannels [ c ]. proxyOpQueue );
struct ncclProxyOp * opNext = ncclIntruQueueHead ( & wipChannels [ c ]. proxyOpQueue );
headIds [ c ] = opNext ? opNext -> opCount : uint64_t ( - 1 );
nHeads -= opNext ? 0 : 1 ;
ncclIntruQueueEnqueue ( & plan -> proxyOpQueue , op );
2018-12-13 15:56:12 -08:00
}
}
2023-09-26 05:47:28 -07:00
NCCL_PARAM ( GraphRegister , "GRAPH_REGISTER" , 1 );
2024-06-11 01:28:01 -07:00
static ncclResult_t getCollNetSupport ( struct ncclComm * comm , struct ncclTaskColl * task , int * collNetSupport );
2025-04-23 15:44:56 -04:00
rccl_static ncclResult_t getAlgoInfo (
2024-06-11 01:28:01 -07:00
struct ncclComm * comm , struct ncclTaskColl * task ,
int collNetSupport , int nvlsSupport , int numPipeOps , ncclSimInfo_t * simInfo = NULL
);
static ncclResult_t calcCollChunking (
struct ncclComm * comm , struct ncclTaskColl * task , int nChannels , size_t nBytes ,
/*outputs*/ uint32_t * outChunkSize , uint32_t * outDirectFlags , struct ncclProxyOp * proxyOp
);
struct ncclKernelPlanBudget {
ssize_t inArgsBytes ; // Space available within kernel args struct
ssize_t outArgsBytes ; // Space available outside of args struct (fifo or persistent buf)
};
static bool testBudget (
struct ncclKernelPlanBudget * budget , int nWorkBatches , ssize_t workBytes
) {
ssize_t batchBytes = nWorkBatches * sizeof ( struct ncclDevWorkBatch );
bool ok = false ;
ok |= ( batchBytes + workBytes <= budget -> inArgsBytes );
ok |= ( batchBytes <= budget -> inArgsBytes ) && ( workBytes <= budget -> outArgsBytes );
return ok ;
}
2022-05-24 02:02:31 -07:00
2024-12-18 08:26:06 -08:00
ncclResult_t ncclTasksRegAndEnqueue ( struct ncclComm * comm ) {
struct ncclKernelPlanner * planner = & comm -> planner ;
struct ncclTaskColl * task ;
task = ncclIntruQueueHead ( & planner -> collTaskQueue );
while ( task != nullptr ) {
// Build a ncclDevWorkColl[Reg?] struct for each task.
void * regBufSend [ NCCL_MAX_LOCAL_RANKS ];
void * regBufRecv [ NCCL_MAX_LOCAL_RANKS ];
bool regNeedConnect = true ;
struct ncclWorkList * workNode = NULL ;
struct ncclDevWorkColl devWork = {};
if ( task -> algorithm == NCCL_ALGO_NVLS_TREE || task -> algorithm == NCCL_ALGO_NVLS ) {
workNode = ncclIntruQueueDequeue ( & planner -> tmpCollWorkQueue );
goto next ;
}
ncclRegisterCollBuffers ( comm , task , regBufSend , regBufRecv , & planner -> collCleanupQueue , & regNeedConnect );
devWork . sendbuff = ( void * ) task -> sendbuff ;
devWork . recvbuff = ( void * ) task -> recvbuff ;
devWork . sendbuffOffset = task -> sendbuffOffset ;
devWork . recvbuffOffset = task -> recvbuffOffset ;
devWork . sendbuffRmtAddrs = task -> sendbuffRmtAddrs ;
devWork . recvbuffRmtAddrs = task -> recvbuffRmtAddrs ;
devWork . root = task -> root ;
devWork . nWarps = task -> nWarps ;
devWork . redOpArg = task -> opDev . scalarArg ;
devWork . redOpArgIsPtr = task -> opDev . scalarArgIsPtr ;
devWork . oneNode = ( comm -> nNodes == 1 );
devWork . isOneRPN = comm -> isOneRPN ;
devWork . netRegUsed = devWork . regUsed = 0 ;
if ( task -> regBufType & NCCL_NET_REG_BUFFER )
devWork . netRegUsed = 1 ;
if ( task -> regBufType & ( NCCL_IPC_REG_BUFFER | NCCL_NVLS_REG_BUFFER ))
devWork . regUsed = 1 ;
if ( task -> regBufType & NCCL_NVLS_REG_BUFFER ) {
struct ncclDevWorkCollReg workReg = {};
workReg . coll = devWork ; // C++ struct assignment
/* NVLS only has one send and recv buffer registered */
workReg . dnInputs [ 0 ] = regBufSend [ 0 ];
workReg . dnOutputs [ 0 ] = regBufRecv [ 0 ];
workNode = ncclMemoryStackAllocInlineArray < ncclWorkList , ncclDevWorkCollReg > ( & comm -> memScoped , 1 );
workNode -> workType = ncclDevWorkTypeCollReg ;
workNode -> size = sizeof ( struct ncclDevWorkCollReg );
memcpy (( void * )( workNode + 1 ), ( void * ) & workReg , workNode -> size );
} else {
workNode = ncclMemoryStackAllocInlineArray < ncclWorkList , ncclDevWorkColl > ( & comm -> memScoped , 1 );
workNode -> workType = ncclDevWorkTypeColl ;
workNode -> size = sizeof ( struct ncclDevWorkColl );
memcpy (( void * )( workNode + 1 ), ( void * ) & devWork , workNode -> size );
}
next :
ncclIntruQueueEnqueue ( & planner -> collWorkQueue , workNode );
task = task -> next ;
}
assert ( ncclIntruQueueEmpty ( & planner -> tmpCollWorkQueue ));
return ncclSuccess ;
}
2024-06-11 01:28:01 -07:00
// Called once per ncclGroup to organize the user submitted tasks in
// comm->planner so that they can be peeled off into plans.
ncclResult_t ncclPrepareTasks ( struct ncclComm * comm , bool * algoNeedConnect , bool * needConnect , ncclSimInfo_t * simInfo ) {
struct ncclKernelPlanner * planner = & comm -> planner ;
2024-12-18 08:26:06 -08:00
planner -> persistent = ncclCudaGraphValid ( planner -> capturingGraph );
2024-06-11 01:28:01 -07:00
// Tasks from the sorter come out ordered size descending.
struct ncclTaskColl * task = ncclTaskCollSorterDequeueAll ( & planner -> collSorter );
// Tasks are assembled by (fn,op,ty) size ascending.
struct ncclTaskColl * tasksByFnOpTy [ ncclNumFuncs * ncclNumDevRedOps * ncclNumTypes ];
memset ( tasksByFnOpTy , 0 , sizeof ( tasksByFnOpTy ));
int fnOpTyIndices [ ncclNumFuncs * ncclNumDevRedOps * ncclNumTypes ];
int fnOpTyCount = 0 ;
// Walk the size sorted tasks, binning them by (fn,op,ty).
while ( task != nullptr ) {
struct ncclTaskColl * next = task -> next ;
int index = (( int ) task -> func * ncclNumDevRedOps + ( int ) task -> opDev . op ) * ncclNumTypes + ( int ) task -> datatype ;
// Add to set of (fn,op,ty) indices on first occurrence
if ( tasksByFnOpTy [ index ] == nullptr ) fnOpTyIndices [ fnOpTyCount ++ ] = index ;
// Add to LIFO for this (fn,op,ty)
task -> next = tasksByFnOpTy [ index ];
tasksByFnOpTy [ index ] = task ;
// Next task
task = next ;
}
// Walk (fn,op,ty) bins, compute algo and proto etc. Then bin them by their
// scheduling constraints (collnet x nvls).
struct ncclIntruQueue < struct ncclTaskColl , & ncclTaskColl :: next > collBins [ 2 ][ 2 ] = {};
for ( int cursor = 0 ; cursor < fnOpTyCount ; cursor ++ ) {
struct ncclTaskColl * aggBeg = tasksByFnOpTy [ fnOpTyIndices [ cursor ]];
int collNetSupport = 0 ;
NCCLCHECK ( getCollNetSupport ( comm , aggBeg , & collNetSupport ));
int nvlsSupport = comm -> nvlsSupport && ( ncclNvlsSupported ( aggBeg -> opDev . op , aggBeg -> datatype ) || aggBeg -> func == ncclFuncAllGather );
// Crudely estimate number of tasks per channel. This is using the wrong number
// of channels for NVLS algos, but knowing the algo requires having this value,
// so either be crude our iterate until fixed point, we chose the former.
int nTasksPerChannel = divUp ( comm -> planner . nTasksColl , comm -> nChannels );
do {
struct ncclTaskColl * aggEnd = aggBeg -> next ;
struct ncclTaskColl agg = * aggBeg ;
// We aggregate operations that are within 4X size of each other.
while ( aggEnd != nullptr && aggEnd -> trafficBytes < 4 * aggBeg -> trafficBytes ) {
agg . count += aggEnd -> count ;
agg . trafficBytes += aggEnd -> trafficBytes ;
aggEnd = aggEnd -> next ;
}
NCCLCHECK ( getAlgoInfo ( comm , & agg , collNetSupport , nvlsSupport , nTasksPerChannel , simInfo ));
agg . devFuncId = ncclDevFuncId ( agg . func , agg . opDev . op , agg . datatype , agg . algorithm , agg . protocol );
2025-01-23 11:48:18 -06:00
if ( agg . devFuncId < 0 ) {
WARN ( "%s: unsupported collective. Please ensure the collective has been enabled in build." , __func__ );
return ncclInvalidUsage ;
}
2024-06-11 01:28:01 -07:00
int isCollnet = 0 , isNvls = 0 ;
switch ( agg . algorithm ) {
case NCCL_ALGO_NVLS :
case NCCL_ALGO_NVLS_TREE :
isNvls = 1 ;
isCollnet = agg . algorithm == NCCL_ALGO_NVLS && comm -> nNodes > 1 ;
break ;
case NCCL_ALGO_COLLNET_CHAIN :
case NCCL_ALGO_COLLNET_DIRECT :
isCollnet = 1 ;
break ;
}
// Update the aggregated tasks with the computed values.
do {
struct ncclTaskColl * next = aggBeg -> next ;
aggBeg -> algorithm = agg . algorithm ;
aggBeg -> protocol = agg . protocol ;
aggBeg -> nMaxChannels = agg . nMaxChannels ;
aggBeg -> nWarps = agg . nWarps ;
aggBeg -> devFuncId = agg . devFuncId ;
aggBeg -> isCollnet = isCollnet ;
aggBeg -> isNvls = isNvls ;
ncclIntruQueueEnqueue ( & collBins [ isCollnet ][ isNvls ], aggBeg );
aggBeg = next ;
} while ( aggBeg != aggEnd );
} while ( aggBeg != nullptr );
}
// Concatenate `collBins[*][*]` together into final list `planner->collTaskQueue`.
// Collnet is the outer dimension since that affects how we divide over the
// channels.
for ( int isCollnet = 0 ; isCollnet <= 1 ; isCollnet ++ ) {
for ( int isNvls = 0 ; isNvls <= 1 ; isNvls ++ ) {
ncclIntruQueueTransfer ( & planner -> collTaskQueue , & collBins [ isCollnet ][ isNvls ]);
}
2022-05-24 02:02:31 -07:00
}
2024-06-11 01:28:01 -07:00
// Walk tasks again to:
// 1. Possibly register buffers.
// 2. Build ncclDevWorkColl structs.
// 3. Bin the work structs according to the number of valid channels they
// may be assigned to {collnet, nvls, standard}
task = ncclIntruQueueHead ( & planner -> collTaskQueue );
while ( task != nullptr ) {
// Build a ncclDevWorkColl[Reg?] struct for each task.
void * regBufSend [ NCCL_MAX_LOCAL_RANKS ];
void * regBufRecv [ NCCL_MAX_LOCAL_RANKS ];
bool regNeedConnect = true ;
2024-12-18 08:26:06 -08:00
ncclRegisterCollNvlsBuffers ( comm , task , regBufSend , regBufRecv , & planner -> collCleanupQueue , & regNeedConnect );
2024-06-11 01:28:01 -07:00
if ( comm -> runtimeConn && comm -> initAlgoChannels [ task -> algorithm ] == false ) {
if ( task -> algorithm == NCCL_ALGO_NVLS_TREE && comm -> initAlgoChannels [ NCCL_ALGO_NVLS ] == false && regNeedConnect == true ) {
comm -> initAlgoChannels [ NCCL_ALGO_NVLS ] = true ;
algoNeedConnect [ NCCL_ALGO_NVLS ] = true ;
}
if ( task -> algorithm != NCCL_ALGO_NVLS || regNeedConnect == true ) {
comm -> initAlgoChannels [ task -> algorithm ] = true ;
algoNeedConnect [ task -> algorithm ] = true ;
* needConnect = true ;
}
}
2024-12-18 08:26:06 -08:00
if ( task -> algorithm == NCCL_ALGO_NVLS_TREE || task -> algorithm == NCCL_ALGO_NVLS ) {
struct ncclDevWorkColl devWork = {};
devWork . sendbuff = ( void * ) task -> sendbuff ;
devWork . recvbuff = ( void * ) task -> recvbuff ;
devWork . sendbuffOffset = task -> sendbuffOffset ;
devWork . recvbuffOffset = task -> recvbuffOffset ;
devWork . sendbuffRmtAddrs = task -> sendbuffRmtAddrs ;
devWork . recvbuffRmtAddrs = task -> recvbuffRmtAddrs ;
devWork . root = task -> root ;
devWork . nWarps = task -> nWarps ;
devWork . redOpArg = task -> opDev . scalarArg ;
devWork . redOpArgIsPtr = task -> opDev . scalarArgIsPtr ;
devWork . oneNode = ( comm -> nNodes == 1 );
devWork . netRegUsed = devWork . regUsed = 0 ;
if ( task -> regBufType & NCCL_NET_REG_BUFFER )
devWork . netRegUsed = 1 ;
if ( task -> regBufType & ( NCCL_IPC_REG_BUFFER | NCCL_NVLS_REG_BUFFER ))
devWork . regUsed = 1 ;
2025-04-23 20:46:36 -07:00
devWork . pivotA2ANumBiRings = comm -> topo -> pivotA2ANumBiRings ;
devWork . opCount = task -> opCount ;
2024-12-18 08:26:06 -08:00
struct ncclWorkList * workNode ;
if ( task -> regBufType & NCCL_NVLS_REG_BUFFER ) {
struct ncclDevWorkCollReg workReg = {};
2024-06-11 01:28:01 -07:00
workReg . coll = devWork ; // C++ struct assignment
/* NVLS only has one send and recv buffer registered */
workReg . dnInputs [ 0 ] = regBufSend [ 0 ];
workReg . dnOutputs [ 0 ] = regBufRecv [ 0 ];
workNode = ncclMemoryStackAllocInlineArray < ncclWorkList , ncclDevWorkCollReg > ( & comm -> memScoped , 1 );
workNode -> workType = ncclDevWorkTypeCollReg ;
workNode -> size = sizeof ( struct ncclDevWorkCollReg );
2024-12-18 08:26:06 -08:00
memcpy (( void * )( workNode + 1 ), ( void * ) & workReg , workNode -> size );
} else {
workNode = ncclMemoryStackAllocInlineArray < ncclWorkList , ncclDevWorkColl > ( & comm -> memScoped , 1 );
workNode -> workType = ncclDevWorkTypeColl ;
workNode -> size = sizeof ( struct ncclDevWorkColl );
memcpy (( void * )( workNode + 1 ), ( void * ) & devWork , workNode -> size );
}
2024-06-11 01:28:01 -07:00
2024-12-18 08:26:06 -08:00
ncclIntruQueueEnqueue ( & planner -> tmpCollWorkQueue , workNode );
}
2024-06-11 01:28:01 -07:00
task = task -> next ;
2018-12-13 15:56:12 -08:00
}
2024-02-05 05:06:02 -08:00
return ncclSuccess ;
}
2022-05-24 02:02:31 -07:00
2025-01-23 11:48:18 -06:00
RCCL_PARAM ( IntraNetThreshold , "INTRANET_THRESHOLD" , 8388608 );
2024-02-05 05:06:02 -08:00
static ncclResult_t scheduleCollTasksToPlan (
2024-06-11 01:28:01 -07:00
struct ncclComm * comm , struct ncclKernelPlan * plan , struct ncclKernelPlanBudget * budget
2024-02-05 05:06:02 -08:00
) {
2024-06-11 01:28:01 -07:00
struct ncclKernelPlanner * planner = & comm -> planner ;
// Estimate number of tasks that will fit in this plan.
int nPlanColls = 0 ;
size_t trafficBytes [ 2 * 2 ] = { 0 , 0 , 0 , 0 }; // [collnet][nvls]
int nChannels [ 2 * 2 ] = { 0 , 0 , 0 , 0 }; // [collnet][nvls]
int const nMaxChannels [ 2 * 2 ] = { comm -> nChannels , comm -> nvlsChannels , // [collnet][nvls]
comm -> nChannels , comm -> nvlsChannels };
2025-03-27 12:51:55 -05:00
constexpr size_t MinTrafficPerChannel = 512 ; // Traffic as minimal
2024-06-11 01:28:01 -07:00
do {
size_t workBytes = 0 ;
struct ncclTaskColl * task = ncclIntruQueueHead ( & planner -> collTaskQueue );
struct ncclWorkList * workNode = ncclIntruQueueHead ( & planner -> collWorkQueue );
while ( task != nullptr ) {
int nBatches = divUp ( nPlanColls , 4 ); // Rough guess: 4 colls per batch.
if ( ! testBudget ( budget , nBatches , workBytes + workNode -> size )) goto plan_full ;
nPlanColls += 1 ;
workBytes += workNode -> size ;
int kind = 2 * task -> isCollnet + task -> isNvls ;
2024-09-10 05:57:10 -07:00
trafficBytes [ kind ] += std :: max ( MinTrafficPerChannel , task -> trafficBytes );
2024-06-11 01:28:01 -07:00
nChannels [ kind ] += task -> nMaxChannels ;
nChannels [ kind ] = std :: min ( nChannels [ kind ], nMaxChannels [ kind ]);
task = task -> next ;
workNode = workNode -> next ;
}
plan_full :;
} while ( 0 );
int kindPrev = - 1 ;
size_t trafficPerChannel = 0 ;
int channelId = 0 ;
size_t currentTraffic = 0 ;
while ( nPlanColls != 0 && ! ncclIntruQueueEmpty ( & planner -> collTaskQueue )) {
struct ncclTaskColl * task = ncclIntruQueueHead ( & planner -> collTaskQueue );
struct ncclWorkList * workNode = ncclIntruQueueHead ( & planner -> collWorkQueue );
struct ncclDevWorkColl * devWork = ( struct ncclDevWorkColl * )( workNode + 1 );
size_t elementSize = ncclTypeSize ( task -> datatype );
int kind = 2 * task -> isCollnet + task -> isNvls ;
if ( kind != kindPrev ) {
trafficPerChannel = std :: max < size_t > ( MinTrafficPerChannel , trafficBytes [ kind ] / nChannels [ kind ]);
kindPrev = kind ;
channelId = 0 ;
currentTraffic = 0 ;
}
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
if ( task -> isCollnet ) {
int nChannels = task -> nMaxChannels ;
// Ensure room for worst case of one new batch per channel
if ( ! testBudget ( budget , plan -> nWorkBatches + nChannels , plan -> workBytes + workNode -> size )) {
return ncclSuccess ;
}
2024-02-05 05:06:02 -08:00
2024-06-11 01:28:01 -07:00
size_t globalBytesPerElement = elementSize * ncclFuncMaxSendRecvCount ( task -> func , comm -> nRanks , 1 );
struct ncclProxyOp proxyOp ;
uint32_t chunkSize , directFlags = 0 ;
2025-01-23 11:48:18 -06:00
size_t nBytes = globalBytesPerElement * task -> count ;
NCCLCHECK ( calcCollChunking ( comm , task , nChannels , nBytes , & chunkSize , & directFlags , & proxyOp ));
2024-06-11 01:28:01 -07:00
devWork -> channelLo = 0 ;
devWork -> channelHi = nChannels - 1 ;
devWork -> collnet . count = task -> count ;
devWork -> collnet . chunkCount = chunkSize / ncclTypeSize ( task -> datatype );
devWork -> direct = directFlags ;
uint64_t proxyOpId = uint64_t ( plan -> collOpCount ++ ) << 1 | 0 ;
for ( int c = devWork -> channelLo ; c <= ( int ) devWork -> channelHi ; c ++ ) {
proxyOp . channelId = c ;
proxyOp . opCount = proxyOpId ;
2024-09-10 05:57:10 -07:00
proxyOp . task . coll = task ;
proxyOp . rank = comm -> rank ;
2024-06-11 01:28:01 -07:00
addWorkBatchToPlan ( comm , plan , c , workNode -> workType , task -> devFuncId , plan -> workBytes );
NCCLCHECK ( addProxyOpIfNeeded ( comm , plan , & proxyOp ));
}
} else { // not task->isCollnet
2024-09-10 05:57:10 -07:00
int trafficPerByte = ncclFuncTrafficPerByte ( task -> func , comm -> nRanks );
size_t cellSize = divUp ( divUp ( MinTrafficPerChannel , ( size_t ) trafficPerByte ), 16 ) * 16 ;
2024-06-11 01:28:01 -07:00
int elementsPerCell = cellSize / elementSize ;
size_t cells = divUp ( task -> count * elementSize , cellSize );
size_t trafficPerElement = elementSize * trafficPerByte ;
size_t trafficPerCell = cellSize * trafficPerByte ;
size_t cellsPerChannel = std :: min ( cells , divUp ( trafficPerChannel , trafficPerCell ));
size_t cellsLo ;
if ( channelId + 1 == nMaxChannels [ kind ]) { // On last channel everything goes to "lo"
cellsLo = cells ;
2024-02-05 05:06:02 -08:00
} else {
2024-09-10 05:57:10 -07:00
cellsLo = std :: min ( cells , divUp (( trafficPerChannel - currentTraffic ), trafficPerCell ));
2024-06-11 01:28:01 -07:00
}
int nMidChannels = ( cells - cellsLo ) / cellsPerChannel ;
size_t cellsHi = ( cells - cellsLo ) % cellsPerChannel ;
int nChannels = ( cellsLo != 0 ? 1 : 0 ) + nMidChannels + ( cellsHi != 0 ? 1 : 0 );
if ( nMaxChannels [ kind ] < channelId + nChannels ) { // Overflowed available channels
nMidChannels = nMaxChannels [ kind ] - channelId - 2 ;
cellsPerChannel = ( cells - cellsLo ) / ( nMidChannels + 1 );
cellsHi = cellsPerChannel + ( cells - cellsLo ) % ( nMidChannels + 1 );
}
if ( cellsHi == 0 && nMidChannels != 0 ) {
cellsHi = cellsPerChannel ;
nMidChannels -= 1 ;
}
if ( cellsLo == 0 ) { // Least channel skipped. Make the next channel the new least.
channelId += 1 ;
if ( nMidChannels == 0 ) { cellsLo = cellsHi ; cellsHi = 0 ; }
else { cellsLo = cellsPerChannel ; nMidChannels -= 1 ; }
}
size_t countMid = nMidChannels != 0 ? cellsPerChannel * elementsPerCell : 0 ;
size_t countLo = cellsLo * elementsPerCell ;
size_t countHi = cellsHi * elementsPerCell ;
( countHi != 0 ? countHi : countLo ) -= cells * elementsPerCell - task -> count ;
nChannels = ( countLo != 0 ? 1 : 0 ) + nMidChannels + ( cellsHi != 0 ? 1 : 0 );
// Ensure room for worst case of one new batch per channel
if ( ! testBudget ( budget , plan -> nWorkBatches + nChannels , plan -> workBytes + workNode -> size )) {
return ncclSuccess ;
2022-05-24 02:02:31 -07:00
}
2024-06-11 01:28:01 -07:00
devWork -> channelLo = channelId ;
devWork -> channelHi = channelId + nChannels - 1 ;
devWork -> cbd . countLo = countLo ;
devWork -> cbd . countMid = countMid ;
devWork -> cbd . countHi = countHi ;
// calcCollChunking() uses global bytes instead of traffic which differs
// in that allreduce isn't multiplied by 2.
size_t globalBytesPerElement = elementSize * ncclFuncMaxSendRecvCount ( task -> func , comm -> nRanks , 1 );
struct ncclProxyOp proxyOpLo , proxyOpMid , proxyOpHi ;
2025-01-23 11:48:18 -06:00
size_t nBytes = globalBytesPerElement * task -> count ;
devWork -> connIndex = 0 ;
if ( task -> protocol == NCCL_PROTO_SIMPLE && task -> algorithm == NCCL_ALGO_RING ) {
if ( comm -> useIntraNet && nBytes > rcclParamIntraNetThreshold ()) {
devWork -> connIndex = NCCL_CONN_IDX_P2P_NET ;
2024-02-05 05:06:02 -08:00
}
2025-01-23 11:48:18 -06:00
}
2024-02-05 05:06:02 -08:00
2024-06-11 01:28:01 -07:00
uint32_t chunkSize , directFlags = 0 ;
size_t grainSize = ncclProtoGrainSize ( task -> protocol );
if ( countLo != 0 ) {
NCCLCHECK ( calcCollChunking ( comm , task , /*nChannels=*/ 1 , globalBytesPerElement * countLo , & chunkSize , & directFlags , & proxyOpLo ));
devWork -> cbd . chunkGrainsLo = chunkSize / grainSize ;
}
if ( countHi != 0 ) {
NCCLCHECK ( calcCollChunking ( comm , task , /*nChannels=*/ 1 , globalBytesPerElement * countHi , & chunkSize , & directFlags , & proxyOpHi ));
devWork -> cbd . chunkGrainsHi = chunkSize / grainSize ;
}
if ( nMidChannels != 0 ) {
NCCLCHECK ( calcCollChunking ( comm , task , /*nChannels=*/ 1 , globalBytesPerElement * countMid , & chunkSize , & directFlags , & proxyOpMid ));
devWork -> cbd . chunkGrainsMid = chunkSize / grainSize ;
}
devWork -> direct = directFlags ;
// Update the current channel and vacant traffic budget.
if ( countHi != 0 ) {
channelId += nChannels - 1 ;
2024-09-10 05:57:10 -07:00
currentTraffic = cellsHi * elementsPerCell * trafficPerElement ;
2024-06-11 01:28:01 -07:00
} else if ( nMidChannels != 0 ) {
channelId += nChannels ;
currentTraffic = 0 ;
2024-02-05 05:06:02 -08:00
} else {
2024-09-10 05:57:10 -07:00
currentTraffic += cellsLo * elementsPerCell * trafficPerElement ;
2024-06-11 01:28:01 -07:00
}
if ( currentTraffic >= trafficPerChannel && channelId + 1 != nMaxChannels [ kind ]) {
channelId += 1 ;
currentTraffic = 0 ;
}
uint64_t proxyOpId = uint64_t ( plan -> collOpCount ++ ) << 1 | 0 ;
for ( int c = devWork -> channelLo ; c <= ( int ) devWork -> channelHi ; c ++ ) {
struct ncclProxyOp * proxyOp ;
if ( c == ( int ) devWork -> channelLo ) {
proxyOp = & proxyOpLo ;
2024-12-18 08:26:06 -08:00
proxyOp -> loopOffset = 0 ;
proxyOp -> channelSize = countLo * elementSize ;
2024-06-11 01:28:01 -07:00
} else if ( c == ( int ) devWork -> channelHi ) {
proxyOp = & proxyOpHi ;
2024-12-18 08:26:06 -08:00
proxyOp -> loopOffset = ( countLo + nMidChannels * countMid ) * elementSize ;
proxyOp -> channelSize = countHi * elementSize ;
2024-06-11 01:28:01 -07:00
} else {
proxyOp = & proxyOpMid ;
2024-12-18 08:26:06 -08:00
proxyOp -> loopOffset = ( countLo + ( c - devWork -> channelLo - 1 ) * countMid ) * elementSize ;
proxyOp -> channelSize = countMid * elementSize ;
2024-06-11 01:28:01 -07:00
}
proxyOp -> channelId = c ;
proxyOp -> opCount = proxyOpId ;
2024-09-10 05:57:10 -07:00
proxyOp -> task . coll = task ;
proxyOp -> rank = comm -> rank ;
2024-12-18 08:26:06 -08:00
proxyOp -> ringAlgo = NULL ;
if ( proxyOp -> reg && task -> algorithm == NCCL_ALGO_RING && ( task -> recvNetHandles [ c ] || task -> sendNetHandles [ c ])) {
if ( task -> func == ncclFuncAllGather ) {
proxyOp -> ringAlgo = new RingAGAlgorithm ( task -> sendbuff , task -> recvbuff , comm -> nRanks , comm -> channels [ c ]. ring . userRanks , proxyOp -> chunkSteps , proxyOp -> sliceSteps , proxyOp -> chunkSize , proxyOp -> sliceSize , proxyOp -> loopOffset , proxyOp -> channelSize , elementSize , task -> count * elementSize , task -> sendNetHandles [ c ], task -> recvNetHandles [ c ], task -> srecvNetHandles [ c ]);
} else if ( task -> func == ncclFuncAllReduce ) {
proxyOp -> ringAlgo = new RingARAlgorithm ( task -> sendbuff , task -> recvbuff , comm -> nRanks , comm -> channels [ c ]. ring . index , proxyOp -> chunkSteps , proxyOp -> sliceSteps , proxyOp -> chunkSize , proxyOp -> sliceSize , proxyOp -> loopOffset , proxyOp -> channelSize , elementSize , task -> sendNetHandles [ c ], task -> recvNetHandles [ c ], task -> srecvNetHandles [ c ]);
} else if ( task -> func == ncclFuncBroadcast ) {
proxyOp -> ringAlgo = new RingBCAlgorithm ( task -> sendbuff , task -> recvbuff , comm -> rank , task -> root , comm -> nRanks , comm -> channels [ c ]. ring . userRanks , proxyOp -> chunkSteps , proxyOp -> sliceSteps , proxyOp -> chunkSize , proxyOp -> sliceSize , proxyOp -> loopOffset , proxyOp -> channelSize , task -> sendNetHandles [ c ], task -> recvNetHandles [ c ], task -> srecvNetHandles [ c ]);
}
proxyOp -> ringAlgo -> incRefCount ();
}
2025-03-04 13:30:36 -05:00
proxyOp -> connIndex = 0 ;
if ( task -> protocol == NCCL_PROTO_SIMPLE && task -> algorithm == NCCL_ALGO_RING ) {
if ( comm -> useIntraNet && nBytes > rcclParamIntraNetThreshold ()) {
proxyOp -> connIndex = NCCL_CONN_IDX_P2P_NET ;
}
}
2024-06-11 01:28:01 -07:00
addWorkBatchToPlan ( comm , plan , c , workNode -> workType , task -> devFuncId , plan -> workBytes );
2024-09-10 05:57:10 -07:00
// Coverity reports "proxyOp->connection" as being possibly uninitialized. It's hard to
// determine if that's actually true but it's also not clear if that would be an issue.
// coverity[uninit_use_in_call:FALSE]
2024-06-11 01:28:01 -07:00
NCCLCHECK ( addProxyOpIfNeeded ( comm , plan , proxyOp ));
2022-05-24 02:02:31 -07:00
}
2024-02-05 05:06:02 -08:00
}
2022-05-24 02:02:31 -07:00
2025-01-23 11:48:18 -06:00
for ( int c = devWork -> channelLo ; c <= devWork -> channelHi ; ++ c ) {
int maskIdx = c / 64 ;
int relativeIdx = c % 64 ;
plan -> channelMask . masks [ maskIdx ] |= ( 1ull << relativeIdx );
}
//plan->channelMask.masks[channelId/64] |= (2ull<<devWork->channelHi) - (1ull<<devWork->channelLo);
plan -> threadPerBlock = std :: max ( plan -> threadPerBlock , 3 * plan -> comm -> WarpSize );
2024-06-11 01:28:01 -07:00
if ( ! plan -> kernelSpecialized ) {
2025-01-23 11:48:18 -06:00
plan -> kernelFn = ncclKerns [ ncclGetKernelIndex ( comm )]. kernelFn ;
plan -> kernelSpecialized = ncclKerns [ ncclGetKernelIndex ( comm )]. specialized ;
2024-06-11 01:28:01 -07:00
}
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
if ( comm -> rank == 0 ) {
2024-12-18 08:26:06 -08:00
INFO ( NCCL_TUNING , "%s: %ld Bytes -> Algo %s proto %s channel{Lo..Hi}={%d..%d}" ,
ncclFuncToString ( task -> func ), task -> count * ncclTypeSize ( task -> datatype ), ncclAlgoToString ( task -> algorithm ),
ncclProtoToString ( task -> protocol ), devWork -> channelLo , devWork -> channelHi );
2024-06-11 01:28:01 -07:00
if ( task -> isCollnet ) {
TRACE ( NCCL_COLL , "Collective %s(%s, %s, %s, %s) count=%ld devFuncId=%d channel{Lo..Hi}={%d..%d} count=%ld chunkCount=%d" ,
ncclFuncToString ( task -> func ), ncclDevRedOpToString ( task -> opDev . op ),
ncclDatatypeToString ( task -> datatype ), ncclAlgoToString ( task -> algorithm ),
ncclProtoToString ( task -> protocol ),
( long ) task -> count , task -> devFuncId , devWork -> channelLo , devWork -> channelHi ,
( long ) devWork -> collnet . count , devWork -> collnet . chunkCount );
2024-02-05 05:06:02 -08:00
} else {
2024-06-11 01:28:01 -07:00
TRACE ( NCCL_COLL , "Collective %s(%s, %s, %s, %s) count=%ld devFuncId=%d channel{Lo..Hi}={%d..%d} count{Lo,Mid,Hi}={%ld,%ld,%ld} chunkBytes{Lo,Mid,Hi}={%d,%d,%d}" ,
ncclFuncToString ( task -> func ), ncclDevRedOpToString ( task -> opDev . op ),
ncclDatatypeToString ( task -> datatype ), ncclAlgoToString ( task -> algorithm ),
ncclProtoToString ( task -> protocol ),
( long ) task -> count , task -> devFuncId , devWork -> channelLo , devWork -> channelHi ,
( long ) devWork -> cbd . countLo , ( long ) devWork -> cbd . countMid , ( long ) devWork -> cbd . countHi ,
int ( devWork -> cbd . chunkGrainsLo * ncclProtoGrainSize ( task -> protocol )),
int ( devWork -> cbd . chunkGrainsMid * ncclProtoGrainSize ( task -> protocol )),
int ( devWork -> cbd . chunkGrainsHi * ncclProtoGrainSize ( task -> protocol )));
2023-09-26 05:47:28 -07:00
}
2024-02-05 05:06:02 -08:00
}
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
for ( int i = 0 ; i < task -> nCleanupQueueElts ; i ++ ) {
ncclIntruQueueEnqueue ( & plan -> cleanupQueue , ncclIntruQueueDequeue ( & planner -> collCleanupQueue ));
}
ncclIntruQueueDequeue ( & planner -> collTaskQueue );
ncclIntruQueueDequeue ( & planner -> collWorkQueue );
nPlanColls -= 1 ;
planner -> nTasksColl -= 1 ;
2024-09-10 05:57:10 -07:00
ncclIntruQueueEnqueue ( & plan -> collTaskQueue , task );
2024-06-11 01:28:01 -07:00
ncclIntruQueueEnqueue ( & plan -> workQueue , workNode );
plan -> workBytes += workNode -> size ;
2024-02-05 05:06:02 -08:00
}
2024-06-11 01:28:01 -07:00
return ncclSuccess ;
}
NCCL_PARAM ( P2pLLThreshold , "P2P_LL_THRESHOLD" , 16384 );
2025-01-23 11:48:18 -06:00
RCCL_PARAM ( P2pNetThreshold , "P2P_NET_THRESHOLD" , 131072 );
2024-06-11 01:28:01 -07:00
NCCL_PARAM ( ChunkSize , "CHUNK_SIZE" , 0 );
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
// Put p2p op in plan assuming there is sizeof(ncclDevWorkBatch) in batch budget
// and sizeof(ncclDevWorkP2p) in work budget. "sendRank" and "recvRank" must
// match the corresponding values for this round of the p2p schedule (no -1's).
// No-op's are encoded with a -1 size.
static ncclResult_t addP2pToPlan (
struct ncclComm * comm , struct ncclKernelPlan * plan ,
int nChannelsMin , int nChannelsMax , int p2pRound ,
int sendRank , void * sendAddr , ssize_t sendBytes ,
2025-02-03 08:55:27 -08:00
int recvRank , void * recvAddr , ssize_t recvBytes ,
2025-03-27 12:51:55 -05:00
uint64_t sendOpCount , uint64_t recvOpCount ,
2024-09-10 05:57:10 -07:00
struct ncclTaskP2p ** p2pTasks
2024-06-11 01:28:01 -07:00
) {
2025-02-04 11:53:20 -05:00
int connIndex [ 2 ] = { 1 , 1 };
2024-06-11 01:28:01 -07:00
bool selfSend = ( sendRank == comm -> rank );
// recv: dir=0, send: dir=1
void * addrs [ 2 ] = { recvAddr , sendAddr };
ssize_t bytes [ 2 ] = { recvBytes , sendBytes };
bool protoLL [ 2 ] = { ! selfSend , ! selfSend };
bool network [ 2 ] = { false , false };
bool proxySameProcess [ 2 ] = { true , true };
2024-12-18 08:26:06 -08:00
void ** handles [ 2 ] = { NULL , NULL };
2024-06-11 01:28:01 -07:00
uint8_t base = ncclP2pChannelBaseForRound ( comm , p2pRound );
2025-01-23 11:48:18 -06:00
2025-02-04 11:53:20 -05:00
if ( comm -> p2pNet ) {
for ( int dir = 0 ; dir <= 1 ; dir ++ ) {
if ( bytes [ dir ] > rcclParamP2pNetThreshold ())
connIndex [ dir ] = NCCL_CONN_IDX_P2P_NET ;
}
}
2025-03-17 11:21:01 -04:00
2024-06-11 01:28:01 -07:00
if ( ! selfSend ) {
for ( int part = 0 ; part < nChannelsMax ; part ++ ) {
2025-02-26 09:48:03 -05:00
int channelId = ncclP2pChannelForPart ( comm -> p2pnChannels , base , part , nChannelsMax , comm -> nNodes );
2024-06-11 01:28:01 -07:00
struct ncclChannelPeer ** channelPeers = comm -> channels [ channelId ]. peers ;
for ( int dir = 0 ; dir <= 1 ; dir ++ ) {
int peerRank = dir ? sendRank : recvRank ;
2025-02-04 11:53:20 -05:00
struct ncclConnector * conn = dir ? & channelPeers [ peerRank ] -> send [ connIndex [ dir ]]
: & channelPeers [ peerRank ] -> recv [ connIndex [ dir ]];
2025-01-23 11:48:18 -06:00
protoLL [ dir ] &= conn -> conn . buffs [ NCCL_PROTO_LL ] != nullptr && ! IsArchMatch ( comm -> topo -> nodes [ GPU ]. nodes [ 0 ]. gpu . gcn , "gfx12" );
2024-06-11 01:28:01 -07:00
network [ dir ] |= conn -> transportComm == ( dir ? & netTransport . send : & netTransport . recv );
proxySameProcess [ dir ] &= conn -> proxyConn . sameProcess ;
}
}
2024-02-05 05:06:02 -08:00
}
2023-09-26 05:47:28 -07:00
2024-06-11 01:28:01 -07:00
ssize_t thresholdLL = nChannelsMax * ncclParamP2pLLThreshold ();
ssize_t paramChunkSize = ncclParamChunkSize ();
// Arrays indexed by dir where recv=0, send=1:
int nChannels [ 2 ];
int protocol [ 2 ];
int stepSize [ 2 ];
int chunkSize [ 2 ];
int chunkDataSize [ 2 ];
int chunkDataSize_u32fp8 [ 2 ];
2024-12-18 08:26:06 -08:00
bool netRegistered [ 2 ] = { false , false };
2024-09-10 05:57:10 -07:00
bool ipcRegistered [ 2 ] = { false , false };
2024-06-11 01:28:01 -07:00
for ( int dir = 0 ; dir < 2 ; dir ++ ) { // 0=recv, 1=send
if ( bytes [ dir ] != - 1 ) protoLL [ dir ] &= bytes [ dir ] <= thresholdLL ;
protocol [ dir ] = protoLL [ dir ] ? NCCL_PROTO_LL : NCCL_PROTO_SIMPLE ;
stepSize [ dir ] = comm -> buffSizes [ protocol [ dir ]] / NCCL_STEPS ;
if ( protocol [ dir ] == NCCL_PROTO_SIMPLE ) stepSize [ dir ] = comm -> p2pChunkSize ;
chunkSize [ dir ] = stepSize [ dir ];
if ( paramChunkSize != 0 ) {
chunkSize [ dir ] = paramChunkSize ;
} else if ( network [ dir ]) {
// Tune chunk size for the network
if ( protocol [ dir ] == NCCL_PROTO_SIMPLE && bytes [ dir ] < stepSize [ dir ]) chunkSize [ dir ] /= 4 ;
else if ( bytes [ dir ] < 8 * stepSize [ dir ]) chunkSize [ dir ] /= 2 ;
}
2023-09-26 05:47:28 -07:00
2024-06-11 01:28:01 -07:00
chunkDataSize [ dir ] = chunkSize [ dir ];
if ( protocol [ dir ] == NCCL_PROTO_LL ) chunkDataSize [ dir ] /= 2 ;
chunkDataSize_u32fp8 [ dir ] = u32fp8Encode ( chunkDataSize [ dir ]);
chunkDataSize [ dir ] = u32fp8Decode ( chunkDataSize_u32fp8 [ dir ]);
chunkSize [ dir ] = chunkDataSize [ dir ];
if ( protocol [ dir ] == NCCL_PROTO_LL ) chunkSize [ dir ] *= 2 ;
2024-09-10 05:57:10 -07:00
if ( network [ dir ]) {
2024-12-18 08:26:06 -08:00
if ( bytes [ dir ] > 0 && proxySameProcess [ dir ] && protocol [ dir ] == NCCL_PROTO_SIMPLE && ( ncclPxnDisable ( comm ) || ! comm -> isAllNvlink )) {
int regFlag = 0 ;
NCCLCHECK ( ncclCalloc ( & handles [ dir ], nChannelsMax ));
for ( int part = 0 ; part < nChannelsMax ; part ++ ) {
2025-04-23 20:46:36 -07:00
int channelId = ncclP2pChannelForPart ( comm -> p2pnChannels , base , part , nChannelsMax , comm -> nNodes );
2024-12-18 08:26:06 -08:00
struct ncclChannelPeer ** channelPeers = comm -> channels [ channelId ]. peers ;
int peerRank = dir ? sendRank : recvRank ;
2025-04-23 20:46:36 -07:00
struct ncclConnector * conn = dir ? & channelPeers [ peerRank ] -> send [ connIndex [ dir ]]
: & channelPeers [ peerRank ] -> recv [ connIndex [ dir ]];
2024-12-18 08:26:06 -08:00
if ( conn -> conn . flags & NCCL_DIRECT_NIC )
ncclRegisterP2pNetBuffer ( comm , addrs [ dir ], bytes [ dir ], conn , & regFlag , & handles [ dir ][ part ], & plan -> cleanupQueue );
if ( ! regFlag ) break ;
}
netRegistered [ dir ] = regFlag ? true : false ;
2024-09-10 05:57:10 -07:00
}
} else if ( bytes [ dir ] > 0 && addrs [ dir ] && protocol [ dir ] == NCCL_PROTO_SIMPLE && ! selfSend ) {
int peerRank = dir ? sendRank : recvRank ;
int regFlag = 0 ;
2025-03-27 12:51:55 -05:00
int channelId = ncclP2pChannelForPart ( comm -> p2pnChannels , base , 0 , nChannelsMax , comm -> nNodes );
2024-09-10 05:57:10 -07:00
struct ncclChannelPeer ** channelPeers = comm -> channels [ channelId ]. peers ;
2025-03-27 12:51:55 -05:00
struct ncclConnector * conn = dir ? & channelPeers [ peerRank ] -> send [ connIndex [ dir ]]
: & channelPeers [ peerRank ] -> recv [ connIndex [ dir ]];
2024-09-10 05:57:10 -07:00
void * regAddr = NULL ;
2024-12-18 08:26:06 -08:00
if ( conn -> conn . flags & ( NCCL_P2P_WRITE | NCCL_P2P_READ )) {
2024-09-10 05:57:10 -07:00
// We require users registering buffers on both sides
2024-12-18 08:26:06 -08:00
NCCLCHECK ( ncclRegisterP2pIpcBuffer ( comm , addrs [ dir ], bytes [ dir ], peerRank , & regFlag , & regAddr , & plan -> cleanupQueue ));
2024-09-10 05:57:10 -07:00
if ( regFlag ) {
2024-12-18 08:26:06 -08:00
if ( dir == 0 && ( conn -> conn . flags & NCCL_P2P_WRITE )) recvAddr = regAddr ;
else if ( dir == 1 && ( conn -> conn . flags & NCCL_P2P_READ )) sendAddr = regAddr ;
2024-09-10 05:57:10 -07:00
}
}
ipcRegistered [ dir ] = regFlag ? true : false ;
2024-06-11 01:28:01 -07:00
}
2023-09-26 05:47:28 -07:00
2024-06-11 01:28:01 -07:00
if ( bytes [ dir ] == - 1 ) nChannels [ dir ] = 0 ;
else if ( bytes [ dir ] == 0 ) nChannels [ dir ] = 1 ;
else {
ssize_t minPartSize = comm -> nNodes > 1 ? stepSize [ dir ] / 2 : stepSize [ dir ] / 8 ;
ssize_t maxPartSize = comm -> nNodes > 1 ? stepSize [ dir ] : stepSize [ dir ] * 32 ;
nChannels [ dir ] = std :: min < int > ( nChannelsMin , divUp ( bytes [ dir ], minPartSize ));
size_t partSize = std :: max ( minPartSize , divUp ( bytes [ dir ], nChannels [ dir ]));
while ( partSize > maxPartSize && nChannels [ dir ] <= nChannelsMax / 2 ) {
nChannels [ dir ] *= 2 ;
partSize = divUp ( bytes [ dir ], nChannels [ dir ]);
}
}
2024-02-05 05:06:02 -08:00
}
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
struct ncclWorkList * workNode = ncclMemoryStackAllocInlineArray < ncclWorkList , ncclDevWorkP2p > ( & comm -> memScoped , 1 );
workNode -> workType = ncclDevWorkTypeP2p ;
workNode -> size = sizeof ( struct ncclDevWorkP2p );
ncclIntruQueueEnqueue ( & plan -> workQueue , workNode );
uint32_t workOffset = plan -> workBytes ;
plan -> workBytes += sizeof ( struct ncclDevWorkP2p );
struct ncclDevWorkP2p * work = ( struct ncclDevWorkP2p * )( workNode + 1 );
work -> nP2pChannels = comm -> p2pnChannels ;
work -> channelBase = base ;
work -> nSendChannels = nChannels [ 1 ];
work -> sendProtoLL = protoLL [ 1 ];
2024-12-18 08:26:06 -08:00
work -> sendNetReg = netRegistered [ 1 ];
2024-09-10 05:57:10 -07:00
work -> sendIpcReg = ipcRegistered [ 1 ];
2024-06-11 01:28:01 -07:00
work -> sendChunkSize_u32fp8 = chunkDataSize_u32fp8 [ 1 ];
work -> sendRank = sendRank ;
work -> sendAddr = sendAddr ;
work -> sendBytes = sendBytes ==- 1 ? 0 : sendBytes ;
2025-02-04 11:53:20 -05:00
work -> sendConnIndex = connIndex [ 1 ];
work -> sendOpCount = sendOpCount ;
2024-06-11 01:28:01 -07:00
work -> nRecvChannels = nChannels [ 0 ];
work -> recvProtoLL = protoLL [ 0 ];
2024-12-18 08:26:06 -08:00
work -> recvNetReg = netRegistered [ 0 ];
2024-09-10 05:57:10 -07:00
work -> recvIpcReg = ipcRegistered [ 0 ];
2024-06-11 01:28:01 -07:00
work -> recvChunkSize_u32fp8 = chunkDataSize_u32fp8 [ 0 ];
work -> recvRank = recvRank ;
work -> recvAddr = recvAddr ;
work -> recvBytes = recvBytes ==- 1 ? 0 : recvBytes ;
2025-02-04 11:53:20 -05:00
work -> recvConnIndex = connIndex [ 0 ];
2025-02-03 08:55:27 -08:00
work -> recvOpCount = recvOpCount ;
2024-06-11 01:28:01 -07:00
struct ncclProxyOp proxyOps [ 2 ] = {};
int nProxyOps = selfSend ? 0 : 2 ;
for ( int dir = 0 ; dir < nProxyOps ; dir ++ ) {
struct ncclProxyOp * op = & proxyOps [ dir ];
op -> root = dir ? sendRank : recvRank ;
op -> sliceSteps = 1 ;
op -> chunkSteps = 1 ;
op -> dtype = ncclInt8 ;
op -> redOp = ncclSum ;
op -> protocol = protocol [ dir ];
op -> pattern = dir ? ncclPatternSend : ncclPatternRecv ;
op -> chunkSize = chunkSize [ dir ];
2024-12-18 08:26:06 -08:00
op -> reg = netRegistered [ dir ];
2024-09-10 05:57:10 -07:00
op -> coll = p2pTasks [ dir ] ? p2pTasks [ dir ] -> func : 0 ;
op -> task . p2p = p2pTasks [ dir ];
op -> rank = comm -> rank ;
2025-02-04 11:53:20 -05:00
op -> connIndex = connIndex [ dir ];
2024-06-11 01:28:01 -07:00
// The following are modified per channel part in addWorkToChannels():
// op->buffer, op->nbytes, op->nsteps = ...;
}
nChannelsMax = std :: max ( nChannels [ 0 ], nChannels [ 1 ]);
for ( int part = 0 ; part < nChannelsMax ; part ++ ) {
2025-02-26 09:48:03 -05:00
int channelId = ncclP2pChannelForPart ( comm -> p2pnChannels , base , part , comm -> p2pnChannelsPerPeer , comm -> nNodes );
2025-01-23 11:48:18 -06:00
plan -> channelMask . masks [ channelId / 64 ] |= uint64_t ( 1 ) << ( channelId % 64 );
2024-06-11 01:28:01 -07:00
// Add batch first.
2025-01-23 11:48:18 -06:00
int funcIdx = ncclDevFuncId_P2p ();
addWorkBatchToPlan ( comm , plan , channelId , ncclDevWorkTypeP2p , funcIdx , workOffset , p2pRound );
if ( funcIdx < 0 ) {
WARN ( "%s: unsupported collective. Please ensure the collective has been enabled in build." , __func__ );
return ncclInvalidUsage ;
}
2024-06-11 01:28:01 -07:00
// Add proxy ops.
for ( int dir = 0 ; dir < nProxyOps ; dir ++ ) {
// Partition steps across channels.
int nParts = dir ? work -> nSendChannels : work -> nRecvChannels ;
void * addr = dir ? work -> sendAddr : work -> recvAddr ;
size_t bytes = dir ? work -> sendBytes : work -> recvBytes ;
proxyOps [ dir ]. recvbuff = nullptr ;
if ( nParts <= part ) {
proxyOps [ dir ]. nsteps = 0 ;
} else if ( bytes == 0 ) {
proxyOps [ dir ]. nsteps = 1 ;
proxyOps [ dir ]. nbytes = 0 ;
} else {
size_t chunkDataSize = u32fp8Decode ( dir ? work -> sendChunkSize_u32fp8 : work -> recvChunkSize_u32fp8 );
size_t partBeg , partEnd ;
ncclP2pPartBounds ( nParts , part , bytes , & partBeg , & partEnd );
if ( proxyOps [ dir ]. reg ) {
2024-12-18 08:26:06 -08:00
( dir ? proxyOps [ dir ]. sendbuff : proxyOps [ dir ]. recvbuff ) = ( uint8_t * ) addr + partBeg ;
( dir ? proxyOps [ dir ]. sendMhandle : proxyOps [ dir ]. recvMhandle ) = handles [ dir ][ part ];
proxyOps [ dir ]. nbytes = partEnd - partBeg ;
proxyOps [ dir ]. nsteps = DIVUP ( proxyOps [ dir ]. nbytes , NCCL_MAX_NET_SIZE );
2024-06-11 01:28:01 -07:00
} else {
proxyOps [ dir ]. nsteps = divUp ( partEnd - partBeg , chunkDataSize );
proxyOps [ dir ]. nbytes = std :: min ( partEnd - partBeg , chunkDataSize );
}
if ( proxyOps [ dir ]. protocol == NCCL_PROTO_LL ) {
proxyOps [ dir ]. nbytes *= 2 ;
proxyOps [ dir ]. nbytes = roundUp ( proxyOps [ dir ]. nbytes , sizeof ( union ncclLLFifoLine ));
}
}
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
if ( proxyOps [ dir ]. nsteps != 0 ) {
// Calculate the opCount after adding batch since then the batch count will
// equal one plus the batch index this p2p settled in.
proxyOps [ dir ]. channelId = channelId ;
proxyOps [ dir ]. opCount = uint64_t ( comm -> planner . wipPlan . channels [ channelId ]. nWorkBatchesP2p ) << 1 | 1 ;
NCCLCHECK ( addProxyOpIfNeeded ( comm , plan , & proxyOps [ dir ]));
}
}
2018-12-13 15:56:12 -08:00
}
2024-02-05 05:06:02 -08:00
2018-12-13 15:56:12 -08:00
return ncclSuccess ;
}
2024-06-11 01:28:01 -07:00
static int calcP2pChannelCount ( size_t totalSize , int minChannels , int maxChannels , size_t minSize , size_t maxSize ) {
2022-05-24 02:02:31 -07:00
size_t size = std :: max ( minSize , divUp ( totalSize , minChannels ));
int nChannels = minChannels ;
while ( size > maxSize && nChannels <= maxChannels / 2 ) {
nChannels *= 2 ;
size = divUp ( totalSize , nChannels );
2018-12-13 15:56:12 -08:00
}
2024-06-11 01:28:01 -07:00
return nChannels ;
2022-05-24 02:02:31 -07:00
}
2018-12-13 15:56:12 -08:00
2022-05-24 02:02:31 -07:00
static ncclResult_t scheduleP2pTasksToPlan (
2024-06-11 01:28:01 -07:00
struct ncclComm * comm , struct ncclKernelPlan * plan , struct ncclKernelPlanBudget * budget
2022-05-24 02:02:31 -07:00
) {
int nRanks = comm -> nRanks ;
2024-06-11 01:28:01 -07:00
struct ncclKernelPlanner :: Peer * peers = comm -> planner . peers ;
2022-05-24 02:02:31 -07:00
plan -> threadPerBlock = std :: max ( plan -> threadPerBlock , NCCL_MAX_NTHREADS );
2022-08-18 02:53:17 -07:00
if ( ! plan -> kernelSpecialized ) {
2022-10-20 15:40:03 +00:00
plan -> kernelFn = ncclKerns [ ncclGetKernelIndex ( comm )]. kernelFn ;
plan -> kernelSpecialized = ncclKerns [ ncclGetKernelIndex ( comm )]. specialized ;
2022-08-18 02:53:17 -07:00
}
2022-05-24 02:02:31 -07:00
// Compute how much to split operations
// Try to use all channels
int nChannelsMax = comm -> p2pnChannelsPerPeer ;
int nChannelsMin = nChannelsMax ;
2024-02-05 05:06:02 -08:00
// Try to use all channels, but one channel per operation.
while ( nChannelsMin * nRanks > comm -> p2pnChannels && nChannelsMin > 1 ) nChannelsMin /= 2 ;
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
while ( comm -> planner . nTasksP2p != 0 ) {
for ( int round = 0 ; round < nRanks ; round ++ ) {
int sendRank = comm -> p2pSchedule [ round ]. sendRank ;
int recvRank = comm -> p2pSchedule [ round ]. recvRank ;
struct ncclTaskP2p * send = ncclIntruQueueHead ( & peers [ sendRank ]. sendQueue );
struct ncclTaskP2p * recv = ncclIntruQueueHead ( & peers [ recvRank ]. recvQueue );
if ( send == nullptr && recv == nullptr ) continue ;
if ( sendRank == comm -> rank ) {
if ( send != nullptr && recv == nullptr ) {
2022-05-24 02:02:31 -07:00
WARN ( "Trying to send to self without a matching recv" );
return ncclInvalidUsage ;
}
2024-06-11 01:28:01 -07:00
if ( send == nullptr && recv != nullptr ) {
2022-05-24 02:02:31 -07:00
WARN ( "Trying to recv to self without a matching send" );
return ncclInvalidUsage ;
}
}
2024-06-11 01:28:01 -07:00
ssize_t sendBytes = send ? send -> bytes : - 1 ;
ssize_t recvBytes = recv ? recv -> bytes : - 1 ;
void * sendBuff = send ? send -> buff : nullptr ;
void * recvBuff = recv ? recv -> buff : nullptr ;
if ( sendRank == comm -> rank && send -> buff == recv -> buff ) {
// Skip send to self in-place (we don't need to support this).
ncclIntruQueueDequeue ( & peers [ sendRank ]. sendQueue );
ncclIntruQueueDequeue ( & peers [ recvRank ]. recvQueue );
2024-12-18 08:26:06 -08:00
ncclMemoryPoolFree ( & comm -> memPool_ncclTaskP2p , send );
ncclMemoryPoolFree ( & comm -> memPool_ncclTaskP2p , recv );
2024-06-11 01:28:01 -07:00
comm -> planner . nTasksP2p -= 2 ;
} else {
// Ensure room for worst case of one new batch per channel.
if ( ! testBudget ( budget , plan -> nWorkBatches + nChannelsMax , plan -> workBytes + sizeof ( struct ncclDevWorkP2p ))) {
return ncclSuccess ;
}
2024-09-10 05:57:10 -07:00
struct ncclTaskP2p * p2pTasks [ 2 ] = { recv , send };
2025-03-27 12:51:55 -05:00
NCCLCHECK ( addP2pToPlan ( comm , plan , nChannelsMin , nChannelsMax , round , sendRank , sendBuff , sendBytes , recvRank , recvBuff , recvBytes , send ? send -> opCount : 0 , recv ? recv -> opCount : 0 , p2pTasks ));
2024-06-11 01:28:01 -07:00
if ( send != nullptr ) {
ncclIntruQueueDequeue ( & peers [ sendRank ]. sendQueue );
2024-09-10 05:57:10 -07:00
ncclIntruQueueEnqueue ( & plan -> p2pTaskQueue , send );
2024-06-11 01:28:01 -07:00
comm -> planner . nTasksP2p -= 1 ;
}
if ( recv != nullptr ) {
ncclIntruQueueDequeue ( & peers [ recvRank ]. recvQueue );
2024-09-10 05:57:10 -07:00
ncclIntruQueueEnqueue ( & plan -> p2pTaskQueue , recv );
2024-06-11 01:28:01 -07:00
comm -> planner . nTasksP2p -= 1 ;
}
2022-05-24 02:02:31 -07:00
}
2018-12-13 15:56:12 -08:00
}
}
return ncclSuccess ;
}
2024-06-11 01:28:01 -07:00
// Spin until its safe to increase comm->workFifoProduced to desiredProduced.
static void waitWorkFifoAvailable ( struct ncclComm * comm , uint32_t desiredProduced ) {
bool hasRoom = ( desiredProduced - comm -> workFifoConsumedLeast ) <= comm -> workFifoBytes ;
if ( hasRoom ) return ;
while ( true ) {
// We have to poll for notifications from device.
uint32_t * consumedLive = comm -> workFifoConsumed ;
uint32_t consumed [ MAXCHANNELS ];
for ( int c = 0 ; c < MAXCHANNELS ; c ++ ) {
consumed [ c ] = __atomic_load_n ( & consumedLive [ c ], __ATOMIC_RELAXED );
}
// Compiler-only fence to prevent fusion of loops to encourage dense loads.
__atomic_signal_fence ( __ATOMIC_SEQ_CST );
uint32_t produced = comm -> workFifoProduced ;
uint32_t consumedLeast = produced ;
for ( int c = 0 ; c < MAXCHANNELS ; c ++ ) {
// consumedLeast is min over all non-quiesced channels
if ( consumed [ c ] != comm -> channels [ c ]. workFifoProduced ) {
if (( produced - consumedLeast ) < ( produced - consumed [ c ])) {
consumedLeast = consumed [ c ];
}
2022-05-24 02:02:31 -07:00
}
2024-06-11 01:28:01 -07:00
}
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
// Compiler only fence to prevent fusion of loops to encourage dense stores.
__atomic_signal_fence ( __ATOMIC_SEQ_CST );
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
for ( int c = 0 ; c < MAXCHANNELS ; c ++ ) {
// Advance counter on quiesced channels so they don't lag behind
// too far where they could get lost in 32-bit wraparound.
if ( consumed [ c ] == comm -> channels [ c ]. workFifoProduced ) {
comm -> channels [ c ]. workFifoProduced = consumedLeast ;
__atomic_store_n ( & consumedLive [ c ], consumedLeast , __ATOMIC_RELAXED );
2022-05-24 02:02:31 -07:00
}
}
2024-06-11 01:28:01 -07:00
comm -> workFifoConsumedLeast = consumedLeast ;
hasRoom = ( desiredProduced - comm -> workFifoConsumedLeast ) <= comm -> workFifoBytes ;
if ( hasRoom ) break ;
sched_yield ();
2018-12-13 15:56:12 -08:00
}
2022-05-24 02:02:31 -07:00
}
2024-09-10 05:57:10 -07:00
namespace {
struct uploadWork_cleanup_t {
struct ncclCommEventCallback base ;
void * hostBuf ;
};
ncclResult_t uploadWork_cleanup_fn (
struct ncclComm * comm , struct ncclCommEventCallback * cb
) {
struct uploadWork_cleanup_t * me = ( struct uploadWork_cleanup_t * ) cb ;
free ( me -> hostBuf );
CUDACHECK ( cudaEventDestroy ( me -> base . event ));
return ncclSuccess ;
}
}
2022-05-24 02:02:31 -07:00
static ncclResult_t uploadWork ( struct ncclComm * comm , struct ncclKernelPlan * plan ) {
2024-06-11 01:28:01 -07:00
size_t workBytes = plan -> workBytes ;
size_t batchBytes = plan -> nWorkBatches * sizeof ( struct ncclDevWorkBatch );
2024-09-10 05:57:10 -07:00
void * fifoBufHost ;
2024-06-11 01:28:01 -07:00
uint32_t fifoCursor , fifoMask ;
switch ( plan -> workStorageType ) {
case ncclDevWorkStorageTypeArgs :
plan -> kernelArgs -> workBuf = nullptr ;
2024-09-10 05:57:10 -07:00
fifoBufHost = ( void * ) plan -> kernelArgs ;
2024-06-11 01:28:01 -07:00
fifoCursor = sizeof ( ncclDevKernelArgs ) + batchBytes ;
fifoMask = ~ 0u ;
break ;
case ncclDevWorkStorageTypeFifo :
2024-09-10 05:57:10 -07:00
fifoBufHost = comm -> workFifoBuf ;
2024-06-11 01:28:01 -07:00
fifoCursor = comm -> workFifoProduced ;
fifoMask = comm -> workFifoBytes - 1 ;
waitWorkFifoAvailable ( comm , fifoCursor + workBytes );
plan -> kernelArgs -> workBuf = comm -> workFifoBufDev ;
break ;
case ncclDevWorkStorageTypePersistent :
2024-12-18 08:26:06 -08:00
// We rely on 16-byte alignment
#if __cplusplus >= 201103L
fifoBufHost = aligned_alloc ( 16 , ROUNDUP ( workBytes , 16 ));
#else
2024-09-10 05:57:10 -07:00
static_assert ( 16 <= alignof ( max_align_t ), "We rely on 16-byte alignment." );
fifoBufHost = malloc ( workBytes );
2024-12-18 08:26:06 -08:00
#endif
2024-06-11 01:28:01 -07:00
fifoCursor = 0 ;
fifoMask = ~ 0u ;
break ;
default :
return ncclInternalError ;
}
plan -> kernelArgs -> workMask = fifoMask ;
// Batches were placed after kernelArgs by finishPlan(). Only thing left to
// do is translate the work offset from zero based (in plan) to:
// ncclDevWorkStorageTypeArgs: offset from beginning of kernel args
// ncclDevWorkStorageTypeFifo: offset from base of fifo
// ncclDevWorkStorageTypePersistent: no translation since our dedicated buffer will also begin at zero.
struct ncclDevWorkBatch * batchZero = ( struct ncclDevWorkBatch * )( plan -> kernelArgs + 1 );
for ( int b = 0 ; b < plan -> nWorkBatches ; b ++ ) {
batchZero [ b ]. offsetBase += fifoCursor ;
}
// Write the channel-shared work structs.
struct ncclWorkList * workNode = ncclIntruQueueHead ( & plan -> workQueue );
while ( workNode != nullptr ) {
2024-09-10 05:57:10 -07:00
char * dst = ( char * ) fifoBufHost ;
2024-06-11 01:28:01 -07:00
char * src = ( char * )( workNode + 1 );
for ( int n = workNode -> size ; n != 0 ; n -= 16 ) {
memcpy (
__builtin_assume_aligned ( dst + ( fifoCursor & fifoMask ), 16 ),
__builtin_assume_aligned ( src , 16 ),
16
);
fifoCursor += 16 ;
src += 16 ;
2022-05-24 02:02:31 -07:00
}
2024-06-11 01:28:01 -07:00
workNode = workNode -> next ;
2018-12-13 15:56:12 -08:00
}
2020-05-12 14:40:18 -07:00
2024-06-11 01:28:01 -07:00
switch ( plan -> workStorageType ) {
case ncclDevWorkStorageTypeFifo :
comm -> workFifoProduced = fifoCursor ;
if ( comm -> workFifoBufGdrHandle != nullptr ) wc_store_fence ();
break ;
case ncclDevWorkStorageTypePersistent :
2024-09-10 05:57:10 -07:00
{ ncclResult_t result = ncclSuccess ;
2024-12-18 08:26:06 -08:00
struct uploadWork_cleanup_t * cleanup = nullptr ;
2024-09-10 05:57:10 -07:00
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed ;
void * fifoBufDev = nullptr ;
2024-12-18 08:26:06 -08:00
CUDACHECKGOTO ( cudaThreadExchangeStreamCaptureMode ( & mode ), result , fail );
2024-09-10 05:57:10 -07:00
// Acquire deviceStream to gain access to deviceStream.cudaStream. Since the
// user's graph will be launched later, and it also acquires the deviceStream,
// it will observe this upload.
2024-12-18 08:26:06 -08:00
NCCLCHECKGOTO ( ncclStrongStreamAcquireUncaptured ( & comm -> sharedRes -> deviceStream ), result , fail );
2024-09-10 05:57:10 -07:00
2024-12-18 08:26:06 -08:00
CUDACHECKGOTO ( cudaMallocAsync ( & fifoBufDev , workBytes , comm -> memPool , comm -> sharedRes -> deviceStream . cudaStream ), result , fail );
2024-09-10 05:57:10 -07:00
plan -> workBufPersistent = fifoBufDev ;
plan -> kernelArgs -> workBuf = fifoBufDev ;
2024-12-18 08:26:06 -08:00
// coverity[uninit_use_in_call:FALSE] => fifoBufHost is never NULL
CUDACHECKGOTO ( cudaMemcpyAsync ( fifoBufDev , fifoBufHost , workBytes , cudaMemcpyDefault , comm -> sharedRes -> deviceStream . cudaStream ), result , fail );
2024-09-10 05:57:10 -07:00
cudaEvent_t memcpyDone ;
2024-12-18 08:26:06 -08:00
CUDACHECKGOTO ( cudaEventCreateWithFlags ( & memcpyDone , cudaEventDisableTiming ), result , fail );
CUDACHECKGOTO ( cudaEventRecord ( memcpyDone , comm -> sharedRes -> deviceStream . cudaStream ), result , fail );
2024-09-10 05:57:10 -07:00
2024-12-18 08:26:06 -08:00
NCCLCHECKGOTO ( ncclCalloc ( & cleanup , 1 ), result , fail );
2024-09-10 05:57:10 -07:00
cleanup -> base . fn = uploadWork_cleanup_fn ;
cleanup -> base . event = memcpyDone ;
cleanup -> hostBuf = fifoBufHost ;
2024-12-18 08:26:06 -08:00
ncclIntruQueueEnqueue ( & comm -> eventCallbackQueue , ( struct ncclCommEventCallback * ) cleanup );
2024-09-10 05:57:10 -07:00
2024-12-18 08:26:06 -08:00
NCCLCHECKGOTO ( ncclStrongStreamRelease ( ncclCudaGraphNone (), & comm -> sharedRes -> deviceStream ), result , fail );
NCCLCHECKGOTO ( ncclCommPollEventCallbacks ( comm ), result , fail );
2024-09-10 05:57:10 -07:00
finish_scope :
2024-12-18 08:26:06 -08:00
if ( mode != cudaStreamCaptureModeRelaxed ) ( void ) cudaThreadExchangeStreamCaptureMode ( & mode );
return result ;
fail :
if ( ! cleanup ) free ( fifoBufHost );
goto finish_scope ;
2024-09-10 05:57:10 -07:00
} break ;
2024-06-11 01:28:01 -07:00
default : break ;
2022-05-24 02:02:31 -07:00
}
2021-04-12 16:00:11 -07:00
return ncclSuccess ;
}
2022-05-24 02:02:31 -07:00
static ncclResult_t uploadProxyOps ( struct ncclComm * comm , struct ncclKernelPlan * plan ) {
2023-04-03 05:32:07 -07:00
uint64_t collOpCount = comm -> sharedRes -> collOpCount ;
2024-06-11 01:28:01 -07:00
uint64_t p2pOpBump [ MAXCHANNELS ] = { /*0...*/ };
2022-05-24 02:02:31 -07:00
// Advance comm's collOpCount by number of colls in this plan.
2023-04-03 05:32:07 -07:00
comm -> sharedRes -> collOpCount += plan -> collOpCount ;
2024-12-18 08:26:06 -08:00
comm -> collOpCount += plan -> collOpCount ;
2024-02-05 05:06:02 -08:00
2024-06-11 01:28:01 -07:00
struct ncclProxyOp * op = ncclIntruQueueHead ( & plan -> proxyOpQueue );
while ( op != nullptr ) {
2024-09-10 05:57:10 -07:00
op -> profilerContext = comm -> profilerContext ;
op -> eActivationMask = op -> coll <= ncclFuncAllReduce ? op -> task . coll -> eActivationMask : op -> task . p2p -> eActivationMask ;
op -> taskEventHandle = op -> coll <= ncclFuncAllReduce ? op -> task . coll -> eventHandle : op -> task . p2p -> eventHandle ;
ncclProfilerAddPidToProxyOp ( op );
2024-06-11 01:28:01 -07:00
uint64_t oldId = op -> opCount ;
2024-02-05 05:06:02 -08:00
// Ignoring the bottom tag bit, opCount's are zero-based within plan so
// translate them to the tip of the comm's history.
if ( oldId & 1 ) { // p2p
// opCount is monotonic increasing within a plan's channel so just
// remember last value to compute max.
2024-06-11 01:28:01 -07:00
p2pOpBump [ op -> channelId ] = ( oldId >> 1 ) + 1 ; // +1 to ensure next plan doesn't collide
op -> opCount = ( comm -> sharedRes -> p2pOpCount [ op -> channelId ] << 1 ) + oldId ;
2024-02-05 05:06:02 -08:00
} else { // coll
2024-06-11 01:28:01 -07:00
op -> opCount = ( collOpCount << 1 ) + oldId ;
2022-05-24 02:02:31 -07:00
}
2024-02-05 05:06:02 -08:00
2024-06-11 01:28:01 -07:00
NCCLCHECK ( ncclProxySaveOp ( comm , op , nullptr ));
op -> opCount = oldId ; // Restore for next uploadProxyOps()
2024-12-18 08:26:06 -08:00
op = op -> enqNext ;
2024-02-05 05:06:02 -08:00
}
2024-06-11 01:28:01 -07:00
for ( int c = 0 ; c < MAXCHANNELS ; c ++ ) {
2022-05-24 02:02:31 -07:00
// Advance channel's p2pOpCount by number of p2p's in this plan channel.
2024-02-05 05:06:02 -08:00
comm -> sharedRes -> p2pOpCount [ c ] += p2pOpBump [ c ];
2020-09-04 14:35:05 -07:00
}
2018-12-13 15:56:12 -08:00
return ncclSuccess ;
}
2022-05-24 02:02:31 -07:00
static ncclResult_t hostStreamPlanTask ( struct ncclComm * comm , struct ncclKernelPlan * plan ) {
2024-09-10 05:57:10 -07:00
NCCLCHECK ( ncclProfilerStartGroupEvent ( plan ));
NCCLCHECK ( ncclProfilerStartTaskEvents ( plan ));
2022-05-24 02:02:31 -07:00
NCCLCHECK ( uploadProxyOps ( comm , plan ));
2020-05-12 14:40:18 -07:00
NCCLCHECK ( ncclProxyStart ( comm ));
2024-09-10 05:57:10 -07:00
NCCLCHECK ( ncclProfilerStopTaskEvents ( plan ));
NCCLCHECK ( ncclProfilerStopGroupEvent ( plan ));
2022-05-24 02:02:31 -07:00
if ( ! plan -> persistent ) {
// Notify main thread of our reclaiming. This will reclaim plan concurrently.
ncclIntruQueueMpscEnqueue ( & comm -> callbackQueue , & plan -> reclaimer );
}
2018-12-13 15:56:12 -08:00
return ncclSuccess ;
}
2021-04-12 16:00:11 -07:00
2022-09-09 01:20:52 +00:00
static void HIPRT_CB hostStreamPlanCallback ( void * plan_ ) {
2022-11-07 14:09:26 -08:00
NVTX3_FUNC_RANGE_IN ( nccl_domain );
2022-05-24 02:02:31 -07:00
struct ncclKernelPlan * plan = ( struct ncclKernelPlan * ) plan_ ;
ncclResult_t result = hostStreamPlanTask ( plan -> comm , plan );
if ( result != ncclSuccess ) {
2023-02-27 02:48:21 -08:00
WARN ( "hostStreamPlanCallback() failed : %s" , ncclGetErrorString ( result ));
2022-09-06 10:29:46 -06:00
}
2024-12-18 08:26:06 -08:00
if ( ! plan -> persistent ) ncclAtomicRefCountDecrement ( & plan -> comm -> noncapturedRefs );
return ;
2022-05-24 02:02:31 -07:00
}
static ncclResult_t reclaimPlan ( struct ncclComm * comm , struct ncclCommCallback * me ) {
struct ncclKernelPlan * plan = ( struct ncclKernelPlan * ) me ; // cast from first member `reclaim`
if ( plan -> persistent ) {
comm -> persistentRefs -= 1 ;
2024-06-11 01:28:01 -07:00
NCCLCHECK ( ncclCudaFree ( plan -> workBufPersistent ));
2018-12-13 15:56:12 -08:00
}
2024-12-18 08:26:06 -08:00
// Free proxy ops
struct ncclProxyOp * q = ncclIntruQueueHead ( & plan -> proxyOpQueue );
while ( q != nullptr ) {
struct ncclProxyOp * q1 = q -> enqNext ;
if ( q -> ringAlgo && q -> ringAlgo -> decRefCount () == 0 ) delete q -> ringAlgo ;
ncclMemoryPoolFree ( & comm -> memPool_ncclProxyOp , q );
q = q1 ;
}
// Run other free callbacks
ncclResult_t result = ncclSuccess ;
while ( ! ncclIntruQueueEmpty ( & plan -> cleanupQueue )) {
struct ncclCommCallback * cb = ncclIntruQueueDequeue ( & plan -> cleanupQueue );
ncclResult_t res1 = cb -> fn ( comm , cb ); // Expect to reclaim memory of cb
if ( res1 != ncclSuccess ) result = res1 ;
}
NCCLCHECK ( result );
// Free plan struct
2022-05-24 02:02:31 -07:00
ncclMemoryPoolFree ( & comm -> memPool_ncclKernelPlan , plan );
2021-04-12 16:00:11 -07:00
return ncclSuccess ;
}
2022-05-24 02:02:31 -07:00
static void persistentDestructor ( void * plans_ ) {
struct ncclKernelPlan * plan = ( struct ncclKernelPlan * ) plans_ ;
struct ncclComm * comm = plan -> comm ;
while ( plan != nullptr ) {
struct ncclKernelPlan * next = plan -> next ;
ncclIntruQueueMpscEnqueue ( & comm -> callbackQueue , & plan -> reclaimer );
plan = next ;
}
}
2021-04-12 16:00:11 -07:00
2022-05-24 02:02:31 -07:00
ncclResult_t ncclLaunchPrepare ( struct ncclComm * comm ) {
ncclResult_t result = ncclSuccess ;
2024-06-11 01:28:01 -07:00
struct ncclKernelPlanner * planner = & comm -> planner ;
bool persistent = ncclCudaGraphValid ( planner -> capturingGraph );
planner -> persistent = persistent ;
2022-05-24 02:02:31 -07:00
int nPlans = 0 ;
2024-06-11 01:28:01 -07:00
if ( planner -> nTasksColl + planner -> nTasksP2p != 0 ) {
2022-05-24 02:02:31 -07:00
do {
2024-06-11 01:28:01 -07:00
memset ( & planner -> wipPlan , 0 , sizeof ( planner -> wipPlan ));
2022-05-24 02:02:31 -07:00
struct ncclKernelPlan * plan = ncclMemoryPoolAlloc < struct ncclKernelPlan > ( & comm -> memPool_ncclKernelPlan , & comm -> memPermanent );
plan -> comm = comm ;
plan -> reclaimer . fn = reclaimPlan ;
plan -> persistent = persistent ;
2024-09-10 05:57:10 -07:00
// finishPlan() promotes ncclDevWorkStorageType[Fifo|Persistent]->Args if the work can fit.
2024-06-11 01:28:01 -07:00
plan -> workStorageType = persistent ? ncclDevWorkStorageTypePersistent
: ncclDevWorkStorageTypeFifo ;
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
struct ncclKernelPlanBudget budget ;
budget . inArgsBytes = comm -> workArgsBytes - sizeof ( struct ncclDevKernelArgs );
2022-05-24 02:02:31 -07:00
// Non-persistent kernels fill up at most half of our fifo per kernel.
2024-06-11 01:28:01 -07:00
budget . outArgsBytes = plan -> persistent ? ( 1 << 30 ) : comm -> workFifoBytes / 2 ;
2022-05-24 02:02:31 -07:00
// Drain coll tasks first. This is essential since we partition tasks based
// on the work budget and p2p work isn't collective. If we were to drain p2p
// first, the place where we cut the kernel could vary by rank which would
// cause the "shortest channel first" channel picker to have divergent results.
2024-06-11 01:28:01 -07:00
if ( planner -> nTasksColl != 0 ) {
NCCLCHECKGOTO ( scheduleCollTasksToPlan ( comm , plan , & budget ), result , failure );
2022-05-24 02:02:31 -07:00
}
// And only drain p2p tasks once colls are depleted.
2024-06-11 01:28:01 -07:00
if ( planner -> nTasksColl == 0 && planner -> nTasksP2p != 0 ) {
NCCLCHECKGOTO ( scheduleP2pTasksToPlan ( comm , plan , & budget ), result , failure );
2022-05-24 02:02:31 -07:00
}
2024-06-11 01:28:01 -07:00
finishPlan ( comm , plan );
if ( plan -> workBytes != 0 ) {
ncclIntruQueueEnqueue ( & planner -> planQueue , plan );
nPlans += 1 ;
2022-05-24 02:02:31 -07:00
}
2024-06-11 01:28:01 -07:00
} while ( planner -> nTasksColl + planner -> nTasksP2p != 0 );
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
struct ncclKernelPlan * planHead = ncclIntruQueueHead ( & planner -> planQueue );
planner -> unlaunchedPlansHead = planHead ;
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
if ( nPlans == 0 ) return ncclSuccess ;
2022-05-24 02:02:31 -07:00
2022-09-27 02:31:13 -07:00
// Semantically we want these dependencies for the kernels launched:
// 1. Launch host task on hostStream.
// 2. Launch kernel, depends on all of {deviceStream, hostStream, userStream[i]...}
// 3. {deviceStream, userStream[i]...} depend on kernel.
// We achieve this by:
// 1. userStream[0] waits on deviceStream
// 2. deviceStream waits on each of userStream[1...]
// 3. host task launch on hostStream
// 4. userStream[0] waits on hostStream
// 5. kernel launch on userStream[0]
// 6. deviceStream waits on userStream[0]
// 7. userStream[1...] each waits on deviceStream
// The two-level fan-in fan-out is because ncclStrongStreamWaitStream() requires
// at least one of the two streams to be strong-stream.
2024-06-11 01:28:01 -07:00
cudaStream_t launchStream = planner -> streams -> stream ;
NCCLCHECKGOTO ( ncclStrongStreamAcquire ( planner -> capturingGraph , & comm -> sharedRes -> deviceStream ), result , failure );
2022-05-24 02:02:31 -07:00
2025-01-23 11:48:18 -06:00
if ( planner -> numStreams != 1 || persistent ) {
2022-11-03 17:42:38 +00:00
// Create dependency for device stream on user streams. First from extra user
// streams to deviceStream. Then deviceStream to first user stream.
2025-01-23 11:48:18 -06:00
for ( struct ncclCudaStreamList * l = planner -> streams -> next ; l != nullptr ; l = l -> next ) {
NCCLCHECKGOTO ( ncclStrongStreamWaitStream ( planner -> capturingGraph , & comm -> sharedRes -> deviceStream , l -> stream ), result , failure );
2022-11-03 17:42:38 +00:00
}
2025-01-23 11:48:18 -06:00
NCCLCHECKGOTO ( ncclStrongStreamWaitStream ( planner -> capturingGraph , launchStream , & comm -> sharedRes -> deviceStream ), result , failure );
} else if ( planner -> streams -> stream != comm -> lastStream && comm -> lastStream != nullptr && ! persistent ) {
2022-09-09 01:20:52 +00:00
// Stream changed from last call, create dependency against last NCCL kernel launch
2025-01-23 11:48:18 -06:00
CUDACHECK ( hipStreamWaitEvent ( planner -> streams -> stream , comm -> doneEvent , 0 ));
2022-09-09 01:20:52 +00:00
}
2024-12-18 08:26:06 -08:00
if ( persistent || comm -> persistentRefs != 0 || ncclCudaLaunchBlocking || __atomic_load_n ( & comm -> noncapturedRefs , __ATOMIC_ACQUIRE )) {
2022-09-27 02:31:13 -07:00
// We have to launch host tasks to push proxy args. We are careful to only
// do this if necessary since host tasks impose a high performance cost in CUDA.
2022-05-24 02:02:31 -07:00
bool acquired = false ;
for ( struct ncclKernelPlan * plan = planHead ; plan != nullptr ; plan = plan -> next ) {
if ( plan -> hasProxyOps ) {
if ( ! acquired ) {
acquired = true ;
2024-06-11 01:28:01 -07:00
NCCLCHECKGOTO ( ncclStrongStreamAcquire ( planner -> capturingGraph , & comm -> sharedRes -> hostStream ), result , failure );
2022-05-24 02:02:31 -07:00
}
2024-12-18 08:26:06 -08:00
if ( ! persistent ) ncclAtomicRefCountIncrement ( & comm -> noncapturedRefs );
plan -> isHostCbEnq = true ;
2024-06-11 01:28:01 -07:00
NCCLCHECKGOTO ( ncclStrongStreamLaunchHost ( planner -> capturingGraph , & comm -> sharedRes -> hostStream , hostStreamPlanCallback , plan ), result , failure );
2022-05-24 02:02:31 -07:00
}
}
if ( acquired ) {
2022-09-27 02:31:13 -07:00
// Make to-be-launched kernels dependent on just-launched host stream tasks.
2024-06-11 01:28:01 -07:00
NCCLCHECKGOTO ( ncclStrongStreamWaitStream ( planner -> capturingGraph , launchStream , & comm -> sharedRes -> hostStream ), result , failure );
NCCLCHECKGOTO ( ncclStrongStreamRelease ( planner -> capturingGraph , & comm -> sharedRes -> hostStream ), result , failure );
2022-05-24 02:02:31 -07:00
}
}
if ( persistent ) {
comm -> persistentRefs += nPlans ;
2024-06-11 01:28:01 -07:00
NCCLCHECKGOTO ( ncclCudaGraphAddDestructor ( planner -> capturingGraph , persistentDestructor , ( void * ) planHead ), result , failure );
2022-05-24 02:02:31 -07:00
}
2021-04-12 16:00:11 -07:00
}
2024-12-18 08:26:06 -08:00
2024-06-11 01:28:01 -07:00
failure :
2022-05-24 02:02:31 -07:00
return result ;
}
2021-04-12 16:00:11 -07:00
2022-05-24 02:02:31 -07:00
ncclResult_t ncclLaunchKernelBefore_NoUncapturedCuda ( struct ncclComm * comm , struct ncclKernelPlan * plan ) {
// This code is called after we've checked in to the intra-process barrier
// but before launching the kernel. We are not allowed to call CUDA unless the
// kernel launch is captured.
NCCLCHECK ( uploadWork ( comm , plan ));
return ncclSuccess ;
}
2021-04-12 16:00:11 -07:00
2022-11-29 04:27:46 -08:00
#if CUDART_VERSION >= 12000
// NCCL uses the "Remote" Mem Sync domain by default
NCCL_PARAM ( MemSyncDomain , "MEM_SYNC_DOMAIN" , cudaLaunchMemSyncDomainRemote );
2022-09-27 02:31:13 -07:00
#endif
2022-05-24 02:02:31 -07:00
ncclResult_t ncclLaunchKernel ( struct ncclComm * comm , struct ncclKernelPlan * plan ) {
2024-06-11 01:28:01 -07:00
struct ncclKernelPlanner * planner = & comm -> planner ;
2025-01-23 11:48:18 -06:00
int nChannels = 0 ;
for ( int i = 0 ; i < MAXCHANNELS / 64 ; i ++ )
nChannels += countOneBits ( plan -> channelMask . masks [ i ]);
2024-06-11 01:28:01 -07:00
void * sym = plan -> kernelFn ;
dim3 grid = {( unsigned ) nChannels , 1 , 1 };
2022-05-24 02:02:31 -07:00
dim3 block = {( unsigned ) plan -> threadPerBlock , 1 , 1 };
2024-06-11 01:28:01 -07:00
int smem = ncclShmemDynamicSize ( comm -> cudaArch );
cudaStream_t launchStream = planner -> streams -> stream ;
2025-01-23 11:48:18 -06:00
void * extra [] = { plan -> kernelArgs , & plan -> kernelArgsSize };
if ( planner -> numStreams == 1 && ! plan -> persistent ) {
CUDACHECK ( hipExtLaunchKernel ( plan -> kernelFn , grid , block , extra , 0 , launchStream , NULL , comm -> doneEvent , 0 ));
comm -> lastStream = planner -> streams -> stream ;
2023-02-04 01:43:38 +00:00
return ncclSuccess ;
}
2022-09-27 02:31:13 -07:00
2025-01-23 11:48:18 -06:00
// CUfunction fn;
// CUDACHECK(cudaGetFuncBySymbol(&fn, sym));
2022-09-27 02:31:13 -07:00
#if CUDART_VERSION >= 11080
int driverVersion ;
NCCLCHECK ( ncclCudaDriverVersion ( & driverVersion ));
2022-11-29 04:27:46 -08:00
if ( driverVersion >= 11080 ) {
int compCap = comm -> compCap ;
2023-04-03 05:32:07 -07:00
unsigned int clusterSize = ( compCap == 90 ) ? comm -> config . cgaClusterSize : 0 ;
2022-09-27 02:31:13 -07:00
2024-06-11 01:28:01 -07:00
CUlaunchConfig launchConfig = { 0 };
CUlaunchAttribute launchAttrs [ 3 ];
2022-11-29 04:27:46 -08:00
int attrs = 0 ;
2022-09-27 02:31:13 -07:00
/* Cooperative Group Array (CGA)
* On sm90 and later we have an extra level of hierarchy where we
* can group together several blocks within the Grid, called
* Thread Block Clusters.
* Clusters enable multiple thread blocks running concurrently
* across multiple SMs to synchronize and collaboratively fetch
* and exchange data. A cluster of blocks are guaranteed to be
* concurrently scheduled onto a group of SMs.
* The maximum value is 8 and it must be divisible into the grid dimensions
*/
2022-11-29 04:27:46 -08:00
if ( clusterSize ) {
2022-11-03 17:42:38 +00:00
// Grid dimension must be divisible by clusterSize
if ( grid . x % clusterSize ) clusterSize = 1 ;
2024-06-11 01:28:01 -07:00
launchAttrs [ attrs ]. id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION ;
launchAttrs [ attrs ++ ]. value . clusterDim = { clusterSize , 1 , 1 };
launchAttrs [ attrs ]. id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE ;
launchAttrs [ attrs ++ ]. value . clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD ;
2022-11-29 04:27:46 -08:00
}
#if CUDART_VERSION >= 12000
if ( compCap >= 90 && driverVersion >= 12000 ) {
// Set the NCCL Mem Sync domain on CUDA 12.0 and later (sm90)
2024-06-11 01:28:01 -07:00
launchAttrs [ attrs ]. id = CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN ;
launchAttrs [ attrs ++ ]. value . memSyncDomain = ( CUlaunchMemSyncDomain ) ncclParamMemSyncDomain ();
2022-11-03 17:42:38 +00:00
}
#endif
2024-06-11 01:28:01 -07:00
launchConfig . gridDimX = grid . x ;
launchConfig . gridDimY = grid . y ;
launchConfig . gridDimZ = grid . z ;
launchConfig . blockDimX = block . x ;
launchConfig . blockDimY = block . y ;
launchConfig . blockDimZ = block . z ;
launchConfig . sharedMemBytes = smem ;
2022-09-27 02:31:13 -07:00
launchConfig . attrs = launchAttrs ;
2022-11-29 04:27:46 -08:00
launchConfig . numAttrs = attrs ;
2024-06-11 01:28:01 -07:00
launchConfig . hStream = launchStream ;
2022-09-27 02:31:13 -07:00
2024-06-11 01:28:01 -07:00
//CUDACHECK(cudaLaunchKernelExC(&launchConfig, fnAddr, args));
CUCHECK ( cuLaunchKernelEx ( & launchConfig , fn , nullptr , extra ));
2022-09-27 02:31:13 -07:00
return ncclSuccess ;
2022-09-09 01:20:52 +00:00
}
2022-09-27 02:31:13 -07:00
#endif
// Standard kernel launch
2025-01-23 11:48:18 -06:00
//cuLaunchKernel(sym, grid.x, grid.y, grid.z, block.x, block.y, block.z, smem, launchStream, nullptr, extra);
CUDACHECK ( cudaLaunchKernel ( sym , grid , block , extra , smem , launchStream ));
2022-05-24 02:02:31 -07:00
return ncclSuccess ;
}
2021-04-12 16:00:11 -07:00
2022-05-24 02:02:31 -07:00
ncclResult_t ncclLaunchKernelAfter_NoCuda ( struct ncclComm * comm , struct ncclKernelPlan * plan ) {
2024-12-18 08:26:06 -08:00
if ( ! ( plan -> persistent || ncclCudaLaunchBlocking || plan -> isHostCbEnq )) {
2023-11-13 10:26:55 -08:00
// We are not using the host stream for proxy ops and reclaimation submission.
2022-05-24 02:02:31 -07:00
NCCLCHECK ( hostStreamPlanTask ( comm , plan ));
2023-11-13 10:26:55 -08:00
} else {
// We are using the host stream for proxy ops and reclaimation submission.
// Only plans with proxy ops have a callback pushed by ncclLaunchPrepare.
// Since non-persistent plans also require reclaimation, we have to do it
// here.
if ( ! plan -> persistent && ! plan -> hasProxyOps ) {
ncclIntruQueueMpscEnqueue ( & comm -> callbackQueue , & plan -> reclaimer );
}
2022-05-24 02:02:31 -07:00
}
2018-12-13 15:56:12 -08:00
return ncclSuccess ;
}
2022-05-24 02:02:31 -07:00
ncclResult_t ncclLaunchFinish ( struct ncclComm * comm ) {
ncclResult_t result = ncclSuccess ;
2024-06-11 01:28:01 -07:00
struct ncclKernelPlanner * planner = & comm -> planner ;
2025-01-23 11:48:18 -06:00
bool persistent = ncclCudaGraphValid ( planner -> capturingGraph );
2022-05-24 02:02:31 -07:00
2024-06-11 01:28:01 -07:00
if ( ! ncclIntruQueueEmpty ( & planner -> planQueue )) {
2022-05-24 02:02:31 -07:00
// Reset queue to empty without destroying plans since those will be sent
// back to us for reclaiming via callbackQueue.
2024-06-11 01:28:01 -07:00
ncclIntruQueueConstruct ( & planner -> planQueue );
cudaStream_t launchStream = planner -> streams -> stream ; // First user stream gets launch
2022-11-29 04:27:46 -08:00
// Create dependency for deviceStream on launchStream. We know that deviceStream
// hasn't been modified since launchStream waited on it (in ncclLaunchPrepare),
// so we can say that launchStream subsumes it.
2025-01-23 11:48:18 -06:00
if ( persistent || planner -> numStreams != 1 ) NCCLCHECKGOTO ( ncclStrongStreamWaitStream ( planner -> capturingGraph , & comm -> sharedRes -> deviceStream , launchStream , /*b_subsumes_a=*/ true ), result , resume1 );
2022-05-24 02:02:31 -07:00
resume1 :
2022-11-29 04:27:46 -08:00
// Create dependency for other user streams (skip launch stream) on deviceStream.
// Again, the user streams haven't been touched since deviceStream waited on them
// so we can say they are subsumed by deviceStream.
2024-06-11 01:28:01 -07:00
struct ncclCudaStreamList * sl = planner -> streams -> next ;
planner -> streams = nullptr ; // Reset comm->planner.streams to empty.
2025-01-23 11:48:18 -06:00
while ( sl != nullptr && ( planner -> numStreams != 1 || persistent )) {
2024-06-11 01:28:01 -07:00
NCCLCHECKGOTO ( ncclStrongStreamWaitStream ( planner -> capturingGraph , sl -> stream , & comm -> sharedRes -> deviceStream , /*b_subsumes_a=*/ true ), result , resume2 );
2022-05-24 02:02:31 -07:00
resume2 :
sl = sl -> next ;
}
2025-01-23 11:48:18 -06:00
planner -> numStreams = 0 ;
2022-09-27 02:31:13 -07:00
// Release device stream as acquired in ncclLaunchPrepare()
2024-06-11 01:28:01 -07:00
NCCLCHECKGOTO ( ncclStrongStreamRelease ( planner -> capturingGraph , & comm -> sharedRes -> deviceStream ), result , resume3 );
2022-09-27 02:31:13 -07:00
resume3 :;
2022-05-24 02:02:31 -07:00
}
return result ;
}
2018-12-13 15:56:12 -08:00
/*****************************************************************************/
/* Enqueueing system : computation of kernel and proxy operations parameters */
/*****************************************************************************/
2024-06-11 01:28:01 -07:00
static inline ncclResult_t getCollNetSupport (
struct ncclComm * comm , struct ncclTaskColl * info , int * collNetSupport
) {
2023-06-13 00:19:57 -07:00
// Translate ncclAvg and PreMulSum
2024-06-11 01:28:01 -07:00
ncclRedOp_t netOp = info -> opHost ;
if ( info -> opDev . op == ncclDevPreMulSum || info -> opDev . op == ncclDevSumPostDiv ) {
netOp = ncclSum ;
}
* collNetSupport = comm -> collNetSupport ;
switch ( info -> func ) {
2024-02-05 05:06:02 -08:00
case ncclFuncAllReduce :
case ncclFuncReduce :
case ncclFuncReduceScatter :
2024-06-11 01:28:01 -07:00
* collNetSupport &= comm -> collNetSupportMatrix [ netOp ][ info -> datatype ];
2024-02-05 05:06:02 -08:00
break ;
default :
break ;
}
2021-07-08 14:12:04 -07:00
return ncclSuccess ;
}
2024-06-11 01:28:01 -07:00
static void initCollCostTable ( float ** collCostTable ) {
float ( * table )[ NCCL_NUM_PROTOCOLS ] = ( float ( * )[ NCCL_NUM_PROTOCOLS ]) collCostTable ;
for ( int a = 0 ; a < NCCL_NUM_ALGORITHMS ; a ++ ) {
for ( int p = 0 ; p < NCCL_NUM_PROTOCOLS ; p ++ ) {
table [ a ][ p ] = NCCL_ALGO_PROTO_IGNORE ;
}
}
}
2022-01-07 06:39:55 -08:00
// numPipeOps: number of pipelined ops. Can be greater than 1 in aggregation mode. Used to adjust latency.
2024-06-11 01:28:01 -07:00
static ncclResult_t updateCollCostTable (
struct ncclComm * comm , struct ncclTaskColl * info , size_t nBytes ,
int collNetSupport , int nvlsSupport , int numPipeOps ,
2024-12-18 08:26:06 -08:00
float ** collCostTable ) {
2024-06-11 01:28:01 -07:00
float ( * table )[ NCCL_NUM_PROTOCOLS ] = ( float ( * )[ NCCL_NUM_PROTOCOLS ]) collCostTable ;
2025-01-23 11:48:18 -06:00
if ( comm -> nRanks == 1 || info -> func == ncclFuncAllToAllPivot ) {
2024-06-11 01:28:01 -07:00
table [ NCCL_ALGO_RING ][ NCCL_PROTO_SIMPLE ] = 0.0 ;
return ncclSuccess ;
2021-09-08 13:56:25 -07:00
}
2023-09-26 05:47:28 -07:00
2024-06-11 01:28:01 -07:00
for ( int a = 0 ; a < NCCL_NUM_ALGORITHMS ; a ++ ) {
if (( a == NCCL_ALGO_COLLNET_DIRECT || a == NCCL_ALGO_COLLNET_CHAIN ) && collNetSupport != 1 ) continue ;
2024-09-10 05:57:10 -07:00
// CollNetDirect is only supported for up to 8 local GPUs
if ( a == NCCL_ALGO_COLLNET_DIRECT && comm -> maxLocalRanks > NCCL_MAX_DIRECT_ARITY + 1 ) continue ;
2024-06-11 01:28:01 -07:00
if (( a == NCCL_ALGO_NVLS || a == NCCL_ALGO_NVLS_TREE ) && nvlsSupport != 1 && info -> func != ncclFuncAllGather ) continue ;
if ( a == NCCL_ALGO_NVLS && collNetSupport != 1 && comm -> nNodes > 1 ) continue ;
/* now we only support single-node NVLS allgather and reducescatter */
if ( a == NCCL_ALGO_NVLS && ( info -> func == ncclFuncAllGather || info -> func == ncclFuncReduceScatter ) && comm -> nNodes > 1 ) continue ;
2024-09-10 05:57:10 -07:00
/* Tree reduceScatter doesn't support scaling yet */
if ( a == NCCL_ALGO_PAT && info -> func == ncclFuncReduceScatter
&& ( info -> opDev . op == ncclDevPreMulSum || info -> opDev . op == ncclDevSumPostDiv )) continue ;
2024-06-11 01:28:01 -07:00
for ( int p = 0 ; p < NCCL_NUM_PROTOCOLS ; p ++ ) {
2025-03-16 15:10:05 -07:00
if ( p == NCCL_PROTO_LL128 && ! ( comm -> topo -> type & RCCL_TOPO_XGMI_ALL )) {
table [ a ][ p ] = NCCL_ALGO_PROTO_IGNORE ;
continue ;
}
2024-12-18 08:26:06 -08:00
NCCLCHECK ( ncclTopoGetAlgoTime ( comm , info -> func , a , p , nBytes , numPipeOps , & table [ a ][ p ]));
// Relegate fp8 reduction trees of sufficient depth that they incur precision loss
// to be least preferred.
if ( info -> datatype == ncclFloat8e4m3 || info -> datatype == ncclFloat8e5m2 ) {
if ( a == NCCL_ALGO_RING && comm -> nRanks > 8 ) {
table [ a ][ p ] *= 1024.0 ; // Any factor large enough to act as a partition between lossy and non-lossy algos.
2024-06-11 01:28:01 -07:00
}
2023-09-26 05:47:28 -07:00
}
2021-09-08 13:56:25 -07:00
}
2018-12-13 15:56:12 -08:00
}
2019-11-19 14:57:39 -08:00
return ncclSuccess ;
}
2024-06-11 01:28:01 -07:00
static ncclResult_t topoGetAlgoInfo (
struct ncclComm * comm , struct ncclTaskColl * info , size_t nBytes ,
2024-12-18 08:26:06 -08:00
float ** collCostTable , ncclSimInfo_t * simInfo
2024-06-11 01:28:01 -07:00
) {
float ( * table )[ NCCL_NUM_PROTOCOLS ] = ( float ( * )[ NCCL_NUM_PROTOCOLS ]) collCostTable ;
float minTime = 3600000000.0 ;
int algorithm = info -> algorithm = NCCL_ALGO_UNDEF ;
int protocol = info -> protocol = NCCL_PROTO_UNDEF ;
for ( int a = 0 ; a < NCCL_NUM_ALGORITHMS ; a ++ ) {
for ( int p = 0 ; p < NCCL_NUM_PROTOCOLS ; p ++ ) {
if ( table [ a ][ p ] == NCCL_ALGO_PROTO_IGNORE ) continue ;
if ( table [ a ][ p ] >= 0.0 && table [ a ][ p ] < minTime ) {
algorithm = a ;
protocol = p ;
minTime = table [ a ][ p ];
}
}
2023-09-26 05:47:28 -07:00
}
2024-02-05 05:06:02 -08:00
2024-06-11 01:28:01 -07:00
info -> algorithm = algorithm ;
info -> protocol = protocol ;
float time = minTime ;
2023-09-26 05:47:28 -07:00
2024-09-10 05:57:10 -07:00
// Yes, we are first assigning and then testing if protocol is sane, but that's OK in this case.
// coverity[check_after_sink]
2024-06-11 01:28:01 -07:00
if ( info -> algorithm == NCCL_ALGO_UNDEF || info -> protocol == NCCL_PROTO_UNDEF ) {
2024-12-18 08:26:06 -08:00
char ncclAlgoEnvStr [ 1024 ] = "" ;
char ncclProtoEnvStr [ 1024 ] = "" ;
char * algoEnv = getenv ( "NCCL_ALGO" );
if ( algoEnv ) {
snprintf ( ncclAlgoEnvStr , 1023 , " NCCL_ALGO was set to %s." , algoEnv );
2024-06-11 01:28:01 -07:00
}
2024-12-18 08:26:06 -08:00
char * protoEnv = getenv ( "NCCL_PROTO" );
if ( protoEnv ) {
snprintf ( ncclProtoEnvStr , 1023 , " NCCL_PROTO was set to %s." , protoEnv );
2024-06-11 01:28:01 -07:00
}
2024-12-18 08:26:06 -08:00
WARN ( "Error : no algorithm/protocol available for function %s with datatype %s.%s%s" , ncclFuncToString ( info -> func ), ncclDatatypeToString ( info -> datatype ), ncclAlgoEnvStr , ncclProtoEnvStr );
return ( algoEnv || protoEnv ) ? ncclInvalidUsage : ncclInternalError ;
2024-06-11 01:28:01 -07:00
}
2025-04-10 11:43:54 -04:00
rcclUpdateCollectiveProtocol ( comm , nBytes , info );
2024-06-11 01:28:01 -07:00
if ( simInfo ) simInfo -> estimatedTime = time ;
TRACE ( NCCL_COLL , "%ld Bytes -> Algo %d proto %d time %f" , nBytes , info -> algorithm , info -> protocol , time );
int nc = comm -> nChannels ;
int nt = comm -> maxThreads [ info -> algorithm ][ info -> protocol ];
int threadThreshold = comm -> threadThresholds [ info -> algorithm ][ info -> protocol ];
if ( info -> algorithm == NCCL_ALGO_COLLNET_DIRECT ) {
2022-01-07 06:39:55 -08:00
// CollNet channel tuning
2021-04-12 16:00:11 -07:00
int ncSwitch = 16 ;
bool flag = true ;
while ( ncSwitch >= 1 && flag ) {
2024-06-11 01:28:01 -07:00
while (( flag = nBytes < nc * nt * comm -> channels [ 0 ]. collnetDirect . nHeads * threadThreshold ) && nc > ncSwitch ) {
2021-04-12 16:00:11 -07:00
if ( nc == ncSwitch + ncSwitch / 2 ) threadThreshold /= 2 ;
nc -- ;
}
ncSwitch /= 2 ;
}
2024-06-11 01:28:01 -07:00
} else if ( info -> algorithm == NCCL_ALGO_NVLS || info -> algorithm == NCCL_ALGO_NVLS_TREE ) {
2023-02-27 02:48:21 -08:00
// NVLS should not need more than 16 channels to get peak BW.
nc = comm -> nvlsChannels ;
2021-04-12 16:00:11 -07:00
} else {
2022-01-07 06:39:55 -08:00
// Ring/Tree channel tuning
2024-06-11 01:28:01 -07:00
while ( nBytes < nc * nt * threadThreshold ) {
2021-04-12 16:00:11 -07:00
if ( nc >= 2 ) nc -- ;
else break ;
}
2019-11-19 14:57:39 -08:00
}
2024-08-14 15:04:13 -06:00
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
2019-11-21 13:41:10 -08:00
#else
2024-06-11 01:28:01 -07:00
if ( info -> algorithm != NCCL_ALGO_NVLS && info -> algorithm != NCCL_ALGO_NVLS_TREE &&
info -> algorithm != NCCL_ALGO_COLLNET_DIRECT ) {
while ( nBytes < nc * nt * threadThreshold ) {
if ( nt % 128 == 0 ) nt /= 2 ;
else break ;
2024-02-05 05:06:02 -08:00
}
}
2024-06-11 01:28:01 -07:00
if ( info -> protocol == NCCL_PROTO_SIMPLE ) {
if ( info -> algorithm == NCCL_ALGO_RING ) nt += WARP_SIZE ; // Extra warp for sync
2022-01-07 06:39:55 -08:00
// More threads or sync warps needed due to split thread model
2024-06-11 01:28:01 -07:00
if ( info -> algorithm == NCCL_ALGO_TREE ) nt += 4 * WARP_SIZE ;
2019-11-19 14:57:39 -08:00
}
2022-05-24 02:02:31 -07:00
nt = nt / WARP_SIZE < 3 ? 3 * WARP_SIZE : nt ;
2019-11-21 13:41:10 -08:00
#endif
2025-01-23 11:48:18 -06:00
if ( info -> func == ncclFuncAllReduce && comm -> topo -> pivotA2ANumBiRings == 3 ) {
2022-05-25 18:55:14 -04:00
static int userTuneInput = - 2 ;
if ( userTuneInput == - 2 ) {
const char * protoStr = getenv ( "NCCL_PROTO" );
const char * algoStr = getenv ( "NCCL_ALGO" );
if ( ! protoStr && ! algoStr )
userTuneInput = 0 ;
else
userTuneInput = 1 ;
}
2025-01-23 11:48:18 -06:00
info -> nMaxChannels = nc ;
2022-05-25 18:55:14 -04:00
if ( ! userTuneInput ) {
// always respect user settings
2025-01-23 11:48:18 -06:00
if ( nBytes <= 2200008 ) {
info -> protocol = NCCL_PROTO_LL ;
info -> algorithm = NCCL_ALGO_TREE ;
info -> nMaxChannels = std :: min ( 24 , comm -> nChannels );
2022-05-25 18:55:14 -04:00
} else {
2025-01-23 11:48:18 -06:00
info -> protocol = NCCL_PROTO_SIMPLE ;
info -> algorithm = NCCL_ALGO_RING ;
2022-05-25 18:55:14 -04:00
}
}
2025-01-23 11:48:18 -06:00
} else if ( info -> func == ncclFuncAllReduce && comm -> topo -> treeDefined == 1 ) {
info -> algorithm = NCCL_ALGO_TREE ;
info -> nMaxChannels = nc ;
2022-02-21 13:09:47 +08:00
} else {
2025-01-23 11:48:18 -06:00
info -> nMaxChannels = nc ;
2018-12-13 15:56:12 -08:00
}
2024-06-11 01:28:01 -07:00
if ( info -> algorithm == NCCL_ALGO_TREE ) nt = NCCL_MAX_NTHREADS ; // Tree now uses all threads always.
2024-09-10 05:57:10 -07:00
if ( info -> algorithm == NCCL_ALGO_PAT ) nt = NCCL_MAX_NTHREADS ;
2024-06-11 01:28:01 -07:00
info -> nWarps = nt / WARP_SIZE ;
2018-12-13 15:56:12 -08:00
return ncclSuccess ;
}
2024-06-11 01:28:01 -07:00
// Use the default topo-based tuner if tuner plugin is not successful.
// Call the plugin first. Let it set algo+proto, and/or nChannels.
// Then, topoGetAlgoInfo will set algo/proto if not set, then nChannels and nThreads based on algo/proto.
// Finally, nChannels will be overriden by the plugin setting.
2025-04-23 15:44:56 -04:00
rccl_static ncclResult_t getAlgoInfo (
2024-06-11 01:28:01 -07:00
struct ncclComm * comm , struct ncclTaskColl * info ,
int collNetSupport , int nvlsSupport , int numPipeOps , ncclSimInfo_t * simInfo /* = NULL*/
) {
size_t nBytes = ncclTypeSize ( info -> datatype ) * ncclFuncMaxSendRecvCount ( info -> func , comm -> nRanks , info -> count );
info -> algorithm = NCCL_ALGO_UNDEF ;
info -> protocol = NCCL_PROTO_UNDEF ;
int nMaxChannels = 0 ;
float collCostTable [ NCCL_NUM_ALGORITHMS ][ NCCL_NUM_PROTOCOLS ];
initCollCostTable (( float ** ) collCostTable );
2024-12-18 08:26:06 -08:00
NCCLCHECK ( updateCollCostTable ( comm , info , nBytes , collNetSupport , nvlsSupport , numPipeOps , ( float ** ) collCostTable ));
2024-06-11 01:28:01 -07:00
if ( comm -> tuner != NULL ) {
2024-12-18 08:26:06 -08:00
size_t elementSize = ncclTypeSize ( info -> datatype );
size_t sendbuffSize = elementSize * ncclFuncSendCount ( info -> func , comm -> nRanks , info -> count );
size_t recvbuffSize = elementSize * ncclFuncRecvCount ( info -> func , comm -> nRanks , info -> count );
struct ncclReg * regSendBuf ;
struct ncclReg * regRecvBuf ;
NCCLCHECK ( ncclRegFind ( comm , info -> sendbuff , sendbuffSize , & regSendBuf ));
NCCLCHECK ( ncclRegFind ( comm , info -> recvbuff , recvbuffSize , & regRecvBuf ));
int regBuff = (( regSendBuf && regRecvBuf ) || ( ncclCudaGraphValid ( comm -> planner . capturingGraph ) && ncclParamGraphRegister ()));
2024-06-11 01:28:01 -07:00
NCCLCHECK ( comm -> tuner -> getCollInfo (
comm -> tunerContext , info -> func , nBytes ,
numPipeOps , ( float ** ) collCostTable , NCCL_NUM_ALGORITHMS , NCCL_NUM_PROTOCOLS ,
2024-12-18 08:26:06 -08:00
regBuff , & nMaxChannels ));
2024-06-11 01:28:01 -07:00
}
2024-12-18 08:26:06 -08:00
NCCLCHECK ( topoGetAlgoInfo ( comm , info , nBytes , ( float ** ) collCostTable , simInfo ));
2024-06-11 01:28:01 -07:00
info -> nMaxChannels = nMaxChannels == 0 ? info -> nMaxChannels : nMaxChannels ;
2024-02-05 05:06:02 -08:00
return ncclSuccess ;
}
2018-12-13 15:56:12 -08:00
2024-06-11 01:28:01 -07:00
NCCL_PARAM ( NvlsTreeMaxChunkSize , "NVLSTREE_MAX_CHUNKSIZE" , - 2 );
2021-06-09 13:24:26 -07:00
2024-06-11 01:28:01 -07:00
static ncclResult_t calcCollChunking (
struct ncclComm * comm , struct ncclTaskColl * info , int nChannels , size_t nBytes ,
/*outputs*/ uint32_t * outChunkSize , uint32_t * outDirectFlags , struct ncclProxyOp * proxyOp
) {
ncclPattern_t pattern ;
size_t grainSize = ncclProtoGrainSize ( info -> protocol );
2018-12-13 15:56:12 -08:00
2024-06-11 01:28:01 -07:00
switch ( info -> func ) {
case ncclFuncBroadcast :
pattern = info -> algorithm == NCCL_ALGO_TREE ? ncclPatternTreeDown : ncclPatternPipelineFrom ;
break ;
case ncclFuncReduce :
pattern = info -> algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo ;
break ;
case ncclFuncReduceScatter :
2024-09-10 05:57:10 -07:00
pattern =
info -> algorithm == NCCL_ALGO_PAT ? ncclPatternPatUp :
info -> algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls :
info -> algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect :
ncclPatternRing ;
break ;
2024-06-11 01:28:01 -07:00
case ncclFuncAllGather :
pattern =
2024-09-10 05:57:10 -07:00
info -> algorithm == NCCL_ALGO_PAT ? ncclPatternPatDown :
2024-06-11 01:28:01 -07:00
info -> algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls :
info -> algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect :
ncclPatternRing ;
break ;
2025-01-23 11:48:18 -06:00
case ncclFuncAllToAllPivot :
pattern = ncclPatternRing ;
break ;
2024-06-11 01:28:01 -07:00
case ncclFuncAllReduce :
pattern =
info -> algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls :
info -> algorithm == NCCL_ALGO_NVLS_TREE ? ncclPatternNvlsTree :
info -> algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect :
info -> algorithm == NCCL_ALGO_COLLNET_CHAIN ? ncclPatternCollnetChain :
info -> algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown :
ncclPatternRingTwice ;
break ;
default :
WARN ( "Unknown pattern for collective %d algorithm %d" , info -> func , info -> algorithm );
return ncclInternalError ;
}
int nstepsPerLoop , nchunksPerLoop ;
2024-12-18 08:26:06 -08:00
size_t loopOffset = 0 ;
2024-06-11 01:28:01 -07:00
int stepSize = comm -> buffSizes [ info -> protocol ] / NCCL_STEPS ;
int chunkSteps = ( info -> protocol == NCCL_PROTO_SIMPLE && info -> algorithm == NCCL_ALGO_RING ) ? info -> chunkSteps : 1 ;
int sliceSteps = ( info -> protocol == NCCL_PROTO_SIMPLE && info -> algorithm == NCCL_ALGO_RING ) ? info -> sliceSteps : 1 ;
int chunkSize = stepSize * chunkSteps ;
if ( info -> protocol == NCCL_PROTO_LL ) chunkSize /= 2 ;
if ( info -> protocol == NCCL_PROTO_LL128 ) chunkSize = ( chunkSize / NCCL_LL128_LINEELEMS ) * NCCL_LL128_DATAELEMS ;
2024-02-05 05:06:02 -08:00
2025-01-23 11:48:18 -06:00
if ( info -> algorithm == NCCL_ALGO_TREE && info -> protocol == NCCL_PROTO_SIMPLE ) {
if ( pattern == ncclPatternTreeUpDown ) {
2024-04-28 19:48:53 -07:00
// Optimize chunkSize / nSteps
2025-01-23 11:48:18 -06:00
while ( nBytes / ( nChannels * chunkSize ) < comm -> channels [ 0 ]. tree . depth * 8 && chunkSize > 131072 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * chunkSize ) < comm -> channels [ 0 ]. tree . depth * 4 && chunkSize > 65536 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * chunkSize ) < comm -> channels [ 0 ]. tree . depth && chunkSize > 32768 ) chunkSize /= 2 ;
2024-04-28 19:48:53 -07:00
}
2025-01-23 11:48:18 -06:00
} else if ( info -> algorithm == NCCL_ALGO_RING && info -> protocol == NCCL_PROTO_SIMPLE ) {
if ( pattern == ncclPatternPipelineFrom || pattern == ncclPatternPipelineTo ) {
2024-05-29 07:59:47 -07:00
// Optimize chunkSize / nSteps
2025-01-23 11:48:18 -06:00
while ( nBytes / ( nChannels * chunkSize ) < 64 && chunkSize > 262144 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * chunkSize ) < 32 && chunkSize > 131072 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * chunkSize ) < 16 && chunkSize > 65536 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * chunkSize ) < 8 && chunkSize > 32768 ) chunkSize /= 2 ;
2024-05-29 07:59:47 -07:00
}
2025-01-23 11:48:18 -06:00
} else if ( info -> algorithm == NCCL_ALGO_COLLNET_DIRECT ) {
2020-01-16 16:02:42 -08:00
// Optimize chunkSize / nSteps
2024-06-11 01:28:01 -07:00
while ( nBytes / ( nChannels * comm -> channels [ 0 ]. collnetDirect . nHeads * chunkSize ) < comm -> channels [ 0 ]. collnetDirect . depth * 64 && chunkSize > 131072 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * comm -> channels [ 0 ]. collnetDirect . nHeads * chunkSize ) < comm -> channels [ 0 ]. collnetDirect . depth * 8 && chunkSize > 65536 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * comm -> channels [ 0 ]. collnetDirect . nHeads * chunkSize ) < comm -> channels [ 0 ]. collnetDirect . depth * 8 && chunkSize > 32768 ) chunkSize /= 2 ;
} else if ( info -> algorithm == NCCL_ALGO_COLLNET_CHAIN ) {
stepSize = comm -> buffSizes [ NCCL_PROTO_SIMPLE ] / NCCL_STEPS ;
2024-02-05 05:06:02 -08:00
chunkSize = std :: min ( 256 * 1024 , stepSize * chunkSteps );
2024-06-11 01:28:01 -07:00
while ( nBytes / ( nChannels * chunkSize ) < comm -> channels [ 0 ]. collnetChain . depth * 64 && chunkSize > 131072 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * chunkSize ) < comm -> channels [ 0 ]. collnetChain . depth * 8 && chunkSize > 65536 ) chunkSize /= 2 ;
while ( nBytes / ( nChannels * chunkSize ) < comm -> channels [ 0 ]. collnetChain . depth && chunkSize > 32768 ) chunkSize /= 2 ;
} else if ( info -> algorithm == NCCL_ALGO_NVLS ) {
int maxChunkSize = comm -> nvlsChunkSize ;
if ( comm -> nNodes > 1 && comm -> bandwidths [ ncclFuncAllReduce ][ NCCL_ALGO_NVLS ][ NCCL_PROTO_SIMPLE ] < 150 ) maxChunkSize = 32768 ;
2023-04-03 05:32:07 -07:00
if ( chunkSize > maxChunkSize ) chunkSize = maxChunkSize ;
2024-09-10 05:57:10 -07:00
// Use uint64_t so that concurrentOps*chunkSize*X does not overflow.
// However, nChannels * comm->channels[0].nvls.nHeads should easily fit in 32 bits.
// coverity[overflow_before_widen]
2024-06-11 01:28:01 -07:00
uint64_t concurrentOps = nChannels * comm -> channels [ 0 ]. nvls . nHeads ;
2024-02-05 05:06:02 -08:00
if (( nBytes < ( 64 * ( concurrentOps * chunkSize ))) && ( chunkSize > 65536 )) chunkSize = 65536 ;
if (( nBytes < ( 8 * ( concurrentOps * chunkSize ))) && ( chunkSize > 32768 )) chunkSize = 32768 ;
if (( nBytes < ( 2 * ( concurrentOps * chunkSize ))) && ( chunkSize > 16384 )) chunkSize = 16384 ;
2024-06-11 01:28:01 -07:00
} else if ( info -> algorithm == NCCL_ALGO_NVLS_TREE ) {
2024-09-10 05:57:10 -07:00
// Use uint64_t so that concurrentOps*chunkSize*X does not overflow.
// However, nChannels * comm->channels[0].nvls.nHeads should easily fit in 32 bits.
// coverity[overflow_before_widen]
2024-06-11 01:28:01 -07:00
uint64_t concurrentOps = nChannels * comm -> channels [ 0 ]. nvls . nHeads ;
chunkSize = comm -> nvlsChunkSize ;
2024-03-26 06:08:55 -07:00
int maxChunkSize = ( int ) ncclParamNvlsTreeMaxChunkSize ();
2024-06-11 01:28:01 -07:00
if ( maxChunkSize == - 2 ) maxChunkSize = comm -> nNodes >= 4 ? 65536 : chunkSize ;
2024-02-05 05:06:02 -08:00
chunkSize = std :: min ( chunkSize , maxChunkSize );
if (( nBytes < ( 32 * ( concurrentOps * chunkSize ))) && ( chunkSize > 262144 )) chunkSize = 262144 ;
if (( nBytes < ( 16 * ( concurrentOps * chunkSize ))) && ( chunkSize > 131072 )) chunkSize = 131072 ;
if (( nBytes < ( 4 * ( concurrentOps * chunkSize ))) && ( chunkSize > 65536 )) chunkSize = 65536 ;
if (( nBytes < ( 1 * ( concurrentOps * chunkSize ))) && ( chunkSize > 32768 )) chunkSize = 32768 ;
2024-06-11 01:28:01 -07:00
} else if ( info -> algorithm == NCCL_ALGO_TREE && info -> protocol == NCCL_PROTO_LL128 ) {
int nNodes = comm -> nNodes ;
float ppn = comm -> nRanks / ( float ) nNodes ;
2020-05-12 14:40:18 -07:00
float nstepsLL128 = 1 + log2i ( nNodes ) + 0.1 * ppn ;
2024-09-10 05:57:10 -07:00
// Yes, we are OK with the division on the left side of the < operand being integer.
// coverity[integer_division]
2024-02-05 05:06:02 -08:00
while ( nBytes / ( nChannels * chunkSize ) < nstepsLL128 * 64 / ppn && chunkSize > 131072 ) chunkSize /= 2 ;
2024-09-10 05:57:10 -07:00
// coverity[integer_division]
2024-02-05 05:06:02 -08:00
while ( nBytes / ( nChannels * chunkSize ) < nstepsLL128 * 16 / ppn && chunkSize > 32768 ) chunkSize /= 2 ;
2024-09-10 05:57:10 -07:00
} else if ( info -> func == ncclFuncAllGather && info -> algorithm == NCCL_ALGO_PAT ) {
while ( chunkSize * nChannels * 32 > nBytes && chunkSize > 65536 ) chunkSize /= 2 ;
} else if ( info -> func == ncclFuncReduceScatter && info -> algorithm == NCCL_ALGO_PAT ) {
while ( chunkSize * nChannels * 16 > nBytes && chunkSize > 65536 ) chunkSize /= 2 ;
2024-02-05 05:06:02 -08:00
}
2024-06-11 01:28:01 -07:00
// Compute directFlags of work struct.
if ( info -> algorithm == NCCL_ALGO_COLLNET_DIRECT ) {
// Set direct direction for broadcast-gather (read or write)
2024-12-18 08:26:06 -08:00
* outDirectFlags = ( nBytes / nChannels <= 1024 * 4 ) ? NCCL_P2P_READ : NCCL_P2P_WRITE ;
2024-06-11 01:28:01 -07:00
} else {
* outDirectFlags = 0 ;
}
// Compute nSteps for proxies
chunkSize = chunkSize / grainSize * grainSize ; // align chunkSize to multiple grainSize
2024-12-18 08:26:06 -08:00
switch ( pattern ) {
case ncclPatternTreeUp :
case ncclPatternTreeDown :
case ncclPatternTreeUpDown :
case ncclPatternPatUp :
case ncclPatternPatDown :
case ncclPatternPipelineFrom :
case ncclPatternPipelineTo :
case ncclPatternCollnetChain :
nstepsPerLoop = nchunksPerLoop = 1 ;
break ;
case ncclPatternNvls :
nstepsPerLoop = 1 ; nchunksPerLoop = comm -> channels [ 0 ]. nvls . nHeads ;
loopOffset = nChannels * chunkSize * comm -> channels [ 0 ]. nvls . headRank ;
break ;
case ncclPatternCollnetDirect :
nstepsPerLoop = 1 ; nchunksPerLoop = comm -> channels [ 0 ]. collnetDirect . nHeads ;
loopOffset = nChannels * chunkSize * comm -> channels [ 0 ]. collnetDirect . headRank ;
break ;
case ncclPatternRing :
nstepsPerLoop = comm -> nRanks - 1 ; nchunksPerLoop = comm -> nRanks ;
break ;
case ncclPatternRingTwice :
nstepsPerLoop = 2 * ( comm -> nRanks - 1 ); nchunksPerLoop = comm -> nRanks ;
break ;
case ncclPatternNvlsTree :
nstepsPerLoop = 1 ; nchunksPerLoop = comm -> channels [ 0 ]. nvls . nHeads ;
break ;
default :
WARN ( "Unknown pattern %d" , pattern );
return ncclInternalError ;
}
// Compute nSteps for proxies
size_t loopSize = size_t ( nChannels ) * nchunksPerLoop * chunkSize ;
int nLoops = ( int ) DIVUP ( nBytes , loopSize );
2024-06-11 01:28:01 -07:00
memset ( proxyOp , 0 , sizeof ( * proxyOp ));
proxyOp -> nsteps = nstepsPerLoop * nLoops * chunkSteps ;
proxyOp -> sliceSteps = sliceSteps ;
proxyOp -> chunkSteps = chunkSteps ;
proxyOp -> chunkSize = chunkSize ;
2024-12-18 08:26:06 -08:00
proxyOp -> sliceSize = chunkSize / chunkSteps * sliceSteps ;
proxyOp -> loopSize = loopSize ;
proxyOp -> loopOffset = loopOffset ;
2024-06-11 01:28:01 -07:00
proxyOp -> protocol = info -> protocol ;
proxyOp -> dtype = info -> datatype ;
2024-12-18 08:26:06 -08:00
proxyOp -> algorithm = info -> algorithm ;
2024-06-11 01:28:01 -07:00
if ( info -> opDev . op == ncclDevPreMulSum || info -> opDev . op == ncclDevSumPostDiv ) {
proxyOp -> redOp = ncclSum ; // Network sees avg as sum
} else {
proxyOp -> redOp = info -> opHost ;
}
proxyOp -> pattern = pattern ;
proxyOp -> coll = info -> func ;
proxyOp -> root = info -> root ;
2024-12-18 08:26:06 -08:00
proxyOp -> isOneRPN = comm -> isOneRPN ;
2020-09-04 14:35:05 -07:00
// This is used by P2P to reduce the receive buffer size. We don't use it in collectives
// because some protocols need to transmit more than the total size, plus they sometimes
// round up
2024-06-11 01:28:01 -07:00
proxyOp -> nbytes = stepSize * sliceSteps ;
2024-12-18 08:26:06 -08:00
if ( info -> regBufType & NCCL_NET_REG_BUFFER ) {
2024-03-26 06:08:55 -07:00
proxyOp -> reg = 1 ;
2024-12-18 08:26:06 -08:00
if ( info -> algorithm == NCCL_ALGO_COLLNET_DIRECT || info -> algorithm == NCCL_ALGO_NVLS || info -> algorithm == NCCL_ALGO_COLLNET_CHAIN ) {
if ( proxyOp -> isOneRPN ) {
proxyOp -> nsteps = 1 ;
proxyOp -> loopOffset = 0 ;
proxyOp -> sendbuff = ( uint8_t * ) info -> sendbuff ;
proxyOp -> sendMhandle = info -> sendMhandle ;
} else {
if ( info -> func == ncclFuncAllGather || info -> func == ncclFuncReduceScatter ) {
proxyOp -> nbytes = nBytes / nchunksPerLoop ;
proxyOp -> loopSize = proxyOp -> loopSize / nchunksPerLoop ;
proxyOp -> loopOffset = 0 ;
if ( info -> func == ncclFuncAllGather ) {
proxyOp -> sendbuff = ( uint8_t * ) info -> sendbuff ;
proxyOp -> sendMhandle = info -> sendMhandle ;
}
} else {
proxyOp -> sendbuff = ( uint8_t * ) info -> recvbuff ;
proxyOp -> sendMhandle = info -> recvMhandle ;
}
}
} else if ( info -> algorithm == NCCL_ALGO_RING ) {
if ( proxyOp -> isOneRPN && info -> func == ncclFuncAllGather ) {
proxyOp -> chunkSize = NCCL_MAX_NET_SIZE ;
proxyOp -> sliceSize = NCCL_MAX_NET_SIZE ;
proxyOp -> chunkSteps = 1 ;
proxyOp -> sliceSteps = 1 ;
proxyOp -> loopSize = size_t ( nChannels ) * nchunksPerLoop * proxyOp -> chunkSize ;
proxyOp -> nsteps = DIVUP ( nBytes , proxyOp -> loopSize ) * nstepsPerLoop ;
proxyOp -> loopOffset = 0 ;
}
} else {
WARN ( "Net registration invalid algorithm %s" , ncclAlgoToString ( info -> algorithm ));
return ncclInternalError ;
}
2024-06-11 01:28:01 -07:00
proxyOp -> recvMhandle = info -> recvMhandle ;
proxyOp -> recvbuff = ( uint8_t * ) info -> recvbuff ;
proxyOp -> nbytes = nBytes ;
2024-03-26 06:08:55 -07:00
} else {
proxyOp -> reg = 0 ;
}
2024-06-11 01:28:01 -07:00
if ( pattern == ncclPatternCollnetDirect ) {
proxyOp -> specifics . collnetDirect . nNodes = comm -> nNodes ;
proxyOp -> specifics . collnetDirect . node = comm -> node ;
if ( info -> func == ncclFuncAllGather || info -> func == ncclFuncReduceScatter ) {
proxyOp -> specifics . collnetDirect . sizePerRank = info -> count * ncclTypeSize ( info -> datatype );
2024-04-23 13:33:19 -07:00
}
2022-02-21 13:09:47 +08:00
}
2024-09-10 05:57:10 -07:00
if ( pattern == ncclPatternPatUp || pattern == ncclPatternPatDown ) {
proxyOp -> nbytes = DIVUP ( nBytes , nChannels );
}
2024-12-18 08:26:06 -08:00
* outChunkSize = proxyOp -> chunkSize ;
2018-12-13 15:56:12 -08:00
return ncclSuccess ;
}
2021-09-08 13:56:25 -07:00
static ncclResult_t hostToDevRedOp (
ncclDevRedOpFull * opFull , ncclRedOp_t op , ncclDataType_t datatype , ncclComm * comm
) {
union {
2023-09-26 05:47:28 -07:00
int8_t i8 ; uint8_t u8 ;
int32_t i32 ; uint32_t u32 ;
int64_t i64 ; uint64_t u64 ;
2024-12-18 08:26:06 -08:00
__half f16 ; float f32 ; double f64 ;
2025-04-23 20:46:36 -07:00
#if defined(RCCL_BFLOAT16)
hip_bfloat16 bf16 ;
2021-09-08 13:56:25 -07:00
#endif
2025-04-23 20:46:36 -07:00
#if defined(RCCL_FLOAT8)
rccl_float8 f8 ;
rccl_bfloat8 bf8 ;
2024-12-18 08:26:06 -08:00
#endif
2021-09-08 13:56:25 -07:00
void * ptr ;
};
u64 = 0 ;
opFull -> scalarArgIsPtr = false ;
2023-09-26 05:47:28 -07:00
opFull -> proxyOp = op ;
int nbits = 8 * ncclTypeSize ( datatype );
2024-09-10 05:57:10 -07:00
if ( nbits <= 0 ) return ncclInvalidArgument ;
2023-09-26 05:47:28 -07:00
uint64_t allBits = uint64_t ( - 1 ) >> ( 64 - nbits );
uint64_t signBit = allBits ^ ( allBits >> 1 );
2024-12-18 08:26:06 -08:00
bool datatype_signed = false ;
2021-09-08 13:56:25 -07:00
switch ( int ( op )) {
case ncclSum : opFull -> op = ncclDevSum ; break ;
case ncclProd : opFull -> op = ncclDevProd ; break ;
2023-09-26 05:47:28 -07:00
case ncclMin :
case ncclMax :
opFull -> op = ncclDevMinMax ;
opFull -> scalarArg = 0 ;
// The xormask used by ncclFuncMinMax<[u]int> is the XOR of the sign bit
// for signed (opposed to unsigned) types and all the bits for max (opposed to min).
if ( datatype == ncclInt8 || datatype == ncclInt32 || datatype == ncclInt64 ) {
opFull -> scalarArg ^= signBit ;
}
opFull -> scalarArg ^= ( op == ncclMax ) ? allBits : 0 ;
break ;
2021-09-08 13:56:25 -07:00
case ncclAvg :
switch (( int ) datatype ) {
case ncclInt8 : case ncclInt32 : case ncclInt64 :
2024-12-18 08:26:06 -08:00
datatype_signed = true ;
// no break, we want to fall through...
2021-09-08 13:56:25 -07:00
case ncclUint8 : case ncclUint32 : case ncclUint64 :
opFull -> op = ncclDevSumPostDiv ;
2024-12-18 08:26:06 -08:00
u64 = comm -> nRanks << 1 | datatype_signed ;
2021-09-08 13:56:25 -07:00
break ;
2025-04-23 20:46:36 -07:00
#if defined(RCCL_FLOAT8)
2024-12-18 08:26:06 -08:00
case ncclFloat8e4m3 :
2021-09-08 13:56:25 -07:00
opFull -> op = ncclDevPreMulSum ;
2025-04-23 20:46:36 -07:00
f8 = static_cast < rccl_float8 > ( float ( 1.0 / comm -> nRanks ));
2021-09-08 13:56:25 -07:00
break ;
2024-12-18 08:26:06 -08:00
case ncclFloat8e5m2 :
2021-09-08 13:56:25 -07:00
opFull -> op = ncclDevPreMulSum ;
2025-04-23 20:46:36 -07:00
bf8 = static_cast < rccl_bfloat8 > ( float ( 1.0 / comm -> nRanks ));
2021-09-08 13:56:25 -07:00
break ;
2024-12-18 08:26:06 -08:00
#endif
2021-09-08 13:56:25 -07:00
case ncclFloat16 :
2024-03-09 07:17:53 +08:00
opFull -> op = ncclDevPreMulSum ;
2021-09-08 13:56:25 -07:00
f16 = __float2half ( float ( 1.0 / comm -> nRanks )); // __double2half not supported pre CUDA 11.x
2024-03-09 07:17:53 +08:00
break ;
2025-04-23 20:46:36 -07:00
#if defined(RCCL_BFLOAT16)
2021-09-08 13:56:25 -07:00
case ncclBfloat16 :
2024-03-09 07:17:53 +08:00
opFull -> op = ncclDevPreMulSum ;
2025-04-23 20:46:36 -07:00
bf16 = ( hip_bfloat16 )( float ( 1.0 / comm -> nRanks ));
2024-03-09 07:17:53 +08:00
break ;
2021-09-08 13:56:25 -07:00
#endif
case ncclFloat32 :
opFull -> op = ncclDevPreMulSum ;
f32 = float ( 1.0 / comm -> nRanks );
break ;
case ncclFloat64 :
opFull -> op = ncclDevPreMulSum ;
f64 = 1.0 / comm -> nRanks ;
break ;
}
opFull -> scalarArgIsPtr = false ;
opFull -> scalarArg = u64 ;
break ;
default : // user created
int ix = int ( ncclUserRedOpMangle ( comm , op )) - int ( ncclNumOps );
ncclUserRedOp * user = & comm -> userRedOps [ ix ];
if ( datatype != user -> datatype ) {
WARN ( "Data type supplied to user-created ncclRedOp_t does not match type "
"given to reduction operation" );
return ncclInvalidArgument ;
}
* opFull = user -> opFull ;
break ;
}
return ncclSuccess ;
}
2024-06-11 01:28:01 -07:00
// Converts `info` to a task and adds it to `comm->planner`. The exception is with
2022-05-24 02:02:31 -07:00
// single rank communicators, collectives are issued as `ncclMemcpyAsync`s and
// thus don't need a task.
2024-02-05 05:06:02 -08:00
static ncclResult_t taskAppend ( struct ncclComm * comm , struct ncclInfo * info ) {
2024-06-11 01:28:01 -07:00
struct ncclKernelPlanner * planner = & comm -> planner ;
2024-02-05 05:06:02 -08:00
2022-05-24 02:02:31 -07:00
if ( info -> coll == ncclFuncSend || info -> coll == ncclFuncRecv ) {
int peer = info -> root ;
ssize_t nBytes = info -> count * ncclTypeSize ( info -> datatype );
bool isSendNotRecv = info -> coll == ncclFuncSend ;
// Must be in thread local group before tasks can be alloc'd in `comm->memScoped`.
ncclGroupCommJoin ( info -> comm );
struct ncclTaskP2p * p2p = ncclMemoryStackAlloc < struct ncclTaskP2p > ( & comm -> memScoped );
p2p -> buff = ( void * ) info -> recvbuff ;
2024-09-10 05:57:10 -07:00
p2p -> count = info -> count ;
p2p -> datatype = info -> datatype ;
p2p -> root = info -> root ;
2022-05-24 02:02:31 -07:00
p2p -> bytes = nBytes ;
2025-02-03 08:55:27 -08:00
p2p -> opCount = comm -> opCount ;
2022-05-24 02:02:31 -07:00
ncclIntruQueueEnqueue (
2024-06-11 01:28:01 -07:00
isSendNotRecv ? & planner -> peers [ peer ]. sendQueue : & planner -> peers [ peer ]. recvQueue ,
2022-05-24 02:02:31 -07:00
p2p );
2024-06-11 01:28:01 -07:00
planner -> nTasksP2p += 1 ;
2022-05-24 02:02:31 -07:00
// Mark channels that need pre-connect
if ( comm -> rank != peer ) {
2024-06-11 01:28:01 -07:00
if ( ! ( isSendNotRecv ? planner -> peers [ peer ]. sendSeen : planner -> peers [ peer ]. recvSeen )) {
( isSendNotRecv ? planner -> peers [ peer ]. sendSeen : planner -> peers [ peer ]. recvSeen ) = true ;
int round = 0 ;
while ( peer != ( isSendNotRecv ? comm -> p2pSchedule [ round ]. sendRank
: comm -> p2pSchedule [ round ]. recvRank )) {
round += 1 ;
}
uint8_t base = ncclP2pChannelBaseForRound ( comm , round );
2022-05-24 02:02:31 -07:00
for ( int c = 0 ; c < comm -> p2pnChannelsPerPeer ; c ++ ) {
2025-02-26 09:48:03 -05:00
int channelId = ncclP2pChannelForPart ( comm -> p2pnChannels , base , c , comm -> p2pnChannelsPerPeer , comm -> nNodes );
2022-05-24 02:02:31 -07:00
if ( isSendNotRecv ) {
2023-04-03 05:32:07 -07:00
if ( comm -> channels [ channelId ]. peers [ peer ] -> send [ 1 ]. connected == 0 ) { // P2P uses only 1 connector
2024-05-15 16:58:28 -05:00
//comm->connectSend[peer] |= (1UL<<channelId);
2025-01-23 11:48:18 -06:00
comm -> connectSend [ peer ]. masks [ channelId / 64 ] |= ( 1UL << ( channelId % 64 ));
2022-05-24 02:02:31 -07:00
ncclGroupCommPreconnect ( comm );
}
2023-06-21 20:54:24 -07:00
if ( comm -> p2pNet && comm -> channels [ channelId ]. peers [ peer ] -> send [ NCCL_CONN_IDX_P2P_NET ]. connected == 0 ) {
2024-05-15 16:58:28 -05:00
//comm->connectSend[peer+comm->nRanks*NCCL_CONN_IDX_P2P_NET] |= (1UL<<channelId);
2025-01-23 11:48:18 -06:00
comm -> connectSend [ peer + comm -> nRanks * NCCL_CONN_IDX_P2P_NET ]. masks [ channelId / 64 ] |= ( 1UL << ( channelId % 64 ));
2022-09-09 01:20:52 +00:00
ncclGroupCommPreconnect ( comm );
}
2022-05-24 02:02:31 -07:00
} else {
2023-04-03 05:32:07 -07:00
if ( comm -> channels [ channelId ]. peers [ peer ] -> recv [ 1 ]. connected == 0 ) { // P2P uses only 1 connector
2024-05-15 16:58:28 -05:00
//comm->connectRecv[peer] |= (1UL<<channelId);
2025-01-23 11:48:18 -06:00
comm -> connectRecv [ peer ]. masks [ channelId / 64 ] |= ( 1UL << ( channelId % 64 ));
2022-05-24 02:02:31 -07:00
ncclGroupCommPreconnect ( comm );
}
2023-06-21 20:54:24 -07:00
if ( comm -> p2pNet && comm -> channels [ channelId ]. peers [ peer ] -> recv [ NCCL_CONN_IDX_P2P_NET ]. connected == 0 ) {
2024-05-15 16:58:28 -05:00
//comm->connectRecv[peer+comm->nRanks*NCCL_CONN_IDX_P2P_NET] |= (1UL<<channelId);
2025-01-23 11:48:18 -06:00
comm -> connectRecv [ peer + comm -> nRanks * NCCL_CONN_IDX_P2P_NET ]. masks [ channelId / 64 ] |= ( 1UL << ( channelId % 64 ));
2022-09-09 01:20:52 +00:00
ncclGroupCommPreconnect ( comm );
}
2022-05-24 02:02:31 -07:00
}
}
}
2020-05-12 14:40:18 -07:00
}
2018-12-13 15:56:12 -08:00
} else {
2024-06-11 01:28:01 -07:00
// Empty collectives can be discarded.
if ( info -> count == 0 ) return ncclSuccess ;
2024-12-18 08:26:06 -08:00
if ( info -> datatype == ncclFloat8e4m3 || info -> datatype == ncclFloat8e5m2 ) {
if ( comm -> minCompCap < 90 ) {
WARN ( "FP8 reduction support begins with sm90 capable devices." );
return ncclInvalidArgument ;
}
}
2022-05-24 02:02:31 -07:00
// Copy reduction op state from op handle into info struct here since the
// op handle may be destroyed before ncclGroupEnd().
2024-06-11 01:28:01 -07:00
struct ncclDevRedOpFull opDev ;
NCCLCHECK ( hostToDevRedOp ( & opDev , info -> op , info -> datatype , comm ));
2022-05-24 02:02:31 -07:00
2023-09-26 05:47:28 -07:00
if ( comm -> nRanks == 1 ) {
2024-06-11 01:28:01 -07:00
NCCLCHECK ( ncclLaunchOneRank ( info -> recvbuff , info -> sendbuff , info -> count , opDev , info -> datatype , info -> stream ));
2022-05-24 02:02:31 -07:00
return ncclSuccess ;
} else {
// Must be in thread local group before tasks can be alloc'd in `comm->memScoped`.
ncclGroupCommJoin ( info -> comm );
2024-06-11 01:28:01 -07:00
struct ncclTaskColl * t = ncclMemoryStackAlloc < struct ncclTaskColl > ( & comm -> memScoped );
t -> func = info -> coll ;
t -> sendbuff = info -> sendbuff ;
t -> recvbuff = info -> recvbuff ;
t -> count = info -> count ;
t -> root = info -> root ;
t -> datatype = info -> datatype ;
size_t elementSize = ncclTypeSize ( t -> datatype );
2025-01-23 11:48:18 -06:00
if ( t -> func == ncclFuncAllGather || t -> func == ncclFuncBroadcast || t -> func == ncclFuncAllToAllPivot ) {
2024-06-11 01:28:01 -07:00
t -> count *= elementSize ;
t -> datatype = ncclInt8 ;
elementSize = 1 ;
}
t -> trafficBytes = t -> count * elementSize * ncclFuncTrafficPerByte ( t -> func , comm -> nRanks );
t -> opHost = info -> op ;
t -> opDev = opDev ; // C++ struct assignment
t -> chunkSteps = info -> chunkSteps ;
t -> sliceSteps = info -> sliceSteps ;
2025-02-03 08:55:27 -08:00
t -> opCount = comm -> opCount ;
2024-06-11 01:28:01 -07:00
planner -> nTasksColl += 1 ;
ncclTaskCollSorterInsert ( & planner -> collSorter , t , t -> trafficBytes );
2022-05-24 02:02:31 -07:00
}
2021-09-08 13:56:25 -07:00
}
2020-05-12 14:40:18 -07:00
2024-06-11 01:28:01 -07:00
if ( info -> stream != planner -> streamRecent || planner -> streams == nullptr ) {
planner -> streamRecent = info -> stream ;
struct ncclCudaStreamList * l = planner -> streams ;
2022-05-24 02:02:31 -07:00
while ( true ) {
if ( l == nullptr ) { // Got to the end, this must be a new stream.
struct ncclCudaGraph graph ;
2024-09-10 05:57:10 -07:00
NCCLCHECK ( ncclCudaGetCapturingGraph ( & graph , info -> stream ));
2024-06-11 01:28:01 -07:00
if ( planner -> streams != nullptr && ! ncclCudaGraphSame ( planner -> capturingGraph , graph )) {
2022-05-24 02:02:31 -07:00
WARN ( "Streams given to a communicator within a NCCL group must either be all uncaptured or all captured by the same graph." );
return ncclInvalidUsage ;
}
2024-06-11 01:28:01 -07:00
planner -> capturingGraph = graph ; // C++ struct assignment
2022-05-24 02:02:31 -07:00
// Add stream to list
l = ncclMemoryStackAlloc < struct ncclCudaStreamList > ( & comm -> memScoped );
l -> stream = info -> stream ;
2024-06-11 01:28:01 -07:00
l -> next = planner -> streams ;
planner -> streams = l ;
2025-01-23 11:48:18 -06:00
planner -> numStreams ++ ;
2022-05-24 02:02:31 -07:00
break ;
}
if ( l -> stream == info -> stream )
break ; // Already seen stream.
2022-09-27 02:31:13 -07:00
l = l -> next ;
2020-05-12 14:40:18 -07:00
}
2022-05-24 02:02:31 -07:00
}
return ncclSuccess ;
}
2020-05-12 14:40:18 -07:00
2022-05-24 02:02:31 -07:00
ncclResult_t ncclEnqueueCheck ( struct ncclInfo * info ) {
NCCLCHECK ( ncclGroupStartInternal ());
ncclResult_t ret = ncclSuccess ;
int devOld = - 1 ;
2022-08-18 02:53:17 -07:00
2024-03-26 06:08:55 -07:00
NCCLCHECKGOTO ( CommCheck ( info -> comm , info -> opName , "comm" ), ret , fail );
2022-08-18 02:53:17 -07:00
// Check whether communicator is ready to communicate
NCCLCHECKGOTO ( ncclCommEnsureReady ( info -> comm ), ret , fail );
2022-05-24 02:02:31 -07:00
if ( info -> comm -> checkPointers ) {
2022-11-07 14:09:26 -08:00
CUDACHECKGOTO ( cudaGetDevice ( & devOld ), ret , fail );
CUDACHECKGOTO ( cudaSetDevice ( info -> comm -> cudaDev ), ret , fail );
2022-05-24 02:02:31 -07:00
}
2022-08-18 02:53:17 -07:00
NCCLCHECKGOTO ( ArgsCheck ( info ), ret , fail );
2021-04-12 16:00:11 -07:00
2025-01-23 11:48:18 -06:00
INFO ( NCCL_COLL , "%s: opCount %lx sendbuff %p recvbuff %p count %zu datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d" ,
2020-05-12 14:40:18 -07:00
info -> opName , info -> comm -> opCount , info -> sendbuff , info -> recvbuff , info -> count ,
2025-01-23 11:48:18 -06:00
info -> datatype , info -> op , info -> root , info -> comm , info -> comm -> nRanks , info -> stream ,
info -> comm -> planner . nTasksP2p + info -> comm -> planner . nTasksColl ,
info -> comm -> localRankToRank [ info -> comm -> localRank ]);
2024-06-11 01:28:01 -07:00
TRACE_CALL ( "nccl%s(%" PRIx64 ",%" PRIx64 ",%zu,%d,%d,%d,%p,%p)" , info -> opName , reinterpret_cast < int64_t > ( info -> sendbuff ), reinterpret_cast < int64_t > ( info -> recvbuff ), info -> count , info -> datatype , info -> op , info -> root , info -> comm , info -> stream );
2021-04-12 16:00:11 -07:00
2022-08-18 02:53:17 -07:00
NCCLCHECKGOTO ( taskAppend ( info -> comm , info ), ret , fail );
2021-04-12 16:00:11 -07:00
2022-08-18 02:53:17 -07:00
exit :
2022-11-07 14:09:26 -08:00
if ( devOld != - 1 ) CUDACHECK ( cudaSetDevice ( devOld ));
2022-05-24 02:02:31 -07:00
ncclGroupErrCheck ( ret );
NCCLCHECK ( ncclGroupEndInternal ());
2022-08-18 02:53:17 -07:00
/* if depth is 1, ncclGroupEndInternal() will trigger group ops. The state can change
* so we have to check state here. */
2024-09-10 05:57:10 -07:00
if ( info -> comm && ! info -> comm -> config . blocking ) { NCCLCHECK ( ncclCommGetAsyncError ( info -> comm , & ret )); }
2021-09-08 13:56:25 -07:00
return ret ;
2022-08-18 02:53:17 -07:00
fail :
2023-04-03 05:32:07 -07:00
if ( info -> comm && ! info -> comm -> config . blocking ) ( void ) ncclCommSetAsyncError ( info -> comm , ret );
2022-08-18 02:53:17 -07:00
goto exit ;
2021-09-08 13:56:25 -07:00
}
NCCL_API ( ncclResult_t , ncclRedOpCreatePreMulSum , ncclRedOp_t * op , void * scalar , ncclDataType_t datatype , ncclScalarResidence_t residence , ncclComm_t comm );
2024-08-22 12:36:07 -05:00
ncclResult_t ncclRedOpCreatePreMulSum_impl ( ncclRedOp_t * op , void * scalar , ncclDataType_t datatype , ncclScalarResidence_t residence , ncclComm_t comm ) {
2024-03-26 06:08:55 -07:00
NCCLCHECK ( CommCheck ( comm , "ncclRedOpCreatePreMulSum" , "comm" ));
2022-08-18 02:53:17 -07:00
/* join init thread before creating PreMulSum op. */
NCCLCHECK ( ncclCommEnsureReady ( comm ));
2021-09-08 13:56:25 -07:00
if ( comm -> userRedOpFreeHead == comm -> userRedOpCapacity ) {
// double capacity and resize
int cap = 2 * comm -> userRedOpCapacity ;
if ( cap < 4 ) cap = 4 ;
ncclUserRedOp * ops = new ncclUserRedOp [ cap ];
2024-09-10 05:57:10 -07:00
if ( comm -> userRedOpCapacity > 0 )
std :: memcpy ( ops , comm -> userRedOps , comm -> userRedOpCapacity * sizeof ( ncclUserRedOp ));
2021-09-08 13:56:25 -07:00
for ( int ix = comm -> userRedOpCapacity ; ix < cap ; ix ++ )
ops [ ix ]. freeNext = ix + 1 ;
delete [] comm -> userRedOps ;
comm -> userRedOps = ops ;
comm -> userRedOpCapacity = cap ;
}
// pop from free list
int ix = comm -> userRedOpFreeHead ;
ncclUserRedOp * user = & comm -> userRedOps [ ix ];
comm -> userRedOpFreeHead = user -> freeNext ;
user -> freeNext = - 1 ; // allocated
user -> datatype = datatype ;
user -> opFull . op = ncclDevPreMulSum ;
if ( residence == ncclScalarHostImmediate ) {
2024-09-10 05:57:10 -07:00
int size = ncclTypeSize ( datatype );
if ( size < 1 ) return ncclInternalError ;
2021-09-08 13:56:25 -07:00
user -> opFull . scalarArgIsPtr = false ;
2024-09-10 05:57:10 -07:00
std :: memcpy ( & user -> opFull . scalarArg , scalar , size );
2021-09-08 13:56:25 -07:00
} else {
user -> opFull . scalarArgIsPtr = true ;
user -> opFull . scalarArg = reinterpret_cast < uint64_t > ( scalar );
2018-12-13 15:56:12 -08:00
}
2021-09-08 13:56:25 -07:00
* op = ncclRedOp_t ( int ( ncclNumOps ) + ix );
* op = ncclUserRedOpMangle ( comm , * op );
2025-04-19 00:21:27 -04:00
// ! recording at sink
NCCLCHECK ( Recorder :: instance (). record ( rrRedOpCreatePreMulSum , * op , comm , datatype , residence , scalar ));
2022-05-24 02:02:31 -07:00
TRACE_CALL ( "ncclRedOpCreatePreMulSum(%d,%p,%d,%d,%p)" , * op , scalar , datatype , residence , comm );
2021-09-08 13:56:25 -07:00
return ncclSuccess ;
}
NCCL_API ( ncclResult_t , ncclRedOpDestroy , ncclRedOp_t op , ncclComm_t comm );
2024-08-22 12:36:07 -05:00
ncclResult_t ncclRedOpDestroy_impl ( ncclRedOp_t op , ncclComm_t comm ) {
2025-04-19 00:21:27 -04:00
NCCLCHECK ( Recorder :: instance (). record ( rrRedOpDestroy , op , comm ));
2021-09-08 13:56:25 -07:00
if ( 0 <= int ( op ) && int ( op ) < int ( ncclNumOps )) {
WARN ( "ncclRedOpDestroy : operator is a NCCL builtin." );
return ncclInvalidArgument ;
2018-12-13 15:56:12 -08:00
}
2024-09-10 05:57:10 -07:00
// int(ncclMaxRedOp) < int(op) will always be false due to the sizes of
// the datatypes involved, and that's by design. We keep the check though
// just as a reminder.
// coverity[result_independent_of_operands]
2021-09-08 13:56:25 -07:00
if ( int ( op ) < 0 || int ( ncclMaxRedOp ) < int ( op )) {
WARN ( "ncclRedOpDestroy : operator is garbage." );
return ncclInvalidArgument ;
}
2023-02-27 02:48:21 -08:00
if ( comm == NULL ) {
WARN ( "ncclRedOpDestroy : invalid communicator passed." );
return ncclInvalidArgument ;
}
2021-09-08 13:56:25 -07:00
int ix = int ( ncclUserRedOpMangle ( comm , op )) - int ( ncclNumOps );
if ( comm -> userRedOpCapacity <= ix || comm -> userRedOps [ ix ]. freeNext != - 1 ) {
WARN ( "ncclRedOpDestroy : operator unknown to this communicator." );
return ncclInvalidArgument ;
}
// push to free list
comm -> userRedOps [ ix ]. freeNext = comm -> userRedOpFreeHead ;
comm -> userRedOpFreeHead = ix ;
2022-05-24 02:02:31 -07:00
TRACE_CALL ( "ncclRedOpDestroy(%d,%p)" , op , comm );
2021-09-08 13:56:25 -07:00
return ncclSuccess ;
2024-05-10 07:31:12 -07:00
}