Add MSCCL Support (#658)
* Add MSCCL support * Add alignment and message size checking * Fix nRanks checking, in-place and out-of-place tests and group call handling * Fix hipGraph unit test * Change MSCCL init warning to INFO * Revise license info
Tá an tiomantas seo le fáil i:
tiomanta ag
GitHub
tuismitheoir
b953544a59
tiomantas
adafc0f759
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2019-2020 Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
INCLUDE(CheckIncludeFiles)
|
||||
@@ -138,7 +139,8 @@ if (BUILD_ALLREDUCE_ONLY)
|
||||
set(CU_SOURCES
|
||||
src/collectives/device/all_reduce.cu
|
||||
src/collectives/device/sendrecv.cu
|
||||
src/collectives/device/functions.cu)
|
||||
src/collectives/device/functions.cu
|
||||
src/collectives/device/msccl_kernel.cu)
|
||||
else()
|
||||
set(CU_SOURCES
|
||||
src/collectives/device/all_reduce.cu
|
||||
@@ -149,7 +151,8 @@ else()
|
||||
src/collectives/device/reduce_scatter.cu
|
||||
src/collectives/device/sendrecv.cu
|
||||
src/collectives/device/onerank_reduce.cu
|
||||
src/collectives/device/functions.cu)
|
||||
src/collectives/device/functions.cu
|
||||
src/collectives/device/msccl_kernel.cu)
|
||||
endif()
|
||||
|
||||
set(CPP_SOURCES)
|
||||
@@ -223,6 +226,12 @@ set(HEADER_SOURCES
|
||||
src/include/nvtx3/nvToolsExtCudaRt.h
|
||||
src/include/nvtx3/nvToolsExtCuda.h
|
||||
src/include/nvtx3/nvToolsExtOpenCL.h
|
||||
src/include/msccl/msccl_kernel.h
|
||||
src/include/msccl/msccl_lifecycle.h
|
||||
src/include/msccl/msccl_parser.h
|
||||
src/include/msccl/msccl_setup.h
|
||||
src/include/msccl/msccl_status.h
|
||||
src/include/msccl/msccl_struct.h
|
||||
src/graph/rings.h
|
||||
src/graph/rome_models.h
|
||||
src/graph/topo.h
|
||||
@@ -242,6 +251,7 @@ set(API_SOURCES
|
||||
src/collectives/scatter.cc
|
||||
src/collectives/gather.cc
|
||||
src/collectives/sendrecv.cc
|
||||
src/collectives/msccl.cc
|
||||
src/net.cc)
|
||||
foreach(filename ${API_SOURCES})
|
||||
string(REPLACE ".cc"
|
||||
@@ -278,6 +288,10 @@ set(CC_SOURCES
|
||||
src/misc/param.cc
|
||||
src/misc/rocmwrap.cc
|
||||
src/misc/strongstream.cc
|
||||
src/misc/msccl/msccl_lifecycle.cc
|
||||
src/misc/msccl/msccl_parser.cc
|
||||
src/misc/msccl/msccl_setup.cc
|
||||
src/misc/msccl/msccl_status.cc
|
||||
src/transport/coll_net.cc
|
||||
src/transport/net.cc
|
||||
src/transport/net_ib.cc
|
||||
@@ -314,6 +328,7 @@ set(HIPIFY_SOURCES
|
||||
src/collectives/reduce_scatter_api.cpp
|
||||
src/collectives/scatter_api.cpp
|
||||
src/collectives/sendrecv_api.cpp
|
||||
src/collectives/msccl_api.cpp
|
||||
src/debug.cpp
|
||||
src/enqueue.cpp
|
||||
src/graph/xml.cpp
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -7,11 +8,20 @@
|
||||
#include "enqueue.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream);
|
||||
ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
|
||||
NVTX3_FUNC_RANGE_IN(nccl_domain);
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
|
||||
}
|
||||
|
||||
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
|
||||
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
|
||||
ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS };
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "enqueue.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
|
||||
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
|
||||
NVTX3_FUNC_RANGE_IN(nccl_domain);
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream);
|
||||
}
|
||||
|
||||
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
|
||||
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
|
||||
ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -9,10 +10,18 @@
|
||||
#include "collectives.h"
|
||||
#include "graph/topo.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
ncclComm_t comm, hipStream_t stream);
|
||||
ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
ncclComm_t comm, hipStream_t stream) {
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream);
|
||||
}
|
||||
|
||||
size_t rankOffset = count * ncclTypeSize(datatype);
|
||||
size_t rankAlign = rankOffset & ((~rankOffset) + 1);
|
||||
// Determine Pivot A2A support now that we know number of channels
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -8,12 +9,20 @@
|
||||
#include "enqueue.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclAllToAllv, const void *sendbuff, const size_t sendcounts[], const size_t sdispls[],
|
||||
void *recvbuff, const size_t recvcounts[], const size_t rdispls[],
|
||||
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream);
|
||||
ncclResult_t ncclAllToAllv(const void *sendbuff, const size_t sendcounts[], const size_t sdispls[],
|
||||
void *recvbuff, const size_t recvcounts[], const size_t rdispls[],
|
||||
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) {
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls,
|
||||
0, datatype, 0, 0, ncclSum, mscclFuncAllToAllv, comm, stream);
|
||||
}
|
||||
|
||||
int nRanks;
|
||||
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -7,11 +8,19 @@
|
||||
#include "enqueue.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclBroadcast, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
|
||||
ncclComm_t comm, cudaStream_t stream);
|
||||
ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
|
||||
ncclComm_t comm, cudaStream_t stream) {
|
||||
NVTX3_FUNC_RANGE_IN(nccl_domain);
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream);
|
||||
}
|
||||
|
||||
struct ncclInfo info = { ncclFuncBroadcast, "Broadcast",
|
||||
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
|
||||
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS };
|
||||
|
||||
@@ -0,0 +1,347 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#include "msccl/msccl_struct.h"
|
||||
#include "msccl/msccl_kernel.h"
|
||||
|
||||
__shared__ struct mscclShmemData mscclShmem;
|
||||
|
||||
#define MSCCL_MAX_ITER 65536
|
||||
|
||||
// flags are a 3-tuple of (workindex, gridoffset_iter, step) and it follows a lexicographical order. a threadblock is ahead of another iff its flag is ahead
|
||||
#define COMPUTE_FLAG(__WORKINDEX__,__GRIDOFFSET_ITER__,__STEP__) \
|
||||
MSCCL_MAX_ITER*MSCCL_MAX_NUM_STEPS*(uint64_t)__WORKINDEX__ + ((uint64_t)__GRIDOFFSET_ITER__ * MSCCL_MAX_NUM_STEPS + (uint64_t)__STEP__)
|
||||
|
||||
// a copy of the volatile load/store from prims_ll
|
||||
template<typename U>
|
||||
__device__ static U load(U *src) {
|
||||
union {
|
||||
U elt;
|
||||
uint8_t u1;
|
||||
uint16_t u2;
|
||||
uint32_t u4;
|
||||
uint64_t u8;
|
||||
};
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if(sizeof(U) == 1)
|
||||
u1 = __builtin_nontemporal_load((uint8_t*)src);
|
||||
else if(sizeof(U) == 2)
|
||||
u2 = __builtin_nontemporal_load((uint16_t*)src);
|
||||
else if(sizeof(U) == 4)
|
||||
u4 = __builtin_nontemporal_load((uint32_t*)src);
|
||||
else
|
||||
u8 = __builtin_nontemporal_load((uint64_t*)src);
|
||||
#else
|
||||
if(sizeof(U) == 1)
|
||||
asm("ld.volatile.global.b8 %0,[%1];" : "=r"(u4) : "l"(src));
|
||||
else if(sizeof(U) == 2)
|
||||
asm("ld.volatile.global.b16 %0,[%1];" : "=h"(u2) : "l"(src));
|
||||
else if(sizeof(U) == 4)
|
||||
asm("ld.volatile.global.b32 %0,[%1];" : "=r"(u4) : "l"(src));
|
||||
else
|
||||
asm("ld.volatile.global.b64 %0,[%1];" : "=l"(u8) : "l"(src));
|
||||
#endif
|
||||
return elt;
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
__device__ static void store(U *dst, U val) {
|
||||
union {
|
||||
U elt;
|
||||
uint8_t u1;
|
||||
uint16_t u2;
|
||||
uint32_t u4;
|
||||
uint64_t u8;
|
||||
};
|
||||
elt = val;
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if(sizeof(U) == 1)
|
||||
__builtin_nontemporal_store(u1, (uint8_t*)dst);
|
||||
else if(sizeof(U) == 2)
|
||||
__builtin_nontemporal_store(u2, (uint16_t*)dst);
|
||||
else if(sizeof(U) == 4)
|
||||
__builtin_nontemporal_store(u4, (uint32_t*)dst);
|
||||
else
|
||||
__builtin_nontemporal_store(u8, (uint64_t*)dst);
|
||||
#else
|
||||
if(sizeof(U) == 1)
|
||||
asm("st.volatile.global.b8 [%0],%1;" :: "l"(dst), "r"(u4));
|
||||
else if(sizeof(U) == 2)
|
||||
asm("st.volatile.global.b16 [%0],%1;" :: "l"(dst), "h"(u2));
|
||||
else if(sizeof(U) == 4)
|
||||
asm("st.volatile.global.b32 [%0],%1;" :: "l"(dst), "r"(u4));
|
||||
else
|
||||
asm("st.volatile.global.b64 [%0],%1;" :: "l"(dst), "l"(u8));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ static void barrier(int nthreads, uint64_t* barrier_next, uint64_t* barriers) {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if (nthreads != WARP_SIZE)
|
||||
barrier_by_group();
|
||||
#else
|
||||
asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(15));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Copy 8-byte aligned data. You must call with at least `(bytes+7)/8` threads.
|
||||
inline __device__ static void copyToShmem8(int tid, void* dst, void const* src, int bytes) {
|
||||
int offset = 8 * tid;
|
||||
if (offset < bytes) {
|
||||
uint64_t *src2 = (uint64_t*)((char const*)src + offset);
|
||||
uint64_t *dst2 = (uint64_t*)((char*)dst + offset);
|
||||
*dst2 = *src2;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ static void threadBlockCopy(
|
||||
uint64_t *dst, uint64_t const *src, uint64_t size, int tid, int nthreads) {
|
||||
for (int i = tid; i < size; i += nthreads) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ __forceinline__ void mscclRunInterpreter(
|
||||
struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork work) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
const int nthreads = NCCL_MAX_NTHREADS;
|
||||
|
||||
// initialize barriers
|
||||
if (tid == 0) {
|
||||
for (auto i = 0; i < NCCL_MAX_GROUPS; i++) {
|
||||
ncclShmem.groups[i].barrier = 0;
|
||||
for (auto j = 0; j < NCCL_MAX_GROUPS; j++) ncclShmem.groups[i].barrier_next[j] = 0;
|
||||
}
|
||||
}
|
||||
uint64_t* mscclBarrierNext = ncclShmem.groups[0].barrier_next;
|
||||
uint64_t* mscclBarriers = &ncclShmem.groups[0].barrier;
|
||||
|
||||
// initialize mscclShmem.mscclTB
|
||||
threadBlockCopy(
|
||||
(uint64_t *)&mscclShmem.mscclTB, (uint64_t *)(algo->mscclTBs + bid),
|
||||
sizeof(struct mscclThreadBlock), tid, nthreads);
|
||||
__synclds(); // publish mscclShmem.mscclTB.channelId
|
||||
|
||||
// initialize ncclShmem and mscclShmem.work
|
||||
int channelId = mscclShmem.mscclTB.channelId;
|
||||
{
|
||||
void *dst, *src;
|
||||
int bytes;
|
||||
// Use first 3 warps to load comm, channel, and work into shmem
|
||||
switch (tid/WARP_SIZE) {
|
||||
case 0:
|
||||
dst = &ncclShmem.comm;
|
||||
src = comm;
|
||||
bytes = sizeof(ncclDevComm);
|
||||
static_assert(sizeof(ncclDevComm) <= sizeof(uint64_t) * WARP_SIZE, "ncclDevComm cannot be loaded by a single warp in one insn.");
|
||||
break;
|
||||
case 1:
|
||||
// Get address of channel without incurring indirect load from ncclDevComm::channels
|
||||
dst = &ncclShmem.channel;
|
||||
src = &((ncclDevCommAndChannels*)comm)->channels[channelId];
|
||||
bytes = sizeof(ncclDevChannel);
|
||||
static_assert(sizeof(ncclDevChannel) <= sizeof(uint64_t) * WARP_SIZE, "ncclDevChannel cannot be loaded by a single warp in one insn.");
|
||||
break;
|
||||
case 2:
|
||||
dst = &mscclShmem.work;
|
||||
src = &work;
|
||||
bytes = sizeof(mscclWork);
|
||||
static_assert(sizeof(mscclWork) <= sizeof(uint64_t) * WARP_SIZE, "mscclWork cannot be loaded by a single warp in one insn.");
|
||||
break;
|
||||
default:
|
||||
bytes = 0;
|
||||
break;
|
||||
}
|
||||
copyToShmem8(tid%WARP_SIZE, dst, src, bytes);
|
||||
}
|
||||
__synclds(); // publish shmem
|
||||
|
||||
// Deference reduce args if required
|
||||
if (tid == 0 && mscclShmem.work.hasReduce && mscclShmem.work.redOpArgIsPtr) {
|
||||
switch (sizeof(T)) {
|
||||
case 1:
|
||||
mscclShmem.work.redOpArg = *reinterpret_cast<uint8_t*>(mscclShmem.work.redOpArg);
|
||||
break;
|
||||
case 2:
|
||||
mscclShmem.work.redOpArg = *reinterpret_cast<uint16_t*>(mscclShmem.work.redOpArg);
|
||||
break;
|
||||
case 4:
|
||||
mscclShmem.work.redOpArg = *reinterpret_cast<uint32_t*>(mscclShmem.work.redOpArg);
|
||||
break;
|
||||
case 8:
|
||||
mscclShmem.work.redOpArg = *reinterpret_cast<uint64_t*>(mscclShmem.work.redOpArg);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
__synclds(); // publish shmem
|
||||
|
||||
// User pointers for primitives
|
||||
T* thisInput = (T*)mscclShmem.work.sendBuff;
|
||||
T* thisOutput = (T*)mscclShmem.work.recvBuff;
|
||||
T* thisScratch = (T*)mscclShmem.work.scratchBuffer;
|
||||
int recvPeer = mscclShmem.mscclTB.recvPeer;
|
||||
int sendPeer = mscclShmem.mscclTB.sendPeer;
|
||||
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? MSCCL_CHUNKSTEPS : 1));
|
||||
int minChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_LL)
|
||||
minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T));
|
||||
if (Proto::Id == NCCL_PROTO_LL128) {
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2;
|
||||
}
|
||||
|
||||
RedOp redFn(mscclShmem.work.redOpArg);
|
||||
Primitives<T, RedOp, FanAsymmetric<1,1>, 1, Proto, 0> prims
|
||||
(tid, nthreads, &recvPeer, &sendPeer, thisInput, thisOutput, mscclShmem.work.redOpArg);
|
||||
|
||||
const ssize_t sizePerMscclChunk = mscclShmem.work.count / mscclShmem.work.nChunksPerLoop;
|
||||
uint32_t maxAllowedCount = mscclShmem.work.maxAllowedCount;
|
||||
|
||||
// msccl flags all start out with 0. this is used as a part of the flag to make sure different work items deal with different synchronization flags
|
||||
// this still needs more work. when we make a way around the queue, the flag might have been set to undesired values. will be fixed in subsequent versions.
|
||||
const int64_t workIndex = mscclShmem.work.workIndex;
|
||||
volatile struct mscclFlag* mscclFlags = mscclShmem.work.syncFlags;
|
||||
for (ssize_t gridOffset = 0, iter = 0; gridOffset < sizePerMscclChunk; gridOffset += chunkSize, iter++) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, sizePerMscclChunk-gridOffset);
|
||||
realChunkSize = roundUp(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
else
|
||||
realChunkSize = min(chunkSize, divUp(sizePerMscclChunk-gridOffset, minChunkSize)*minChunkSize);
|
||||
realChunkSize = int(realChunkSize);
|
||||
int nelem = min(realChunkSize, sizePerMscclChunk-gridOffset);
|
||||
|
||||
ssize_t srcOffset, dstOffset;
|
||||
T *srcPointer, *dstPointer;
|
||||
int step = 0;
|
||||
for (int i = 0; i < mscclShmem.mscclTB.nSteps; i++){
|
||||
struct mscclTransmission* t = &mscclShmem.mscclTB.transmissions[i];
|
||||
// first wait if there is a dependence
|
||||
int16_t numDependencies = t->numDependencies;
|
||||
if (numDependencies > 0){
|
||||
if (tid < numDependencies) {
|
||||
int16_t dependentPointer = t->dependencePointer;
|
||||
int8_t dependentBid = mscclShmem.mscclTB.dependentBid[dependentPointer+tid];
|
||||
int16_t dependentStep = mscclShmem.mscclTB.dependentStep[dependentPointer+tid];
|
||||
uint64_t goalFlag = COMPUTE_FLAG(workIndex, iter, dependentStep);
|
||||
while ((mscclFlags + dependentBid)->flag < goalFlag);
|
||||
}
|
||||
step += numDependencies-1;
|
||||
barrier(nthreads, mscclBarrierNext, mscclBarriers);
|
||||
}
|
||||
|
||||
srcPointer = (t->srcBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->srcBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch);
|
||||
dstPointer = (t->dstBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->dstBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch);
|
||||
prims.setDataPtrs(srcPointer, dstPointer);
|
||||
int count = t->count;
|
||||
for (int c = 0; c < count; c += maxAllowedCount) {
|
||||
srcOffset = gridOffset + (ssize_t) (t->srcOffset+c) * sizePerMscclChunk;
|
||||
dstOffset = gridOffset + (ssize_t) (t->dstOffset+c) * sizePerMscclChunk;
|
||||
int thisCount = min(maxAllowedCount, count - c);
|
||||
int thisNelem = nelem * thisCount;
|
||||
if (t->type == MSCCL_SEND)
|
||||
prims.sendWithBarrier(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end.
|
||||
else if (t->type == MSCCL_RECV)
|
||||
prims.recv(dstOffset, thisNelem);
|
||||
else if (t->type == MSCCL_REDUCE) {
|
||||
int numReductions = t->numReductions;
|
||||
if (thisNelem < nthreads){
|
||||
if (tid < thisNelem){
|
||||
dstOffset = gridOffset + (ssize_t) (t->dstOffset+c) * sizePerMscclChunk;
|
||||
T* dstIndex = dstPointer + dstOffset + tid;
|
||||
T o = load(dstIndex);
|
||||
for (int r = 0; r < numReductions; r++){
|
||||
srcOffset = gridOffset + (ssize_t) (mscclShmem.mscclTB.reductionSrcOffsets[t->reductionPointer+r]+c) * sizePerMscclChunk;
|
||||
T t = load(srcPointer + srcOffset + tid);
|
||||
o = redFn(t,o);
|
||||
}
|
||||
store(dstIndex, o);
|
||||
}
|
||||
barrier(nthreads, mscclBarrierNext, mscclBarriers);
|
||||
} else {
|
||||
T* srcs[MSCCL_MAX_REDUCE_FUSION+1]; // +1 is for SIMPLE protocol as dst is added in the list of srcs
|
||||
dstOffset = gridOffset + (ssize_t) (t->dstOffset+c) * sizePerMscclChunk;
|
||||
T* dst = dstPointer + dstOffset;
|
||||
for (int r = 0; r < numReductions; r++) {
|
||||
srcOffset = gridOffset + (ssize_t) (mscclShmem.mscclTB.reductionSrcOffsets[t->reductionPointer+r]+c) * sizePerMscclChunk;
|
||||
srcs[r] = srcPointer + srcOffset;
|
||||
}
|
||||
prims.reduce(srcs, numReductions, &dst, 1, thisNelem);
|
||||
}
|
||||
if (c == 0) step += (numReductions-1); // only advance step once!
|
||||
} else if (t->type == MSCCL_RECV_COPY_SEND)
|
||||
prims.recvCopySend(dstOffset, thisNelem);
|
||||
else if (t->type == MSCCL_RECV_REDUCE_SEND)
|
||||
prims.recvReduceSend(srcOffset, thisNelem);
|
||||
else if (t->type == MSCCL_RECV_REDUCE_COPY_SEND)
|
||||
prims.recvReduceCopySend(srcOffset, dstOffset, thisNelem);
|
||||
else if (t->type == MSCCL_RECV_REDUCE_COPY)
|
||||
prims.recvReduceCopy(srcOffset, dstOffset, thisNelem);
|
||||
else if (t->type == MSCCL_LOCAL_COPY)
|
||||
prims.localCopy(srcPointer+srcOffset, dstPointer+dstOffset, thisNelem);
|
||||
else
|
||||
return;
|
||||
}
|
||||
if (t->hasDependence && tid == nthreads-1){
|
||||
mscclFlags[bid].flag = (uint64_t) COMPUTE_FLAG(workIndex, iter, step);
|
||||
}
|
||||
step++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL>(comm, algo, work); \
|
||||
} \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL128>(comm, algo, work); \
|
||||
} \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS>>(comm, algo, work); \
|
||||
}
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t)
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv)
|
||||
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC()
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -363,6 +364,20 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void mscclStoreData(T *dst, uint64_t val, int eltN) {
|
||||
union {
|
||||
uint64_t u8;
|
||||
T elt[EltPerLine];
|
||||
};
|
||||
u8 = val;
|
||||
#pragma unroll
|
||||
for(int i=0; i < EltPerLine; i++) {
|
||||
if (i==0 || i < eltN)
|
||||
store(dst+i, elt[i]);
|
||||
// dst[i] = elt[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
@@ -464,6 +479,69 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
template <int REDUCE, int COPY, int MULTISRCS, int MULTIDSTS>
|
||||
__device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) {
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
T *srcElts = srcs[0];
|
||||
T *dstElts = dsts[0];
|
||||
nelem -= tid*EltPerLine;
|
||||
srcElts += tid*EltPerLine;
|
||||
dstElts += tid*EltPerLine;
|
||||
if (MULTISRCS){
|
||||
for (int i = 1; i < nsrcs; i++){
|
||||
srcs[i] += tid*EltPerLine;
|
||||
}
|
||||
}
|
||||
if (MULTIDSTS){
|
||||
for (int i = 1; i < ndsts; i++){
|
||||
dsts[i] += tid*EltPerLine;
|
||||
}
|
||||
}
|
||||
int offset = tid;
|
||||
int eltPerTrip = nthreads*EltPerLine;
|
||||
while (nelem > 0) {
|
||||
int eltInLine = EltPerLine < nelem ? EltPerLine : nelem;
|
||||
|
||||
DataLoader dl;
|
||||
uint64_t data;
|
||||
dl.loadBegin(srcElts, eltInLine);
|
||||
srcElts += eltPerTrip;
|
||||
data = dl.loadFinish();
|
||||
if (REDUCE) {
|
||||
uint64_t dataD;
|
||||
dl.loadBegin(dstElts, eltInLine);
|
||||
dataD = dl.loadFinish();
|
||||
dataD = MULTI<RedOp,T>()(redOp, dataD, data);
|
||||
if (MULTISRCS){
|
||||
for (int i = 1; i < nsrcs; i++){
|
||||
dl.loadBegin(srcs[i], eltInLine);
|
||||
srcs[i] += eltPerTrip;
|
||||
data = dl.loadFinish();
|
||||
dataD = MULTI<RedOp,T>()(redOp, dataD, data);
|
||||
}
|
||||
}
|
||||
mscclStoreData(dstElts, dataD, eltInLine);
|
||||
dstElts += eltPerTrip;
|
||||
}
|
||||
if (COPY){
|
||||
mscclStoreData(dstElts, data, eltInLine);
|
||||
dstElts += eltPerTrip;
|
||||
if (MULTIDSTS){
|
||||
for (int i = 1; i < ndsts; i++){
|
||||
dl.loadBegin(srcs[i], eltInLine);
|
||||
srcs[i] += eltPerTrip;
|
||||
data = dl.loadFinish();
|
||||
mscclStoreData(dsts[i], data, eltInLine);
|
||||
dsts[i] += eltPerTrip;
|
||||
}
|
||||
}
|
||||
}
|
||||
nelem -= eltPerTrip;
|
||||
offset += nthreads;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
||||
recvBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
|
||||
recvStep[i] = conn->step;
|
||||
@@ -662,4 +740,21 @@ private:
|
||||
__device__ void recvSend(int eltN) {
|
||||
return LLGenericOp<1, 1, -1, -1>(-1, -1, eltN, false);
|
||||
}
|
||||
|
||||
// MSCCL primitives
|
||||
__device__ void sendWithBarrier(intptr_t inpIx, int eltN) {
|
||||
send(inpIx, eltN);
|
||||
// This is the only primitive.instruction where there is no barrier at the end, add it
|
||||
barrier();
|
||||
}
|
||||
__device__ void localCopy(T* srcs, T* dsts, int eltN) {
|
||||
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
|
||||
}
|
||||
__device__ void reduce(T** srcs, int nsrcs, T** dsts, int ndsts, int eltN) {
|
||||
if (nsrcs == 1) {
|
||||
return mscclGenericOp<1,0,0,0>(srcs, 1, dsts, 1, eltN);
|
||||
} else {
|
||||
return mscclGenericOp<1,0,1,0>(srcs, nsrcs, dsts, 1, eltN);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -370,6 +371,84 @@ private:
|
||||
if (RECV) postRecv();
|
||||
}
|
||||
|
||||
template <int REDUCE, int COPY, int MULTISRCS, int MULTIDSTS>
|
||||
__device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) {
|
||||
T const *srcPtr = srcs[0];
|
||||
T *dstPtr = dsts[0];
|
||||
int wireOffset = WireWordPerSlice*warp + 2*wid;
|
||||
const int nwarps = nthreads/WARP_SIZE;
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
|
||||
nelem -= DataEltPerSlice*warp;
|
||||
srcPtr += DataEltPerSlice*warp;
|
||||
dstPtr += DataEltPerSlice*warp;
|
||||
if (MULTISRCS){
|
||||
for (int i = 1; i < nsrcs; i++){
|
||||
srcs[i] += DataEltPerSlice*warp;
|
||||
}
|
||||
}
|
||||
if (MULTIDSTS){
|
||||
for (int i = 1; i < ndsts; i++){
|
||||
dsts[i] += DataEltPerSlice*warp;
|
||||
}
|
||||
}
|
||||
while (nelem > 0) {
|
||||
const int eltInSlice = min(nelem, DataEltPerSlice);
|
||||
uint64_t regs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD];
|
||||
loadRegsBegin(regs, srcPtr, eltInSlice);
|
||||
loadRegsFinish(regs);
|
||||
if (REDUCE){
|
||||
uint64_t regsD[NCCL_LL128_SHMEM_ELEMS_PER_THREAD];
|
||||
loadRegsBegin(regsD, dstPtr, eltInSlice);
|
||||
loadRegsFinish(regsD);
|
||||
#pragma unroll
|
||||
for (int u=0; u<NCCL_LL128_SHMEM_ELEMS_PER_THREAD; u+=2) {
|
||||
regsD[u] = MULTI<RedOp, T>()(redOp, regs[u], regsD[u]);
|
||||
if (!flagThread)
|
||||
regsD[u+1] = MULTI<RedOp, T>()(redOp, regs[u+1], regsD[u+1]);
|
||||
}
|
||||
if (MULTISRCS){
|
||||
for (int i = 1; i < nsrcs; i++){
|
||||
loadRegsBegin(regs, srcs[i], eltInSlice);
|
||||
loadRegsFinish(regs);
|
||||
for (int u=0; u<NCCL_LL128_SHMEM_ELEMS_PER_THREAD; u+=2) {
|
||||
regsD[u] = MULTI<RedOp, T>()(redOp, regs[u], regsD[u]);
|
||||
if (!flagThread)
|
||||
regsD[u+1] = MULTI<RedOp, T>()(redOp, regs[u+1], regsD[u+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
storeRegs(dstPtr, regsD, eltInSlice);
|
||||
}
|
||||
if (COPY){
|
||||
storeRegs(dstPtr, regs, eltInSlice);
|
||||
if (MULTIDSTS){
|
||||
for (int i = 1; i < nsrcs; i++){
|
||||
loadRegsBegin(regs, srcs[i], eltInSlice);
|
||||
loadRegsFinish(regs);
|
||||
storeRegs(dsts[i], regs, eltInSlice);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wireOffset += WireWordPerSlice*nwarps;
|
||||
srcPtr += DataEltPerSlice*nwarps;
|
||||
dstPtr += DataEltPerSlice*nwarps;
|
||||
if (MULTISRCS){
|
||||
for (int i = 1; i < nsrcs; i++){
|
||||
srcs[i] += DataEltPerSlice*nwarps;
|
||||
}
|
||||
}
|
||||
if (MULTIDSTS){
|
||||
for (int i = 1; i < ndsts; i++){
|
||||
dsts[i] += DataEltPerSlice*nwarps;
|
||||
}
|
||||
}
|
||||
nelem -= DataEltPerSlice*nwarps;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
||||
recvBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
|
||||
recvStep[i] = conn->step;
|
||||
@@ -477,4 +556,19 @@ public:
|
||||
__device__ void recvSend(int eltN) {
|
||||
return GenericOp<1, 1, -1, -1>(-1, -1, eltN, false);
|
||||
}
|
||||
|
||||
// MSCCL primitives
|
||||
__device__ void sendWithBarrier(intptr_t inpIx, int eltN) {
|
||||
send(inpIx, eltN);
|
||||
}
|
||||
__device__ void localCopy(T* srcs, T* dsts, int eltN) {
|
||||
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
|
||||
}
|
||||
__device__ void reduce(T** srcs, int nsrcs, T** dsts, int ndsts, int eltN) {
|
||||
if (nsrcs == 1) {
|
||||
return mscclGenericOp<1,0,0,0>(srcs, 1, dsts, 1, eltN);
|
||||
} else {
|
||||
return mscclGenericOp<1,0,1,0>(srcs, nsrcs, dsts, 1, eltN);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -9,6 +10,8 @@
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct,
|
||||
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p>
|
||||
class Primitives<
|
||||
@@ -363,6 +366,35 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
template <int REDUCE, int COPY, int MULTISRCS, int MULTIDSTS>
|
||||
__device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) {
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
if (tid < nworkers) {
|
||||
if (REDUCE){
|
||||
srcs[nsrcs] = dsts[0];
|
||||
nsrcs++;
|
||||
if (MULTISRCS){
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 3, MSCCL_MAX_REDUCE_FUSION, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs, false, nsrcs, (T const**)srcs, 1, (T**)dsts, nelem);
|
||||
} else {
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 2, 2, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs, false, 2, (T const**)srcs, 1, (T**)dsts, nelem);
|
||||
}
|
||||
}
|
||||
if (COPY){
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs, false, 1, (T const**)srcs, 1, (T**)dsts, nelem);
|
||||
if (MULTISRCS) {
|
||||
for (int i = 1; i < nsrcs; i++){
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs, false, 1, (T const**)&srcs[i], 1, (T**)&dsts[i], nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
// Scatter/Gather generic op
|
||||
// skip: my own rank order in the buffer chunks
|
||||
// shift: peer offset to avoid all ranks sending to or receiving from same peer
|
||||
@@ -665,6 +697,12 @@ private:
|
||||
userBuff += delta;
|
||||
}
|
||||
|
||||
// Set MSCCL data pointers
|
||||
__device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
if (flags & RoleInput) userBuff = (T*)inputBuf;
|
||||
if (flags & RoleOutput) userBuff = (T*)outputBuf;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void send(intptr_t inpIx, int eltN) {
|
||||
genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, -1, eltN, false);
|
||||
}
|
||||
@@ -742,4 +780,19 @@ private:
|
||||
__device__ __forceinline__ void recvSend(int eltN) {
|
||||
genericOp<0, 0, 1, 1, -1, -1>(-1, -1, -1, eltN, /*postOp=*/false);
|
||||
}
|
||||
|
||||
// MSCCL primitives
|
||||
__device__ __forceinline__ void sendWithBarrier(intptr_t inpIx, int eltN) {
|
||||
send(inpIx, eltN);
|
||||
}
|
||||
__device__ __forceinline__ void localCopy(T* srcs, T* dsts, int eltN) {
|
||||
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
|
||||
}
|
||||
__device__ __forceinline__ void reduce(T** srcs, int nsrcs, T** dsts, int ndsts, int eltN) {
|
||||
if (nsrcs == 1) {
|
||||
return mscclGenericOp<1,0,0,0>(srcs, 1, dsts, 1, eltN);
|
||||
} else {
|
||||
return mscclGenericOp<1,0,1,0>(srcs, nsrcs, dsts, 1, eltN);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -8,10 +9,18 @@
|
||||
#include "enqueue.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclGather, const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, int root, ncclComm_t comm, hipStream_t stream);
|
||||
ncclResult_t ncclGather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, int root, ncclComm_t comm, hipStream_t stream) {
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
sendcount, datatype, root, 0, ncclSum, mscclFuncGather, comm, stream);
|
||||
}
|
||||
|
||||
int nRanks;
|
||||
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
||||
size_t rankOffset = sendcount * ncclTypeSize(datatype);
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#include "enqueue.h"
|
||||
#include "msccl/msccl_parser.h"
|
||||
#include "msccl/msccl_setup.h"
|
||||
#include "msccl/msccl_status.h"
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle);
|
||||
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
|
||||
if (status.freeAlgoHandles.size() == 0) {
|
||||
WARN("MSCCL: MSCCL_MAX_NUM_ALGOS (%d) limit reached", MSCCL_MAX_NUM_ALGOS);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
mscclAlgoHandle_t handle = *status.freeAlgoHandles.rbegin();
|
||||
status.freeAlgoHandles.pop_back();
|
||||
|
||||
struct mscclAlgo* hostAlgo;
|
||||
NCCLCHECK(ncclCalloc(&hostAlgo, 1));
|
||||
NCCLCHECK(mscclGetAlgoFromXmlFile(mscclAlgoFilePath, hostAlgo, status.rank));
|
||||
status.hostAlgos[handle] = hostAlgo;
|
||||
|
||||
struct mscclAlgo* devAlgo;
|
||||
NCCLCHECK(ncclCudaCalloc(&devAlgo, 1));
|
||||
CUDACHECK(hipMemcpy(devAlgo, hostAlgo, sizeof(struct mscclAlgo), hipMemcpyHostToDevice));
|
||||
status.devAlgos[handle] = devAlgo;
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_API(ncclResult_t, mscclRunAlgo,
|
||||
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
|
||||
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
|
||||
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
|
||||
mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream);
|
||||
ncclResult_t mscclRunAlgo(
|
||||
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
|
||||
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
|
||||
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
|
||||
mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle];
|
||||
struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];
|
||||
|
||||
NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType));
|
||||
|
||||
NCCLCHECK(mscclSetupScratch(hostAlgo, stream));
|
||||
|
||||
NCCLCHECK(mscclSetupSyncFlags(stream));
|
||||
|
||||
NCCLCHECK(mscclSetupConnections(hostAlgo, comm));
|
||||
|
||||
NCCLCHECK(mscclSetupProxy(hostAlgo, comm));
|
||||
|
||||
NCCLCHECK(mscclSetupKernel(sendBuff, recvBuff, count, dataType, op, hostAlgo, devAlgo, comm, stream));
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_API(ncclResult_t, mscclUnloadAlgo, mscclAlgoHandle_t mscclAlgoHandle);
|
||||
ncclResult_t mscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
|
||||
free(status.hostAlgos[mscclAlgoHandle]);
|
||||
status.hostAlgos.erase(mscclAlgoHandle);
|
||||
|
||||
CUDACHECK(hipFree(status.devAlgos[mscclAlgoHandle]));
|
||||
status.devAlgos.erase(mscclAlgoHandle);
|
||||
|
||||
status.freeAlgoHandles.push_back(mscclAlgoHandle);
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -7,11 +8,20 @@
|
||||
#include "enqueue.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
|
||||
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
|
||||
NVTX3_FUNC_RANGE_IN(nccl_domain);
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, root, 0, op, mscclFuncReduce, comm, stream);
|
||||
}
|
||||
|
||||
struct ncclInfo info = { ncclFuncReduce, "Reduce",
|
||||
sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */
|
||||
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS };
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -7,11 +8,20 @@
|
||||
#include "enqueue.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
|
||||
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
|
||||
NVTX3_FUNC_RANGE_IN(nccl_domain);
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream);
|
||||
}
|
||||
|
||||
struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
|
||||
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
|
||||
REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS };
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -8,10 +9,18 @@
|
||||
#include "enqueue.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclScatter, const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, int root,
|
||||
ncclComm_t comm, hipStream_t stream);
|
||||
ncclResult_t ncclScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, int root,
|
||||
ncclComm_t comm, hipStream_t stream) {
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
recvcount, datatype, root, 0, ncclSum, mscclFuncScatter, comm, stream);
|
||||
}
|
||||
|
||||
int nRanks;
|
||||
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
||||
size_t rankOffset = recvcount * ncclTypeSize(datatype);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -8,11 +9,20 @@
|
||||
#include "collectives.h"
|
||||
#include "argcheck.h" // Need some checks here since we access comm
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
NCCL_API(ncclResult_t, ncclSend, const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
ncclComm_t comm, cudaStream_t stream);
|
||||
ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
ncclComm_t comm, cudaStream_t stream) {
|
||||
NVTX3_FUNC_RANGE_IN(nccl_domain);
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream);
|
||||
}
|
||||
|
||||
struct ncclInfo info = { ncclFuncSend, "Send",
|
||||
NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
|
||||
1, 1 };
|
||||
@@ -28,6 +38,13 @@ NCCL_API(ncclResult_t, ncclRecv, void* recvbuff, size_t count, ncclDataType_t da
|
||||
ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
ncclComm_t comm, cudaStream_t stream) {
|
||||
NVTX3_FUNC_RANGE_IN(nccl_domain);
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream);
|
||||
}
|
||||
|
||||
struct ncclInfo info = { ncclFuncRecv, "Recv",
|
||||
NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
|
||||
1, 1 };
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -12,6 +13,8 @@
|
||||
#include "channel.h"
|
||||
#include <assert.h>
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
__thread int ncclGroupDepth = 0; // depth of ncclGroupStart nesting
|
||||
__thread ncclResult_t ncclGroupError = ncclSuccess;
|
||||
__thread struct ncclComm* ncclGroupCommHead = nullptr;
|
||||
@@ -100,6 +103,14 @@ exit:
|
||||
return ret;
|
||||
}
|
||||
|
||||
ncclResult_t ncclGroupStartInternal() {
|
||||
ncclGroupDepth++;
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
NCCLCHECK(mscclGroupStart());
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_API(ncclResult_t, ncclGroupEnd);
|
||||
ncclResult_t ncclGroupEnd() {
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
@@ -386,6 +397,10 @@ ncclResult_t ncclGroupEndInternal() {
|
||||
goto exit;
|
||||
}
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
NCCLCHECK(mscclGroupEnd());
|
||||
}
|
||||
|
||||
if ((--ncclGroupDepth) > 0) goto exit;
|
||||
|
||||
if ((ret = ncclGroupError) != ncclSuccess) goto fail;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2017, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -67,11 +68,6 @@ extern __thread struct ncclComm* ncclGroupCommHead;
|
||||
extern __thread struct ncclComm* ncclGroupCommPreconnectHead;
|
||||
extern __thread int ncclGroupBlocking;
|
||||
|
||||
inline ncclResult_t ncclGroupStartInternal() {
|
||||
ncclGroupDepth++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
inline ncclResult_t ncclGroupErrCheck(ncclResult_t ret) {
|
||||
if (ncclGroupDepth > 0) {
|
||||
if (ret != ncclSuccess && ret != ncclInProgress) ncclGroupError = ret;
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#ifndef MSCCL_KERNEL_H_
|
||||
#define MSCCL_KERNEL_H_
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto) mscclKernel_##devredop##_##type##_##proto
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, proto) \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork work);
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL128) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, Simple)
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t)
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC() \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv)
|
||||
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC()
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,35 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#ifndef MSCCL_LIFECYCLE_H_
|
||||
#define MSCCL_LIFECYCLE_H_
|
||||
|
||||
#include "enqueue.h"
|
||||
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
bool mscclEnabled();
|
||||
|
||||
void mscclSetIsCallerFlag();
|
||||
void mscclClearIsCallerFlag();
|
||||
bool mscclIsCaller();
|
||||
|
||||
bool mscclAvailable();
|
||||
|
||||
ncclResult_t mscclInit(ncclComm_t comm);
|
||||
|
||||
ncclResult_t mscclGroupStart();
|
||||
|
||||
ncclResult_t mscclEnqueueCheck(
|
||||
const void* sendbuff, const size_t sendcounts[], const size_t sdispls[],
|
||||
void* recvbuff, const size_t recvcounts[], const size_t rdispls[],
|
||||
size_t count, ncclDataType_t datatype, int root, int peer, ncclRedOp_t op,
|
||||
mscclFunc_t mscclFunc, ncclComm_t comm, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclGroupEnd();
|
||||
|
||||
ncclResult_t mscclTeardown();
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,103 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#ifndef MSCCL_PARSER_H_
|
||||
#define MSCCL_PARSER_H_
|
||||
|
||||
#include "nccl.h"
|
||||
#include "debug.h"
|
||||
#include "checks.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
// A few constraints to make the implementation easy
|
||||
#define MAX_STR_LEN 255
|
||||
#define MAX_ATTR_COUNT 16
|
||||
#define MAX_SUBS 1024
|
||||
#define MAX_NODES 4096
|
||||
|
||||
#define NODE_TYPE_NONE 0
|
||||
#define NODE_TYPE_OPEN 1
|
||||
#define NODE_TYPE_CLOSE 2
|
||||
#define NODE_TYPE_SINGLE 3
|
||||
|
||||
struct mscclXmlNode {
|
||||
char name[MAX_STR_LEN+1];
|
||||
struct {
|
||||
char key[MAX_STR_LEN+1];
|
||||
char value[MAX_STR_LEN+1];
|
||||
} attrs[MAX_ATTR_COUNT+1]; // Need an extra one to consume extra params
|
||||
int nAttrs;
|
||||
int type;
|
||||
struct mscclXmlNode* parent;
|
||||
struct mscclXmlNode* subs[MAX_SUBS];
|
||||
int nSubs;
|
||||
};
|
||||
|
||||
struct mscclXml {
|
||||
struct mscclXmlNode nodes[MAX_NODES];
|
||||
int maxIndex;
|
||||
};
|
||||
|
||||
static ncclResult_t mscclXmlGetAttrIndex(struct mscclXmlNode* node, const char* attrName, int* index) {
|
||||
*index = -1;
|
||||
const int nAttrs = node->nAttrs;
|
||||
for (int a=0; a<nAttrs; a++) {
|
||||
if (strncmp(node->attrs[a].key, attrName, MAX_STR_LEN) == 0) {
|
||||
*index = a;
|
||||
return ncclSuccess;
|
||||
}
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclXmlGetAttr(struct mscclXmlNode* node, const char* attrName, const char** value) {
|
||||
int index;
|
||||
NCCLCHECK(mscclXmlGetAttrIndex(node, attrName, &index));
|
||||
*value = index == -1 ? NULL : node->attrs[index].value;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclXmlGetAttrStr(struct mscclXmlNode* node, const char* attrName, const char** value) {
|
||||
NCCLCHECK(mscclXmlGetAttr(node, attrName, value));
|
||||
if (*value == NULL) {
|
||||
WARN("Attribute %s of node %s not found", attrName, node->name);
|
||||
return ncclInternalError;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
static ncclResult_t mscclXmlGetAttrInt(struct mscclXmlNode* node, const char* attrName, int* value) {
|
||||
const char* str;
|
||||
NCCLCHECK(mscclXmlGetAttrStr(node, attrName, &str));
|
||||
*value = strtol(str, NULL, 0);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclXmlGetAttrInt64(struct mscclXmlNode* node, const char* attrName, int64_t* value) {
|
||||
const char* str;
|
||||
NCCLCHECK(mscclXmlGetAttrStr(node, attrName, &str));
|
||||
*value = strtoll(str, NULL, 0);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclXmlFindTag(struct mscclXml* xml, const char* tagName, struct mscclXmlNode** node) {
|
||||
*node = NULL;
|
||||
for (int i=0; i<xml->maxIndex; i++) {
|
||||
struct mscclXmlNode* n = xml->nodes+i;
|
||||
if (strcmp(n->name, tagName) == 0) {
|
||||
*node = n;
|
||||
return ncclSuccess;
|
||||
}
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclGetAlgoFromXmlFile(const char* xmlGraphFile, struct mscclAlgo* algo, int rank);
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,28 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#ifndef MSCCL_SETUP_H_
|
||||
#define MSCCL_SETUP_H_
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "comm.h"
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclSetupSyncFlags(hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm);
|
||||
|
||||
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType);
|
||||
|
||||
ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm);
|
||||
|
||||
ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count,
|
||||
ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo,
|
||||
ncclComm_t comm, hipStream_t stream);
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#ifndef MSCCL_STATUS_H_
|
||||
#define MSCCL_STATUS_H_
|
||||
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
mscclStatus& mscclGetStatus();
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,209 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#ifndef MSCCL_STRUCT_H_
|
||||
#define MSCCL_STRUCT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "devcomm.h"
|
||||
|
||||
#define MSCCL_MAX_NUM_STEPS 256
|
||||
#define MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL 32
|
||||
#define MSCCL_MAX_NUM_THREAD_BLOCKS (MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL * MAXCHANNELS)
|
||||
#define MSCCL_MAX_COUNT 72 // max concurrent number of msccl chunk transmission
|
||||
#define MSCCL_MAX_REDUCE_FUSION 16
|
||||
#define MSCCL_MAX_NUM_ALGOS 1024
|
||||
|
||||
#define MSCCL_SLICESTEPS (NCCL_STEPS/4)
|
||||
#define MSCCL_CHUNKSTEPS (NCCL_STEPS/2)
|
||||
|
||||
#define MSCCL_INPUT_BUFFER 0
|
||||
#define MSCCL_OUTPUT_BUFFER 1
|
||||
#define MSCCL_SCRATCH_BUFFER 2
|
||||
|
||||
#define MSCCL_SEND 0
|
||||
#define MSCCL_RECV 1
|
||||
#define MSCCL_RECV_COPY_SEND 2
|
||||
#define MSCCL_RECV_REDUCE_SEND 3
|
||||
#define MSCCL_RECV_REDUCE_COPY 4
|
||||
#define MSCCL_RECV_REDUCE_COPY_SEND 5
|
||||
#define MSCCL_LOCAL_COPY 6
|
||||
#define MSCCL_REDUCE 7
|
||||
|
||||
typedef enum { mscclFuncReduce = 0,
|
||||
mscclFuncBroadcast = 1,
|
||||
mscclFuncAllReduce = 2,
|
||||
mscclFuncReduceScatter = 3,
|
||||
mscclFuncAllGather = 4,
|
||||
mscclFuncSend = 5,
|
||||
mscclFuncRecv = 6,
|
||||
mscclFuncGather = 7,
|
||||
mscclFuncScatter = 8,
|
||||
mscclFuncAllToAll = 9,
|
||||
mscclFuncAllToAllv = 10,
|
||||
mscclNumFuncs = 11 } mscclFunc_t;
|
||||
|
||||
struct mscclTransmission {
|
||||
int16_t dependencePointer; // index to the first dependence
|
||||
int16_t numDependencies; // dependencePointer+numDependencies indicate the last dependence
|
||||
int16_t reductionPointer; // where the reduction starts
|
||||
int16_t numReductions; // number of reductions with the same dst
|
||||
int16_t srcOffset;
|
||||
int16_t dstOffset;
|
||||
uint8_t srcBuffer : 4; // input/output/scratch
|
||||
uint8_t dstBuffer : 4; // input/output/scratch
|
||||
int8_t hasDependence;
|
||||
uint8_t type;
|
||||
uint8_t count;
|
||||
}; // 16 bytes
|
||||
|
||||
static_assert((1ULL << (8*sizeof(mscclTransmission::count))) - 1 > MSCCL_MAX_COUNT, "MSCCL_MAX_COUNT must representable by datatype of count");
|
||||
|
||||
struct mscclThreadBlock {
|
||||
// step is used to index into these arrays
|
||||
struct mscclTransmission transmissions[MSCCL_MAX_NUM_STEPS]; // 4KB
|
||||
int8_t dependentBid[MSCCL_MAX_NUM_STEPS]; // -1 if not dependent on any thread block, 256 bytes
|
||||
int16_t dependentStep[MSCCL_MAX_NUM_STEPS]; // 512 bytes
|
||||
int16_t reductionSrcOffsets[MSCCL_MAX_NUM_STEPS]; // 512 bytes
|
||||
int16_t sendPeer;
|
||||
int16_t recvPeer;
|
||||
uint16_t nSteps;
|
||||
int16_t channelId; // associated channel. -1 indicates a thread block with only local copies
|
||||
}; // 5384 bytes
|
||||
|
||||
static_assert(sizeof(struct mscclThreadBlock) % sizeof(uint64_t) == 0, "Sanity check: sizeof(struct mscclThreadBlock) \% sizeof(uint64_t) != 0");
|
||||
|
||||
struct mscclFlag {
|
||||
uint64_t flag;
|
||||
uint64_t align[3]; // to avoid false sharing
|
||||
};
|
||||
|
||||
struct mscclChannelPeerInfo {
|
||||
int peer;
|
||||
// nTransmissionsOfCount[i]: number of transmissions with count i (in terms of msccl chunks)
|
||||
int nTransmissionsOfCount[MSCCL_MAX_COUNT + 1];
|
||||
int existingCounts[MSCCL_MAX_COUNT + 1];
|
||||
int nExistingCounts;
|
||||
};
|
||||
|
||||
struct mscclChannelInfo {
|
||||
struct mscclChannelPeerInfo sendPeerInfo[MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
int nSendPeers;
|
||||
struct mscclChannelPeerInfo recvPeerInfo[MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
int nRecvPeers;
|
||||
};
|
||||
|
||||
struct mscclAlgo {
|
||||
// number of chunks of input/output in each MSCCL algorithm loop
|
||||
int nChunksPerLoop;
|
||||
// the protocol that the algorithm needs to use
|
||||
int protocol;
|
||||
// number of channels needed by MSCCL algorithm
|
||||
int nChannels;
|
||||
// number of ranks required by this algorithm
|
||||
int nRanks;
|
||||
// number of necessary thread blocks
|
||||
int nBlocks;
|
||||
// number of scratch chunks that MSCCL will use
|
||||
int nScratchChunks;
|
||||
// need to times nRanks for all-gather, reduce-scatter and all-to-all
|
||||
int sizeMultiplier;
|
||||
// number of steps per chunk for this algorithm
|
||||
int chunkSteps;
|
||||
// number of steps per slice for this algorithm
|
||||
int sliceSteps;
|
||||
// bid is used as an index into this array
|
||||
struct mscclThreadBlock mscclTBs[MSCCL_MAX_NUM_THREAD_BLOCKS];
|
||||
// used to calculate proxy info
|
||||
struct mscclChannelInfo mscclChannels[MAXCHANNELS];
|
||||
// Whether the algorithm requires reduce operation
|
||||
bool hasReduce;
|
||||
// MSCCL function type
|
||||
mscclFunc_t func;
|
||||
// Min message size allowed for this algorithm.
|
||||
int64_t minBytes;
|
||||
// Max message size allowed for this algorithm, 0 for no limit.
|
||||
int64_t maxBytes;
|
||||
// Whether this algorithm is suitable for in-place.
|
||||
bool inPlace;
|
||||
// Whether this algorithm is suitable for out-of-place.
|
||||
bool outOfPlace;
|
||||
};
|
||||
|
||||
enum mscclGroupStatus {
|
||||
mscclNoGroup,
|
||||
mscclGroupSupportedOp,
|
||||
mscclGroupUnsupportedOp
|
||||
};
|
||||
|
||||
struct mscclSchedulerParam {
|
||||
const void* sendBuff;
|
||||
const size_t* sendCounts;
|
||||
std::vector<size_t> savedSendCounts;
|
||||
const size_t* sDisPls;
|
||||
std::vector<size_t> savedSDisPls;
|
||||
void* recvBuff;
|
||||
const size_t* recvCounts;
|
||||
std::vector<size_t> savedRecvCounts;
|
||||
const size_t* rDisPls;
|
||||
std::vector<size_t> savedRDisPls;
|
||||
size_t count;
|
||||
ncclDataType_t dataType;
|
||||
int root;
|
||||
int peer;
|
||||
ncclRedOp_t op;
|
||||
mscclFunc_t func;
|
||||
bool scheduled;
|
||||
mscclAlgoHandle_t handle;
|
||||
ncclComm_t comm;
|
||||
hipStream_t stream;
|
||||
};
|
||||
|
||||
struct mscclStatus {
|
||||
std::vector<mscclAlgoHandle_t> freeAlgoHandles;
|
||||
std::map<mscclAlgoHandle_t, mscclAlgo *> hostAlgos;
|
||||
std::map<mscclAlgoHandle_t, mscclAlgo *> devAlgos;
|
||||
struct mscclFlag* syncFlags;
|
||||
void *scratchBuffer;
|
||||
uint64_t scratchBufferSize;
|
||||
size_t nBytes;
|
||||
int stepSize;
|
||||
int chunkSteps;
|
||||
int sliceSteps;
|
||||
int chunkSize;
|
||||
int chunkEffectiveSize;
|
||||
int rank;
|
||||
uint32_t workIndex;
|
||||
uint32_t maxAllowedCount;
|
||||
ncclDataType_t dataType;
|
||||
mscclGroupStatus groupStatus;
|
||||
int groupDepth;
|
||||
std::vector<struct mscclSchedulerParam> savedSchedulerParams;
|
||||
};
|
||||
|
||||
struct alignas(16) mscclWork {
|
||||
volatile struct mscclFlag *syncFlags;
|
||||
void *scratchBuffer;
|
||||
const void *sendBuff;
|
||||
void *recvBuff;
|
||||
size_t count;
|
||||
uint64_t redOpArg;
|
||||
uint32_t workIndex;
|
||||
int nChunksPerLoop;
|
||||
uint32_t maxAllowedCount;
|
||||
bool hasReduce;
|
||||
bool redOpArgIsPtr;
|
||||
};
|
||||
|
||||
struct mscclShmemData {
|
||||
struct mscclThreadBlock mscclTB;
|
||||
alignas(16) struct mscclWork work;
|
||||
};
|
||||
static_assert(offsetof(struct mscclShmemData, work) % 16 == 0, "mscclShmemData.work needs to be 16B aligned");
|
||||
|
||||
#endif
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -224,4 +225,8 @@ enum ncclProxyMsgType {
|
||||
ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize);
|
||||
ncclResult_t ncclProxyDestroy(struct ncclComm* comm);
|
||||
ncclResult_t ncclProxyShmUnlink(struct ncclComm* comm);
|
||||
|
||||
enum { proxyRecv=0, proxySend=1 };
|
||||
ncclResult_t mscclSaveProxy(struct ncclChannel* channel, int type, int peer, struct ncclProxyOp* op, int connIndex);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -39,6 +40,8 @@
|
||||
//#include <hsa/hsa_ext_amd.h>
|
||||
// [/RCCL]
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
#define STR2(v) #v
|
||||
#define STR(v) STR2(v)
|
||||
|
||||
@@ -608,6 +611,10 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
|
||||
NCCLCHECK(ncclCudaCalloc(&tmpCommAndChans.comm.devProf, MAXCHANNELS*PROFILE_NUM_LAUNCHES), comm->sideStream);
|
||||
#endif
|
||||
|
||||
if (mscclEnabled()) {
|
||||
NCCLCHECK(mscclInit(comm));
|
||||
}
|
||||
|
||||
NCCLCHECK(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->deviceStream.cudaStream));
|
||||
CUDACHECK(cudaStreamSynchronize(comm->deviceStream.cudaStream));
|
||||
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream));
|
||||
@@ -1703,6 +1710,10 @@ static ncclResult_t commCleanup(ncclComm_t comm) {
|
||||
NCCLCHECK(NpKit::Shutdown());
|
||||
#endif
|
||||
|
||||
if (mscclEnabled()) {
|
||||
NCCLCHECK(mscclTeardown());
|
||||
}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "checks.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
#include "msccl/msccl_parser.h"
|
||||
#include "msccl/msccl_setup.h"
|
||||
#include "msccl/msccl_status.h"
|
||||
|
||||
RCCL_PARAM(MscclEnabled, "MSCCL_ENABLE", 0);
|
||||
static const char* mscclAlgoFilePathEnv = "MSCCL_ALGO_FILE_PATH";
|
||||
static std::atomic<bool> mscclInitialized;
|
||||
static bool mscclSchedulerTriedLoadAlgo = false;
|
||||
|
||||
bool mscclEnabled() {
|
||||
return rcclParamMscclEnabled();
|
||||
}
|
||||
|
||||
static bool mscclIsCallerFlag = false;
|
||||
|
||||
void mscclSetIsCallerFlag() {
|
||||
mscclIsCallerFlag = true;
|
||||
}
|
||||
|
||||
void mscclClearIsCallerFlag() {
|
||||
mscclIsCallerFlag = false;
|
||||
}
|
||||
|
||||
bool mscclIsCaller() {
|
||||
return mscclIsCallerFlag;
|
||||
}
|
||||
|
||||
bool mscclAvailable() {
|
||||
return mscclEnabled() && mscclInitialized.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
if (comm->intraRanks > 1) {
|
||||
mscclInitialized.store(false, std::memory_order_release);
|
||||
INFO(NCCL_INIT, "MSCCL doesn't support multiple GPUs in one process and is not available");
|
||||
return ncclSuccess;
|
||||
} else {
|
||||
mscclInitialized.store(true, std::memory_order_release);
|
||||
}
|
||||
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
status.scratchBuffer = nullptr;
|
||||
status.scratchBufferSize = 0;
|
||||
status.rank = comm->rank;
|
||||
status.workIndex = 1;
|
||||
status.freeAlgoHandles.resize(MSCCL_MAX_NUM_ALGOS);
|
||||
for (int i = 0; i < MSCCL_MAX_NUM_ALGOS; i++) {
|
||||
status.freeAlgoHandles[i] = MSCCL_MAX_NUM_ALGOS - i - 1;
|
||||
}
|
||||
NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS));
|
||||
status.groupStatus = mscclNoGroup;
|
||||
status.groupDepth = 0;
|
||||
mscclSchedulerTriedLoadAlgo = false;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclGroupStart() {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
status.groupDepth++;
|
||||
if (status.groupStatus == mscclNoGroup) {
|
||||
status.groupStatus = mscclGroupSupportedOp;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclScheduler(struct mscclSchedulerParam* param) {
|
||||
static bool algoAvailable = false;
|
||||
static mscclAlgoHandle_t loadedAlgoHandle;
|
||||
static mscclAlgo* loadedHostAlgo = nullptr;
|
||||
|
||||
param->scheduled = false;
|
||||
|
||||
if (!mscclSchedulerTriedLoadAlgo) {
|
||||
mscclSchedulerTriedLoadAlgo = true;
|
||||
const char* mscclAlgoFilePath = getenv(mscclAlgoFilePathEnv);
|
||||
if (mscclAlgoFilePath != nullptr) {
|
||||
NCCLCHECK(mscclLoadAlgo(mscclAlgoFilePath, &loadedAlgoHandle));
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
loadedHostAlgo = status.hostAlgos[loadedAlgoHandle];
|
||||
algoAvailable = true;
|
||||
}
|
||||
}
|
||||
if (!algoAvailable) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
bool mscclAlgoFuncIsValid = loadedHostAlgo->func == param->func;
|
||||
if (!mscclAlgoFuncIsValid) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
bool numGpusIsValid = loadedHostAlgo->nRanks == param->comm->nRanks;
|
||||
if (!numGpusIsValid) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
size_t nBytes = param->count * ncclTypeSize(param->dataType) * loadedHostAlgo->sizeMultiplier;
|
||||
bool msgSizeIsValid =
|
||||
param->count > 0 && (param->count % loadedHostAlgo->nChunksPerLoop) == 0 &&
|
||||
nBytes >= loadedHostAlgo->minBytes &&
|
||||
(loadedHostAlgo->maxBytes == 0 || nBytes <= loadedHostAlgo->maxBytes);
|
||||
if (!msgSizeIsValid) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
bool isInPlace = false;
|
||||
if (param->func == mscclFuncReduce ||
|
||||
param->func == mscclFuncBroadcast ||
|
||||
param->func == mscclFuncAllReduce ||
|
||||
param->func == mscclFuncAllToAll ||
|
||||
param->func == mscclFuncAllToAllv) {
|
||||
isInPlace = param->sendBuff == param->recvBuff;
|
||||
} else if (param->func == mscclFuncAllGather ||
|
||||
param->func == mscclFuncGather) {
|
||||
isInPlace = (char*)param->sendBuff == (char*)param->recvBuff + param->comm->rank * param->count * ncclTypeSize(param->dataType);
|
||||
} else if (param->func == mscclFuncReduceScatter ||
|
||||
param->func == mscclFuncScatter) {
|
||||
isInPlace = (char*)param->recvBuff == (char*)param->sendBuff + param->comm->rank * param->count * ncclTypeSize(param->dataType);
|
||||
}
|
||||
bool inPlaceOutOfPlaceIsValid = isInPlace ? loadedHostAlgo->inPlace : loadedHostAlgo->outOfPlace;
|
||||
if (!inPlaceOutOfPlaceIsValid) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
param->handle = loadedAlgoHandle;
|
||||
param->scheduled = true;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclSetSchedulerParam(
|
||||
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
|
||||
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
|
||||
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
|
||||
mscclFunc_t func, ncclComm_t comm, hipStream_t stream,
|
||||
struct mscclSchedulerParam* param) {
|
||||
param->sendBuff = sendBuff;
|
||||
param->sendCounts = sendCounts;
|
||||
param->sDisPls = sDisPls;
|
||||
param->recvBuff = recvBuff;
|
||||
param->recvCounts = recvCounts;
|
||||
param->rDisPls = rDisPls;
|
||||
param->count = count;
|
||||
param->dataType = dataType;
|
||||
param->root = root;
|
||||
param->peer = peer;
|
||||
param->op = op;
|
||||
param->func = func;
|
||||
param->comm = comm;
|
||||
param->stream = stream;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclSaveCountsAndDispls(struct mscclSchedulerParam* param) {
|
||||
if (param->sendCounts) {
|
||||
param->savedSendCounts.assign(param->sendCounts, param->sendCounts + param->comm->nRanks);
|
||||
param->sendCounts = param->savedSendCounts.data();
|
||||
param->savedSDisPls.assign(param->sDisPls, param->sDisPls + param->comm->nRanks);
|
||||
param->sDisPls = param->savedSDisPls.data();
|
||||
param->savedRecvCounts.assign(param->recvCounts, param->recvCounts + param->comm->nRanks);
|
||||
param->recvCounts = param->savedRecvCounts.data();
|
||||
param->savedRDisPls.assign(param->rDisPls, param->rDisPls + param->comm->nRanks);
|
||||
param->rDisPls = param->savedRDisPls.data();
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclRunSavedParams() {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
for (auto& param : status.savedSchedulerParams) {
|
||||
NCCLCHECK(mscclRunAlgo(
|
||||
param.sendBuff, param.sendCounts, param.sDisPls,
|
||||
param.recvBuff, param.recvCounts, param.rDisPls,
|
||||
param.count, param.dataType, param.root, param.peer, param.op, param.handle, param.comm, param.stream));
|
||||
}
|
||||
status.savedSchedulerParams.clear();
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclFallBackSavedParams() {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclSetIsCallerFlag();
|
||||
for (auto& param : status.savedSchedulerParams) {
|
||||
switch (param.func) {
|
||||
case mscclFuncReduce:
|
||||
NCCLCHECK(ncclReduce(param.sendBuff, param.recvBuff, param.count, param.dataType,
|
||||
param.op, param.root, param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncBroadcast:
|
||||
NCCLCHECK(ncclBroadcast(param.sendBuff, param.recvBuff, param.count, param.dataType,
|
||||
param.root, param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncAllReduce:
|
||||
NCCLCHECK(ncclAllReduce(param.sendBuff, param.recvBuff, param.count, param.dataType,
|
||||
param.op, param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncReduceScatter:
|
||||
NCCLCHECK(ncclReduceScatter(param.sendBuff, param.recvBuff, param.count, param.dataType,
|
||||
param.op, param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncAllGather:
|
||||
NCCLCHECK(ncclAllGather(param.sendBuff, param.recvBuff, param.count, param.dataType,
|
||||
param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncSend:
|
||||
NCCLCHECK(ncclSend(param.sendBuff, param.count, param.dataType,
|
||||
param.peer, param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncRecv:
|
||||
NCCLCHECK(ncclRecv(param.recvBuff, param.count, param.dataType,
|
||||
param.peer, param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncGather:
|
||||
NCCLCHECK(ncclGather(param.sendBuff, param.recvBuff, param.count, param.dataType,
|
||||
param.root, param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncScatter:
|
||||
NCCLCHECK(ncclScatter(param.sendBuff, param.recvBuff, param.count, param.dataType,
|
||||
param.root, param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncAllToAll:
|
||||
NCCLCHECK(ncclAllToAll(param.sendBuff, param.recvBuff, param.count, param.dataType,
|
||||
param.comm, param.stream));
|
||||
break;
|
||||
case mscclFuncAllToAllv:
|
||||
NCCLCHECK(ncclAllToAllv(
|
||||
param.sendBuff, param.sendCounts, param.sDisPls,
|
||||
param.recvBuff, param.recvCounts, param.rDisPls,
|
||||
param.dataType, param.comm, param.stream));
|
||||
break;
|
||||
default:
|
||||
WARN("Invalid MSCCL function type in saved parameter");
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
}
|
||||
mscclClearIsCallerFlag();
|
||||
status.savedSchedulerParams.clear();
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclEnqueueCheck(
|
||||
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
|
||||
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
|
||||
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
|
||||
mscclFunc_t func, ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
hipStreamCaptureStatus captureStatus;
|
||||
unsigned long long pid;
|
||||
|
||||
status.savedSchedulerParams.push_back({});
|
||||
NCCLCHECK(mscclSetSchedulerParam(
|
||||
sendBuff, sendCounts, sDisPls, recvBuff, recvCounts, rDisPls,
|
||||
count, dataType, root, peer, op, func, comm, stream,
|
||||
&status.savedSchedulerParams.back()));
|
||||
|
||||
switch (status.groupStatus) {
|
||||
case mscclNoGroup:
|
||||
CUDACHECK(hipStreamGetCaptureInfo(stream, &captureStatus, &pid));
|
||||
if (captureStatus == hipStreamCaptureStatusNone) {
|
||||
NCCLCHECK(mscclScheduler(&status.savedSchedulerParams.back()));
|
||||
if (status.savedSchedulerParams.back().scheduled) {
|
||||
NCCLCHECK(mscclRunSavedParams());
|
||||
break;
|
||||
}
|
||||
}
|
||||
NCCLCHECK(mscclFallBackSavedParams());
|
||||
break;
|
||||
case mscclGroupSupportedOp:
|
||||
CUDACHECK(hipStreamGetCaptureInfo(stream, &captureStatus, &pid));
|
||||
if (captureStatus == hipStreamCaptureStatusNone) {
|
||||
NCCLCHECK(mscclScheduler(&status.savedSchedulerParams.back()));
|
||||
if (status.savedSchedulerParams.back().scheduled) {
|
||||
// Only save counts and displs when there is suitable MSCCL algorithm for this
|
||||
NCCLCHECK(mscclSaveCountsAndDispls(&status.savedSchedulerParams.back()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
NCCLCHECK(mscclFallBackSavedParams());
|
||||
break;
|
||||
case mscclGroupUnsupportedOp:
|
||||
NCCLCHECK(mscclFallBackSavedParams());
|
||||
break;
|
||||
default:
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclGroupEnd() {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
status.groupDepth--;
|
||||
if (status.groupDepth == 0) {
|
||||
if (status.groupStatus == mscclGroupSupportedOp) {
|
||||
NCCLCHECK(mscclRunSavedParams());
|
||||
}
|
||||
status.groupStatus = mscclNoGroup;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclTeardown() {
|
||||
if (!mscclInitialized.load(std::memory_order_acquire)) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
for (auto &p : status.hostAlgos) {
|
||||
free(p.second);
|
||||
status.freeAlgoHandles.push_back(p.first);
|
||||
}
|
||||
for (auto &p : status.devAlgos) {
|
||||
CUDACHECK(hipFree(p.second));
|
||||
}
|
||||
CUDACHECK(hipFree(status.scratchBuffer));
|
||||
CUDACHECK(hipFree(status.syncFlags));
|
||||
status.hostAlgos.clear();
|
||||
status.devAlgos.clear();
|
||||
status.freeAlgoHandles.clear();
|
||||
status.scratchBuffer = nullptr;
|
||||
status.scratchBufferSize = 0;
|
||||
status.workIndex = 1;
|
||||
mscclInitialized.store(false, std::memory_order_release);
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -0,0 +1,703 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <ctype.h>
|
||||
#include "core.h"
|
||||
#include "collectives.h"
|
||||
#include "msccl/msccl_parser.h"
|
||||
|
||||
ncclResult_t mscclXmlGetChar(FILE* file, char* c) {
|
||||
if (fread(c, 1, 1, file) == 0) {
|
||||
WARN("XML Parse : Unexpected EOF");
|
||||
return ncclInternalError;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclXmlGetValue(FILE* file, char* value, char* last) {
|
||||
char c;
|
||||
NCCLCHECK(mscclXmlGetChar(file, &c));
|
||||
if (c != '"' && c != '\'') {
|
||||
#if INT_OK
|
||||
int o = 0;
|
||||
do {
|
||||
value[o++] = c;
|
||||
NCCLCHECK(mscclXmlGetChar(file, &c));
|
||||
} while (c >= '0' && c <= '9');
|
||||
value[o] = '\0';
|
||||
*last = c;
|
||||
return ncclSuccess;
|
||||
#else
|
||||
WARN("XML Parse : Expected (double) quote.");
|
||||
return ncclInternalError;
|
||||
#endif
|
||||
}
|
||||
int o = 0;
|
||||
do {
|
||||
NCCLCHECK(mscclXmlGetChar(file, &c));
|
||||
value[o++] = c;
|
||||
} while (c != '"');
|
||||
value[o-1] = '\0';
|
||||
NCCLCHECK(mscclXmlGetChar(file, last));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclXmlGetToken(FILE* file, char* name, char* value, char* last) {
|
||||
char c;
|
||||
char* ptr = name;
|
||||
int o = 0;
|
||||
do {
|
||||
NCCLCHECK(mscclXmlGetChar(file, &c));
|
||||
if (c == '=') {
|
||||
ptr[o] = '\0';
|
||||
if (value == NULL) {
|
||||
WARN("XML Parse : Unexpected value with name %s", ptr);
|
||||
return ncclInternalError;
|
||||
}
|
||||
return mscclXmlGetValue(file, value, last);
|
||||
}
|
||||
ptr[o] = c;
|
||||
if (o == MAX_STR_LEN-1) {
|
||||
ptr[o] = '\0';
|
||||
WARN("Error : name %s too long (max %d)", ptr, MAX_STR_LEN);
|
||||
return ncclInternalError;
|
||||
}
|
||||
o++;
|
||||
} while (c != ' ' && c != '>' && c != '/' && c != '\n' && c != '\r');
|
||||
ptr[o-1] = '\0';
|
||||
*last = c;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Shift the 3-chars string by one char and append c at the end
|
||||
#define SHIFT_APPEND(s, c) do { s[0]=s[1]; s[1]=s[2]; s[2]=c; } while(0)
|
||||
ncclResult_t mscclXmlSkipComment(FILE* file, char* start, char next) {
|
||||
// Start from something neutral with \0 at the end.
|
||||
char end[4] = "...";
|
||||
|
||||
// Inject all trailing chars from previous reads. We don't need
|
||||
// to check for --> here because there cannot be a > in the name.
|
||||
for (int i=0; i<strlen(start); i++) SHIFT_APPEND(end, start[i]);
|
||||
SHIFT_APPEND(end, next);
|
||||
|
||||
// Stop when we find "-->"
|
||||
while (strcmp(end, "-->") != 0) {
|
||||
int c;
|
||||
if (fread(&c, 1, 1, file) != 1) {
|
||||
WARN("XML Parse error : unterminated comment");
|
||||
return ncclInternalError;
|
||||
}
|
||||
SHIFT_APPEND(end, c);
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclXmlGetNode(FILE* file, struct mscclXmlNode* node) {
|
||||
node->type = NODE_TYPE_NONE;
|
||||
char c = ' ';
|
||||
while (c == ' ' || c == '\n' || c == '\r') {
|
||||
if (fread(&c, 1, 1, file) == 0) return ncclSuccess;
|
||||
}
|
||||
if (c != '<') {
|
||||
WARN("XML Parse error : expecting '<', got '%c'", c);
|
||||
return ncclInternalError;
|
||||
}
|
||||
// Read XML element name
|
||||
NCCLCHECK(mscclXmlGetToken(file, node->name, NULL, &c));
|
||||
|
||||
// Check for comments
|
||||
if (strncmp(node->name, "!--", 3) == 0) {
|
||||
NCCLCHECK(mscclXmlSkipComment(file, node->name+3, c));
|
||||
return mscclXmlGetNode(file, node);
|
||||
}
|
||||
|
||||
// Check for closing tag
|
||||
if (node->name[0] == '\0' && c == '/') {
|
||||
node->type = NODE_TYPE_CLOSE;
|
||||
// Re-read the name, we got '/' in the first call
|
||||
NCCLCHECK(mscclXmlGetToken(file, node->name, NULL, &c));
|
||||
if (c != '>') {
|
||||
WARN("XML Parse error : unexpected trailing %c in closing tag %s", c, node->name);
|
||||
return ncclInternalError;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
node->type = NODE_TYPE_OPEN;
|
||||
|
||||
// Get Attributes
|
||||
int a = 0;
|
||||
while (c == ' ') {
|
||||
NCCLCHECK(mscclXmlGetToken(file, node->attrs[a].key, node->attrs[a].value, &c));
|
||||
if (a == MAX_ATTR_COUNT) {
|
||||
INFO(NCCL_GRAPH, "XML Parse : Ignoring extra attributes (max %d)", MAX_ATTR_COUNT);
|
||||
// Actually we need to still consume the extra attributes so we have an extra one.
|
||||
} else a++;
|
||||
}
|
||||
node->nAttrs = a;
|
||||
if (c == '/') {
|
||||
node->type = NODE_TYPE_SINGLE;
|
||||
char str[MAX_STR_LEN];
|
||||
NCCLCHECK(mscclXmlGetToken(file, str, NULL, &c));
|
||||
}
|
||||
if (c != '>') {
|
||||
WARN("XML Parse : expected >, got '%c'", c);
|
||||
return ncclInternalError;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
typedef ncclResult_t (*mscclXmlHandlerFunc_t)(FILE*, struct mscclXml*, struct mscclXmlNode*);
|
||||
|
||||
struct mscclXmlHandler {
|
||||
const char * name;
|
||||
mscclXmlHandlerFunc_t func;
|
||||
};
|
||||
|
||||
ncclResult_t mscclXmlLoadSub(FILE* file, struct mscclXml* xml, struct mscclXmlNode* head, struct mscclXmlHandler handlers[], int nHandlers) {
|
||||
if (head && head->type == NODE_TYPE_SINGLE) return ncclSuccess;
|
||||
while (1) {
|
||||
if (xml->maxIndex == MAX_NODES) {
|
||||
WARN("Error : XML parser is limited to 1024 nodes");
|
||||
return ncclInternalError;
|
||||
}
|
||||
struct mscclXmlNode* node = xml->nodes+xml->maxIndex;
|
||||
memset(node, 0, sizeof(struct mscclXmlNode));
|
||||
NCCLCHECK(mscclXmlGetNode(file, node));
|
||||
if (node->type == NODE_TYPE_NONE) {
|
||||
if (head) {
|
||||
WARN("XML Parse : unterminated %s", head->name);
|
||||
return ncclInternalError;
|
||||
} else {
|
||||
// All done
|
||||
return ncclSuccess;
|
||||
}
|
||||
}
|
||||
if (head && node->type == NODE_TYPE_CLOSE) {
|
||||
if (strcmp(node->name, head->name) != 0) {
|
||||
WARN("XML Mismatch : %s / %s", head->name, node->name);
|
||||
return ncclInternalError;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
int found = 0;
|
||||
for (int h=0; h<nHandlers; h++) {
|
||||
if (strcmp(node->name, handlers[h].name) == 0) {
|
||||
if (head) head->subs[head->nSubs++] = node;
|
||||
node->parent = head;
|
||||
node->nSubs = 0;
|
||||
xml->maxIndex++;
|
||||
NCCLCHECK(handlers[h].func(file, xml, node));
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
if (nHandlers) INFO(NCCL_GRAPH, "Ignoring element %s", node->name);
|
||||
NCCLCHECK(mscclXmlLoadSub(file, xml, node, NULL, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ncclResult_t mscclAlgoXmlStep(FILE* file, struct mscclXml* xml, struct mscclXmlNode* head) {
|
||||
NCCLCHECK(mscclXmlLoadSub(file, xml, head, NULL, 1));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclAlgoXmlThreadBlock(FILE* file, struct mscclXml* xmlGraph, struct mscclXmlNode* head) {
|
||||
struct mscclXmlHandler handlers[] = { { "step", mscclAlgoXmlStep } };
|
||||
NCCLCHECK(mscclXmlLoadSub(file, xmlGraph, head, handlers, 1));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static int currentRank;
|
||||
|
||||
ncclResult_t mscclAlgoXmlGpu(FILE* file, struct mscclXml* xmlGraph, struct mscclXmlNode* head) {
|
||||
int thisrank;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(head, "id", &thisrank));
|
||||
if (thisrank == currentRank) {
|
||||
struct mscclXmlHandler handlers[] = { { "tb", mscclAlgoXmlThreadBlock } };
|
||||
NCCLCHECK(mscclXmlLoadSub(file, xmlGraph, head, handlers, 1));
|
||||
} else {
|
||||
NCCLCHECK(mscclXmlLoadSub(file, xmlGraph, head, NULL, 0));
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclAlgoXmlAlgo(FILE* file, struct mscclXml* xmlGraph, struct mscclXmlNode* head) {
|
||||
struct mscclXmlHandler handlers[] = { { "gpu", mscclAlgoXmlGpu } };
|
||||
NCCLCHECK(mscclXmlLoadSub(file, xmlGraph, head, handlers, 1));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclAlgoXmlLoad(const char* xmlFilePath, struct mscclXml* xml, int rank) {
|
||||
currentRank = rank;
|
||||
FILE* file = fopen(xmlFilePath, "r");
|
||||
if (file == NULL) {
|
||||
WARN("Could not open MSCCL XML algorithm file %s : %s", xmlFilePath, strerror(errno));
|
||||
return ncclSystemError;
|
||||
}
|
||||
struct mscclXmlHandler handlers[] = { { "algo", mscclAlgoXmlAlgo } };
|
||||
xml->maxIndex = 0;
|
||||
NCCLCHECK(mscclXmlLoadSub(file, xml, NULL, handlers, 1));
|
||||
fclose(file);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclGetBufferType(const char* str, uint8_t* output) {
|
||||
if (strcmp(str, "i") == 0) {
|
||||
*output = MSCCL_INPUT_BUFFER;
|
||||
} else if (strcmp(str, "o") == 0) {
|
||||
*output = MSCCL_OUTPUT_BUFFER;
|
||||
} else if (strcmp(str, "s") == 0) {
|
||||
*output = MSCCL_SCRATCH_BUFFER;
|
||||
} else {
|
||||
WARN("type of buffer is not supported: %s", str);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclCheckBufferBounds(int bufferType, int offset, int nInputChunks, int nOutputChunks, int nScratchChunks) {
|
||||
if (bufferType == MSCCL_INPUT_BUFFER) {
|
||||
if (offset < -1 || offset >= nInputChunks) {
|
||||
WARN("Incorrect offset set for input buffer: offset: %d maximum allowed: %d", offset, nInputChunks);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
} else if (bufferType == MSCCL_OUTPUT_BUFFER) {
|
||||
if (offset < -1 || offset >= nOutputChunks) {
|
||||
WARN("Incorrect offset set for output buffer: offset: %d maximum allowed: %d", offset, nOutputChunks);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
} else if (bufferType == MSCCL_SCRATCH_BUFFER) {
|
||||
if (offset < -1 || offset >= nScratchChunks) {
|
||||
WARN("Incorrect offset set for scratch buffer: offset: %d maximum allowed: %d", offset, nScratchChunks);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclProtocolStrToId(const char *protocol, int *protocolId) {
|
||||
if (strcmp(protocol, "Simple") == 0) {
|
||||
*protocolId = NCCL_PROTO_SIMPLE;
|
||||
} else if (strcmp(protocol, "LL128") == 0) {
|
||||
*protocolId = NCCL_PROTO_LL128;
|
||||
} else if (strcmp(protocol, "LL") == 0) {
|
||||
*protocolId = NCCL_PROTO_LL;
|
||||
} else {
|
||||
WARN("MSCCL: protocol %s is not supported.", protocol);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclGetAlgoFromXmlFile(const char* str, struct mscclAlgo* algo, int rank) {
|
||||
struct mscclXml* xml;
|
||||
NCCLCHECK(ncclCalloc(&xml, 1));
|
||||
NCCLCHECK(mscclAlgoXmlLoad(str, xml, rank));
|
||||
|
||||
// zeroing out all entries.
|
||||
memset(algo, 0, sizeof(struct mscclAlgo));
|
||||
struct mscclXmlNode* topNode;
|
||||
NCCLCHECK(mscclXmlFindTag(xml, "algo", &topNode));
|
||||
|
||||
int nChunksPerLoop;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(topNode, "nchunksperloop", &nChunksPerLoop));
|
||||
algo->nChunksPerLoop = nChunksPerLoop;
|
||||
|
||||
int nChannels;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(topNode, "nchannels", &nChannels));
|
||||
algo->nChannels = nChannels;
|
||||
|
||||
int nGpus;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(topNode, "ngpus", &nGpus));
|
||||
algo->nRanks = nGpus;
|
||||
|
||||
const char* protocol;
|
||||
NCCLCHECK(mscclXmlGetAttrStr(topNode, "proto", &protocol));
|
||||
NCCLCHECK(mscclProtocolStrToId(protocol, &algo->protocol));
|
||||
|
||||
algo->sizeMultiplier = 1;
|
||||
algo->chunkSteps = MSCCL_CHUNKSTEPS;
|
||||
algo->sliceSteps = MSCCL_SLICESTEPS;
|
||||
const char* coll;
|
||||
NCCLCHECK(mscclXmlGetAttrStr(topNode, "coll", &coll));
|
||||
if (strcmp(coll, "reduce") == 0) {
|
||||
algo->chunkSteps = REDUCE_CHUNKSTEPS;
|
||||
algo->sliceSteps = REDUCE_SLICESTEPS;
|
||||
algo->func = mscclFuncReduce;
|
||||
} else if (strcmp(coll, "broadcast") == 0) {
|
||||
algo->chunkSteps = BROADCAST_CHUNKSTEPS;
|
||||
algo->sliceSteps = BROADCAST_SLICESTEPS;
|
||||
algo->func = mscclFuncBroadcast;
|
||||
} else if (strcmp(coll, "allreduce") == 0) {
|
||||
algo->chunkSteps = ALLREDUCE_CHUNKSTEPS;
|
||||
algo->sliceSteps = ALLREDUCE_SLICESTEPS;
|
||||
algo->func = mscclFuncAllReduce;
|
||||
} else if (strcmp(coll, "reducescatter") == 0) {
|
||||
algo->sizeMultiplier = nGpus;
|
||||
algo->chunkSteps = REDUCESCATTER_CHUNKSTEPS;
|
||||
algo->sliceSteps = REDUCESCATTER_SLICESTEPS;
|
||||
algo->func = mscclFuncReduceScatter;
|
||||
} else if (strcmp(coll, "allgather") == 0) {
|
||||
algo->sizeMultiplier = nGpus;
|
||||
algo->chunkSteps = ALLGATHER_CHUNKSTEPS;
|
||||
algo->sliceSteps = ALLGATHER_SLICESTEPS;
|
||||
algo->func = mscclFuncAllGather;
|
||||
} else if (strcmp(coll, "send") == 0) {
|
||||
algo->func = mscclFuncSend;
|
||||
} else if (strcmp(coll, "recv") == 0) {
|
||||
algo->func = mscclFuncRecv;
|
||||
} else if (strcmp(coll, "gather") == 0) {
|
||||
algo->func = mscclFuncGather;
|
||||
} else if (strcmp(coll, "scatter") == 0) {
|
||||
algo->func = mscclFuncScatter;
|
||||
} else if (strcmp(coll, "alltoall") == 0) {
|
||||
algo->sizeMultiplier = nGpus;
|
||||
algo->func = mscclFuncAllToAll;
|
||||
} else if (strcmp(coll, "alltoallv") == 0) {
|
||||
algo->func = mscclFuncAllToAllv;
|
||||
} else {
|
||||
WARN("MSCCL: unsupported collective: %s", coll);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
|
||||
int64_t minBytes;
|
||||
NCCLCHECK(mscclXmlGetAttrInt64(topNode, "minBytes", &minBytes));
|
||||
algo->minBytes = minBytes;
|
||||
|
||||
int64_t maxBytes;
|
||||
NCCLCHECK(mscclXmlGetAttrInt64(topNode, "maxBytes", &maxBytes));
|
||||
algo->maxBytes = maxBytes;
|
||||
|
||||
int inplace;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(topNode, "inplace", &inplace));
|
||||
algo->inPlace = (bool)inplace;
|
||||
|
||||
int outofplace;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(topNode, "outofplace", &outofplace));
|
||||
algo->outOfPlace = (bool)outofplace;
|
||||
|
||||
algo->hasReduce = false;
|
||||
|
||||
for (int s=0; s<topNode->nSubs; s++) {
|
||||
struct mscclXmlNode* node = topNode->subs[s];
|
||||
if (strcmp(node->name, "gpu") == 0) {
|
||||
int blockExists[MSCCL_MAX_NUM_THREAD_BLOCKS];
|
||||
memset(blockExists, 0, sizeof(int[MSCCL_MAX_NUM_THREAD_BLOCKS]));
|
||||
int id, nScratchChunks, nInputChunks, nOutputChunks;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(node, "id", &id));
|
||||
if (id == rank) {
|
||||
NCCLCHECK(mscclXmlGetAttrInt(node, "i_chunks", &nInputChunks));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(node, "o_chunks", &nOutputChunks));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(node, "s_chunks", &nScratchChunks));
|
||||
if (nScratchChunks < 0) {
|
||||
WARN("MSCCL: nScratchChunks must be not negative. nScratchChunks: %d", nScratchChunks);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
algo->nScratchChunks = nScratchChunks;
|
||||
for (int t=0; t<node->nSubs; t++) {
|
||||
struct mscclXmlNode* threadBlockNode = node->subs[t];
|
||||
if (strcmp(threadBlockNode->name, "tb") == 0) {
|
||||
int bid, recvPeer, sendPeer, channelId;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(threadBlockNode, "id", &bid));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(threadBlockNode, "recv", &recvPeer));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(threadBlockNode, "send", &sendPeer));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(threadBlockNode, "chan", &channelId));
|
||||
if (bid < 0) {
|
||||
WARN("MSCCL: bid must be not negative. bid: %d", bid);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
if (bid >= MSCCL_MAX_NUM_THREAD_BLOCKS) {
|
||||
WARN("MSCCL: too many thread blocks are requested. Max thread blocks: %d", MSCCL_MAX_NUM_THREAD_BLOCKS);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
if (blockExists[bid]) {
|
||||
WARN("MSCCL: duplicate thread block id %d for MSCCL", bid);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
blockExists[bid] = 1;
|
||||
|
||||
if (recvPeer == id || sendPeer == id) {
|
||||
WARN("MSCCL: peer (%d,%d) and gpu id (%d) must be different", recvPeer, sendPeer, id);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
struct mscclThreadBlock* sTB = &algo->mscclTBs[bid];
|
||||
sTB->nSteps = 0;
|
||||
if (recvPeer < -1 || sendPeer < -1) {
|
||||
WARN("MSCCL: wrong recvPeer (%d) or sendPeer (%d) in thread block %d on gpu %d", recvPeer, sendPeer, bid, id);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
|
||||
if (recvPeer == id || sendPeer == id) {
|
||||
WARN("MSCCL: recvPeer (%d) or sendPeer (%d) for thread block %d cannot be gpu %d", recvPeer, sendPeer, bid, id);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
|
||||
sTB->recvPeer = recvPeer;
|
||||
sTB->sendPeer = sendPeer;
|
||||
if (channelId < 0 || channelId > MAXCHANNELS) {
|
||||
WARN("MSCCL: threadblock %d on GPU %d has an invalid channel %d", bid, id, channelId);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
sTB->channelId = channelId;
|
||||
|
||||
// setting the summary of the msccl algorithm in msccl channels
|
||||
mscclChannelInfo* mscclChannel = &algo->mscclChannels[sTB->channelId];
|
||||
|
||||
int numDependencies = 0;
|
||||
int oldDependencePointer = 0; // Indicator of where the dependencies started for nop
|
||||
|
||||
int oldReductionDstBuffer = -1; // Indicator of last reduction buffer name; -1 means that last one wasn't a compatible reduction
|
||||
int oldReductionDstOffset = -1; // Indicator of last reduction buffer index
|
||||
int oldReductionSrcBuffer = -1; //
|
||||
int numReductions = 0;
|
||||
|
||||
int numTransfers = 0;
|
||||
for (int st=0; st<threadBlockNode->nSubs; st++) {
|
||||
struct mscclXmlNode* stepNode = threadBlockNode->subs[st];
|
||||
if (strcmp(stepNode->name, "step") == 0) {
|
||||
int s, srcOffset, dstOffset, dependBid, dependStep, hasDependence, count;
|
||||
const char* srcBuffer, * dstBuffer, * type;
|
||||
NCCLCHECK(mscclXmlGetAttrInt(stepNode, "s", &s));
|
||||
|
||||
NCCLCHECK(mscclXmlGetAttrInt(stepNode, "srcoff", &srcOffset));
|
||||
NCCLCHECK(mscclXmlGetAttrStr(stepNode, "srcbuf", &srcBuffer));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(stepNode, "dstoff", &dstOffset));
|
||||
NCCLCHECK(mscclXmlGetAttrStr(stepNode, "dstbuf", &dstBuffer));
|
||||
|
||||
NCCLCHECK(mscclXmlGetAttrInt(stepNode, "cnt", &count));
|
||||
NCCLCHECK(mscclXmlGetAttrStr(stepNode, "type", &type));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(stepNode, "depid", &dependBid));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(stepNode, "deps", &dependStep));
|
||||
NCCLCHECK(mscclXmlGetAttrInt(stepNode, "hasdep", &hasDependence));
|
||||
|
||||
if (s >= MSCCL_MAX_NUM_STEPS){
|
||||
WARN("MSCCL: too many steps are requested. Max number of steps: %d, requested: %d", MSCCL_MAX_NUM_STEPS, s+1);
|
||||
return ncclInternalError;
|
||||
}
|
||||
if (s < 0){
|
||||
WARN("MSCCL: step must be positive: step %d", s);
|
||||
return ncclInternalError;
|
||||
}
|
||||
|
||||
int hasSend = 0;
|
||||
int hasRecv = 0;
|
||||
int checkSrc = 0;
|
||||
int checkDst = 0;
|
||||
int transferType = -1; // -1 indicate a nop
|
||||
if (strcmp(type, "s") == 0) {
|
||||
transferType = MSCCL_SEND;
|
||||
hasSend = 1;
|
||||
checkSrc = 1;
|
||||
} else if (strcmp(type, "r") == 0) {
|
||||
transferType = MSCCL_RECV;
|
||||
hasRecv = 1;
|
||||
checkDst = 1;
|
||||
} else if (strcmp(type, "rcs") == 0) {
|
||||
transferType = MSCCL_RECV_COPY_SEND;
|
||||
hasSend = 1;
|
||||
hasRecv = 1;
|
||||
checkDst = 1;
|
||||
} else if (strcmp(type, "rrs") == 0) {
|
||||
transferType = MSCCL_RECV_REDUCE_SEND;
|
||||
hasSend = 1;
|
||||
hasRecv = 1;
|
||||
checkSrc = 1;
|
||||
algo->hasReduce = true;
|
||||
} else if (strcmp(type, "rrc") == 0) {
|
||||
transferType = MSCCL_RECV_REDUCE_COPY;
|
||||
hasRecv = 1;
|
||||
algo->hasReduce = true;
|
||||
} else if (strcmp(type, "rrcs") == 0) {
|
||||
transferType = MSCCL_RECV_REDUCE_COPY_SEND;
|
||||
hasRecv = 1;
|
||||
hasSend = 1;
|
||||
checkSrc = 1;
|
||||
checkDst = 1;
|
||||
algo->hasReduce = true;
|
||||
} else if (strcmp(type, "cpy") == 0) {
|
||||
transferType = MSCCL_LOCAL_COPY;
|
||||
checkSrc = 1;
|
||||
checkDst = 1;
|
||||
} else if (strcmp(type, "re") == 0) {
|
||||
transferType = MSCCL_REDUCE;
|
||||
checkSrc = 1;
|
||||
checkDst = 1;
|
||||
algo->hasReduce = true;
|
||||
} else if (strcmp(type, "nop") == 0) {
|
||||
transferType = -1;
|
||||
} else {
|
||||
WARN("MSCCL: type of transfer is not supported: %s", type);
|
||||
return ncclInternalError;
|
||||
}
|
||||
|
||||
if (dependBid >= 0) {
|
||||
sTB->dependentBid[numDependencies] = dependBid;
|
||||
sTB->dependentStep[numDependencies] = dependStep;
|
||||
numDependencies++;
|
||||
}
|
||||
|
||||
uint8_t srcBufferInt = 0;
|
||||
uint8_t dstBufferInt = 0;
|
||||
NCCLCHECK(mscclGetBufferType(srcBuffer, &srcBufferInt));
|
||||
NCCLCHECK(mscclGetBufferType(dstBuffer, &dstBufferInt));
|
||||
|
||||
int continuationOfReductions = 0;
|
||||
// Analyze to see if this is in the same list of reductions for them to be chained
|
||||
if (transferType == MSCCL_REDUCE) {
|
||||
if (oldReductionDstBuffer == dstBufferInt && oldReductionDstOffset == dstOffset && oldReductionSrcBuffer == srcBufferInt && dependBid == -1) {
|
||||
numTransfers--; // reuse the same transfer
|
||||
continuationOfReductions = 1;
|
||||
} else {
|
||||
oldReductionDstBuffer = -1;
|
||||
oldReductionDstOffset = -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (transferType != -1) {
|
||||
struct mscclTransmission* mscclTran = &sTB->transmissions[numTransfers];
|
||||
mscclTran->type = transferType;
|
||||
mscclTran->srcOffset = srcOffset;
|
||||
mscclTran->srcBuffer = srcBufferInt;
|
||||
mscclTran->srcOffset = srcOffset;
|
||||
mscclTran->dstBuffer = dstBufferInt;
|
||||
mscclTran->dstOffset = dstOffset;
|
||||
|
||||
if (count < 0 || count >= MSCCL_MAX_COUNT){
|
||||
WARN("MSCCL: count (%d) must be positive and less than %d", count, MSCCL_MAX_COUNT);
|
||||
return ncclInternalError;
|
||||
}
|
||||
|
||||
mscclTran->count = count;
|
||||
|
||||
if (hasSend) {
|
||||
if (sendPeer < 0) {
|
||||
WARN("MSCCL: there is a send in thread block %d on GPU %d without a sendPeer.", bid, id);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
if (mscclChannel->nSendPeers >= MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL) {
|
||||
WARN("MSCCL: too many sends per channel. Max allowed %d", MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
|
||||
struct mscclChannelPeerInfo* sendPeerInfo = &mscclChannel->sendPeerInfo[mscclChannel->nSendPeers];
|
||||
sendPeerInfo->nTransmissionsOfCount[count]++;
|
||||
}
|
||||
if (hasRecv) {
|
||||
if (recvPeer < 0) {
|
||||
WARN("MSCCL: there is a recv in thread block %d on GPU %d without a recvPeer.", bid, id);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
if (mscclChannel->nRecvPeers >= MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL) {
|
||||
WARN("MSCCL: too many recvs per channel. Max allowed %d", MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
struct mscclChannelPeerInfo* recvPeerInfo = &mscclChannel->recvPeerInfo[mscclChannel->nRecvPeers];
|
||||
recvPeerInfo->nTransmissionsOfCount[count]++;
|
||||
}
|
||||
|
||||
if (checkSrc) NCCLCHECK(mscclCheckBufferBounds(mscclTran->srcBuffer, mscclTran->srcOffset, nInputChunks, nOutputChunks, nScratchChunks));
|
||||
if (checkDst) NCCLCHECK(mscclCheckBufferBounds(mscclTran->dstBuffer, mscclTran->dstOffset, nInputChunks, nOutputChunks, nScratchChunks));
|
||||
|
||||
if (!continuationOfReductions) {
|
||||
mscclTran->dependencePointer = oldDependencePointer;
|
||||
mscclTran->numDependencies = numDependencies - oldDependencePointer;
|
||||
if (mscclTran->numDependencies > 0 && dependBid < 0) {
|
||||
WARN("MSCCL: when there is a chain of dependencies, the last reduction must be a part of the first immediate instruction. Detected for GPU %d, thread block %d, and step %d. XML will be ignored.", id, bid, s);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
oldDependencePointer = numDependencies;
|
||||
}
|
||||
|
||||
// reduction related pointers
|
||||
if (transferType != MSCCL_REDUCE) {
|
||||
oldReductionDstBuffer = -1;
|
||||
oldReductionDstOffset = -1;
|
||||
oldReductionSrcBuffer = -1;
|
||||
} else {
|
||||
if (oldReductionDstBuffer == -1) { // if this is the first reduction
|
||||
mscclTran->reductionPointer = numReductions;
|
||||
}
|
||||
sTB->reductionSrcOffsets[numReductions] = mscclTran->srcOffset;
|
||||
numReductions++;
|
||||
mscclTran->numReductions = numReductions - mscclTran->reductionPointer;
|
||||
|
||||
if (hasDependence || numReductions == MSCCL_MAX_REDUCE_FUSION) {
|
||||
oldReductionDstBuffer = -1;
|
||||
oldReductionDstOffset = -1;
|
||||
} else {
|
||||
oldReductionDstBuffer = mscclTran->dstBuffer;
|
||||
oldReductionDstOffset = mscclTran->dstOffset;
|
||||
oldReductionSrcBuffer = mscclTran->srcBuffer;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (hasDependence != 0 && hasDependence != 1) {
|
||||
WARN("MSCCL: hasDependence needs to be 0 or 1, but it was %d", hasDependence);
|
||||
return ncclInternalError;
|
||||
}
|
||||
mscclTran->hasDependence = hasDependence;
|
||||
|
||||
numTransfers++;
|
||||
sTB->nSteps = numTransfers;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finish up mscclChannel calculation
|
||||
|
||||
for (int c = 0; c < MSCCL_MAX_COUNT; c++) {
|
||||
struct mscclChannelPeerInfo* sendPeer = &mscclChannel->sendPeerInfo[mscclChannel->nSendPeers];
|
||||
if (sendPeer->nTransmissionsOfCount[c] > 0) {
|
||||
sendPeer->existingCounts[sendPeer->nExistingCounts] = c;
|
||||
sendPeer->nExistingCounts++;
|
||||
}
|
||||
struct mscclChannelPeerInfo* recvPeer = &mscclChannel->recvPeerInfo[mscclChannel->nRecvPeers];
|
||||
if (recvPeer->nTransmissionsOfCount[c] > 0) {
|
||||
recvPeer->existingCounts[recvPeer->nExistingCounts] = c;
|
||||
recvPeer->nExistingCounts++;
|
||||
}
|
||||
}
|
||||
|
||||
if (sTB->sendPeer >= 0) {
|
||||
mscclChannel->sendPeerInfo[mscclChannel->nSendPeers].peer = sTB->sendPeer;
|
||||
mscclChannel->nSendPeers++;
|
||||
}
|
||||
if (sTB->recvPeer >= 0) {
|
||||
mscclChannel->recvPeerInfo[mscclChannel->nRecvPeers].peer = sTB->recvPeer;
|
||||
mscclChannel->nRecvPeers++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// make sure that thread blocks are in order. Something like 0, 2, 3 is not allowed.
|
||||
if (blockExists[0] == 1) {
|
||||
algo->nBlocks = 1;
|
||||
}
|
||||
for (int i = 1; i < MSCCL_MAX_NUM_THREAD_BLOCKS; i++) {
|
||||
if (blockExists[i] == 1 && blockExists[i-1] == 0) {
|
||||
WARN("MSCCL: thread block %d is missing", i);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
if (blockExists[i] == 1) {
|
||||
algo->nBlocks = i+1;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
free(xml);
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#include "checks.h"
|
||||
#include "collectives.h"
|
||||
#include "proxy.h"
|
||||
#include "transport.h"
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
#include "msccl/msccl_kernel.h"
|
||||
#include "msccl/msccl_setup.h"
|
||||
#include "msccl/msccl_status.h"
|
||||
|
||||
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS;
|
||||
status.chunkSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->chunkSteps : 1;
|
||||
status.sliceSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->sliceSteps : 1;
|
||||
status.chunkSize = status.stepSize * status.chunkSteps;
|
||||
status.chunkEffectiveSize = status.chunkSize;
|
||||
if (hostAlgo->protocol == NCCL_PROTO_LL) status.chunkEffectiveSize /= 2;
|
||||
if (hostAlgo->protocol == NCCL_PROTO_LL128) status.chunkEffectiveSize = (status.chunkSize / NCCL_LL128_LINEELEMS) * NCCL_LL128_DATAELEMS;
|
||||
status.dataType = dataType;
|
||||
status.nBytes = count * ncclTypeSize(status.dataType) * hostAlgo->sizeMultiplier;
|
||||
status.maxAllowedCount = std::max((uint32_t)1, (uint32_t)(status.chunkEffectiveSize / DIVUP(status.nBytes, (size_t)(hostAlgo->nChunksPerLoop))));
|
||||
if (status.maxAllowedCount == 0){
|
||||
WARN("MSCCL: something went wrong. Max allowed count is 0\n");
|
||||
return ncclInternalError;
|
||||
}
|
||||
if (status.maxAllowedCount >= MSCCL_MAX_COUNT) {
|
||||
status.maxAllowedCount = MSCCL_MAX_COUNT - 1;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
size_t sizeNeeded = (status.nBytes * (size_t)(hostAlgo->nScratchChunks)) / (size_t)(hostAlgo->nChunksPerLoop);
|
||||
if (sizeNeeded > status.scratchBufferSize){
|
||||
CUDACHECK(hipStreamSynchronize(stream));
|
||||
CUDACHECK(hipFree(status.scratchBuffer));
|
||||
NCCLCHECK(ncclCudaCalloc((char**)&status.scratchBuffer, sizeNeeded));
|
||||
status.scratchBufferSize = sizeNeeded;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupSyncFlags(hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
if (status.workIndex > (1ULL << (8*sizeof(status.workIndex))) - 2 * NCCL_MAX_OPS - 1) {
|
||||
CUDACHECK(hipMemsetAsync(status.syncFlags, 0, sizeof(struct mscclFlag) * MSCCL_MAX_NUM_THREAD_BLOCKS, stream));
|
||||
status.workIndex = 1; // setting the workIndex back to 1 for next iterations
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
|
||||
// Check whether there is enough channels
|
||||
if (hostAlgo->nChannels > comm->nChannels) {
|
||||
WARN("MSCCL: number of channels available (%d) less than required (%d)", comm->nChannels, hostAlgo->nChannels);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
|
||||
// Flag MSCCL connections
|
||||
for (int i = 0; i < hostAlgo->nChannels; i++) {
|
||||
struct mscclChannelInfo* mCh = hostAlgo->mscclChannels + i;
|
||||
|
||||
int sendPeers[MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
for (int p = 0; p < mCh->nSendPeers; p++) {
|
||||
sendPeers[p] = mCh->sendPeerInfo[p].peer;
|
||||
}
|
||||
|
||||
int recvPeers[MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
|
||||
for (int p = 0; p < mCh->nRecvPeers; p++) {
|
||||
recvPeers[p] = mCh->recvPeerInfo[p].peer;
|
||||
}
|
||||
|
||||
NCCLCHECK(ncclTransportP2pConnect(comm, i, mCh->nRecvPeers, recvPeers, mCh->nSendPeers, sendPeers, 0 /*connIndex*/));
|
||||
}
|
||||
|
||||
// Connect MSCCL connections
|
||||
mscclSetIsCallerFlag();
|
||||
NCCLCHECK(ncclTransportP2pSetup(comm, NULL, 0));
|
||||
mscclClearIsCallerFlag();
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
struct ncclProxyOp proxyOp = {};
|
||||
proxyOp.connIndex = 0;
|
||||
proxyOp.sliceSteps = status.sliceSteps;
|
||||
proxyOp.chunkSteps = status.chunkSteps;
|
||||
proxyOp.chunkSize = status.chunkSize;
|
||||
proxyOp.protocol = hostAlgo->protocol;
|
||||
proxyOp.dtype = status.dataType;
|
||||
proxyOp.redOp = 0;
|
||||
proxyOp.pattern = 0;
|
||||
proxyOp.root = 0;
|
||||
proxyOp.nbytes = status.stepSize*proxyOp.sliceSteps;
|
||||
proxyOp.opCount = comm->collOpCount;
|
||||
int nLoops = (int)(DIVUP(status.nBytes, (size_t)((size_t)hostAlgo->nChunksPerLoop*(size_t)status.chunkEffectiveSize)));
|
||||
int nLoopsChunkSteps = nLoops * status.chunkSteps;
|
||||
for (int ch = 0; ch < hostAlgo->nChannels; ch++) {
|
||||
proxyOp.channelId = ch;
|
||||
struct mscclChannelInfo* mscclChannel = hostAlgo->mscclChannels + ch;
|
||||
struct ncclChannel* ncclChannel = comm->channels + ch;
|
||||
for (int i = 0; i < mscclChannel->nRecvPeers; i++){
|
||||
struct mscclChannelPeerInfo* recvPeer = mscclChannel->recvPeerInfo + i;
|
||||
int nRecvs = 0;
|
||||
for (int j = 0; j < recvPeer->nExistingCounts; j++){
|
||||
int c = recvPeer->existingCounts[j];
|
||||
int nStepsInCount = DIVUP(c+1, status.maxAllowedCount);
|
||||
nRecvs += recvPeer->nTransmissionsOfCount[c] * nStepsInCount;
|
||||
}
|
||||
proxyOp.nsteps = nLoopsChunkSteps * nRecvs;
|
||||
if (proxyOp.nsteps > 0) {
|
||||
NCCLCHECK(mscclSaveProxy(ncclChannel, proxyRecv, recvPeer->peer, &proxyOp, 0));
|
||||
}
|
||||
}
|
||||
for (int i=0; i<mscclChannel->nSendPeers; i++){
|
||||
struct mscclChannelPeerInfo* sendPeer = &mscclChannel->sendPeerInfo[i];
|
||||
int nSends = 0;
|
||||
for (int j = 0; j < sendPeer->nExistingCounts; j++){
|
||||
int c = sendPeer->existingCounts[j];
|
||||
int nStepsInCount = DIVUP(c+1, status.maxAllowedCount);
|
||||
nSends += sendPeer->nTransmissionsOfCount[c] * nStepsInCount;
|
||||
}
|
||||
proxyOp.nsteps = nLoopsChunkSteps * nSends;
|
||||
if (proxyOp.nsteps > 0) {
|
||||
NCCLCHECK(mscclSaveProxy(ncclChannel, proxySend, sendPeer->peer, &proxyOp, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
NCCLCHECK(ncclProxyStart(comm));
|
||||
comm->collOpCount++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t hostToDevRedOp(
|
||||
ncclDevRedOpFull *opFull, ncclRedOp_t op, ncclDataType_t datatype, ncclComm *comm
|
||||
) {
|
||||
union {
|
||||
int8_t i8;
|
||||
uint8_t u8;
|
||||
int32_t i32;
|
||||
uint32_t u32;
|
||||
int64_t i64;
|
||||
uint64_t u64;
|
||||
half f16;
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
rccl_bfloat16 bf16;
|
||||
#endif
|
||||
float f32;
|
||||
double f64;
|
||||
void *ptr;
|
||||
};
|
||||
u64 = 0;
|
||||
opFull->scalarArgIsPtr = false;
|
||||
switch (int(op)) {
|
||||
case ncclSum: opFull->op = ncclDevSum; break;
|
||||
case ncclProd: opFull->op = ncclDevProd; break;
|
||||
case ncclMax: opFull->op = ncclDevMax; break;
|
||||
case ncclMin: opFull->op = ncclDevMin; break;
|
||||
case ncclAvg:
|
||||
switch ((int)datatype) {
|
||||
case ncclInt8: case ncclInt32: case ncclInt64:
|
||||
case ncclUint8: case ncclUint32: case ncclUint64:
|
||||
opFull->op = ncclDevSumPostDiv;
|
||||
u64 = comm->nRanks;
|
||||
break;
|
||||
case ncclFloat16:
|
||||
opFull->op = ncclDevPreMulSum;
|
||||
f16 = __float2half(float(1.0/comm->nRanks)); // __double2half not supported pre CUDA 11.x
|
||||
break;
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
case ncclBfloat16:
|
||||
opFull->op = ncclDevPreMulSum;
|
||||
bf16 = (rccl_bfloat16)(float(1.0/comm->nRanks));
|
||||
break;
|
||||
#endif
|
||||
case ncclFloat32:
|
||||
opFull->op = ncclDevPreMulSum;
|
||||
f32 = float(1.0/comm->nRanks);
|
||||
break;
|
||||
case ncclFloat64:
|
||||
opFull->op = ncclDevPreMulSum;
|
||||
f64 = 1.0/comm->nRanks;
|
||||
break;
|
||||
}
|
||||
opFull->scalarArgIsPtr = false;
|
||||
opFull->scalarArg = u64;
|
||||
break;
|
||||
default: // user created
|
||||
int ix = int(ncclUserRedOpMangle(comm, op)) - int(ncclNumOps);
|
||||
ncclUserRedOp *user = &comm->userRedOps[ix];
|
||||
if (datatype != user->datatype) {
|
||||
WARN("Data type supplied to user-created ncclRedOp_t does not match type "
|
||||
"given to reduction operation");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
*opFull = user->opFull;
|
||||
break;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP_NULL() \
|
||||
nullptr, \
|
||||
nullptr, \
|
||||
nullptr
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, type) \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL), \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128), \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP(devredop) \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(devredop) \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NULL()
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY() \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Min), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Max), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv)
|
||||
|
||||
void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS] = {
|
||||
MSCCL_KERNEL_ENTRY()
|
||||
};
|
||||
|
||||
ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count,
|
||||
ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo,
|
||||
ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
dim3 grid = {(uint32_t)hostAlgo->nBlocks, 1, 1};
|
||||
dim3 block = {NCCL_MAX_NTHREADS, 1, 1};
|
||||
ncclDevRedOpFull opFull;
|
||||
NCCLCHECK(hostToDevRedOp(&opFull, op, dataType, comm));
|
||||
|
||||
mscclWork work;
|
||||
work.syncFlags = status.syncFlags;
|
||||
work.scratchBuffer = status.scratchBuffer;
|
||||
work.sendBuff = sendBuff;
|
||||
work.recvBuff = recvBuff;
|
||||
work.count = count * hostAlgo->sizeMultiplier; // count is sum of all ranks in MSCCL kernel
|
||||
work.redOpArg = opFull.scalarArg;
|
||||
work.workIndex = status.workIndex;
|
||||
work.nChunksPerLoop = hostAlgo->nChunksPerLoop;
|
||||
work.maxAllowedCount = status.maxAllowedCount;
|
||||
work.hasReduce = hostAlgo->hasReduce;
|
||||
work.redOpArgIsPtr = opFull.scalarArgIsPtr;
|
||||
|
||||
void *args[3] = {&comm->devComm, &devAlgo, &work};
|
||||
void *func = mscclKernelEntries[(opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol];
|
||||
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, NULL,0));
|
||||
status.workIndex++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT License.
|
||||
************************************************************************/
|
||||
|
||||
#include "msccl/msccl_status.h"
|
||||
|
||||
mscclStatus& mscclGetStatus() {
|
||||
static mscclStatus status;
|
||||
return status;
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -498,6 +499,44 @@ ncclResult_t pncclAllToAllv(const void *sendbuff, const size_t sendcounts[],
|
||||
const size_t rdispls[], ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream);
|
||||
/// @endcond
|
||||
|
||||
/*! @brief Opaque handle to MSCCL algorithm */
|
||||
typedef int mscclAlgoHandle_t;
|
||||
|
||||
/*! @brief MSCCL Load Algorithm
|
||||
*
|
||||
* @details Load MSCCL algorithm file specified in mscclAlgoFilePath and return
|
||||
* its handle via mscclAlgoHandle. This API is expected to be called by MSCCL
|
||||
* scheduler instead of end users.
|
||||
*/
|
||||
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle);
|
||||
ncclResult_t pmscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle);
|
||||
|
||||
/*! @brief MSCCL Run Algorithm
|
||||
*
|
||||
* @details Run MSCCL algorithm specified by mscclAlgoHandle. The parameter
|
||||
* list merges all possible parameters required by different operations as this
|
||||
* is a general-purposed API. This API is expected to be called by MSCCL
|
||||
* scheduler instead of end users.
|
||||
*/
|
||||
ncclResult_t mscclRunAlgo(
|
||||
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
|
||||
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
|
||||
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
|
||||
mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream);
|
||||
ncclResult_t pmscclRunAlgo(
|
||||
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
|
||||
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
|
||||
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
|
||||
mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream);
|
||||
|
||||
/*! @brief MSCCL Load Algorithm
|
||||
*
|
||||
* @details Unload MSCCL algorithm previous loaded using its handle. This API
|
||||
* is expected to be called by MSCCL scheduler instead of end users.
|
||||
*/
|
||||
ncclResult_t mscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle);
|
||||
ncclResult_t pmscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle);
|
||||
|
||||
/*
|
||||
* Group semantics
|
||||
*
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -16,8 +17,6 @@
|
||||
|
||||
#include <sys/syscall.h>
|
||||
|
||||
enum { proxyRecv=0, proxySend=1 };
|
||||
|
||||
static bool NeedProxy(int type, int pattern, int root, struct ncclRing* ring, int nranks) {
|
||||
if (pattern == ncclPatternRing || pattern == ncclPatternRingTwice) return true;
|
||||
|
||||
@@ -371,6 +370,11 @@ static ncclResult_t SaveProxy(struct ncclChannel* channel, int type, int peer, s
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclSaveProxy(struct ncclChannel* channel, int type, int peer, struct ncclProxyOp* op, int connIndex) {
|
||||
NCCLCHECK(SaveProxy(channel, type, peer, op, connIndex, nullptr));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// justInquire != nullptr means don't actually do anything, just assertain need of
|
||||
// ncclProxySaveOp for this op.
|
||||
ncclResult_t ncclProxySaveOp(struct ncclComm* comm, struct ncclProxyOp* op, bool* justInquire) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -18,6 +19,7 @@
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
static_assert(sizeof(ncclNetHandle_t) <= CONNECT_SIZE, "NET Connect info is too large");
|
||||
|
||||
@@ -175,7 +177,7 @@ struct setupReq {
|
||||
static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) {
|
||||
struct setupReq req;
|
||||
|
||||
send->conn.shared = req.shared = graph ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1;
|
||||
send->conn.shared = req.shared = (graph || mscclAvailable() && mscclIsCaller()) ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1;
|
||||
req.channelId = channelId;
|
||||
req.connIndex = connIndex;
|
||||
req.curr_hdp_reg = 0;
|
||||
@@ -217,7 +219,7 @@ NCCL_PARAM(GdrCopyFlushEnable, "GDRCOPY_FLUSH_ENABLE", 0);
|
||||
static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId, int connIndex) {
|
||||
struct setupReq req;
|
||||
|
||||
recv->conn.shared = req.shared = graph ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1;
|
||||
recv->conn.shared = req.shared = (graph || mscclAvailable() && mscclIsCaller()) ? 0 : ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : 1;
|
||||
req.channelId = channelId;
|
||||
req.connIndex = connIndex;
|
||||
req.netDev = -1;
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "TestBed.hpp"
|
||||
|
||||
namespace RcclUnitTesting
|
||||
{
|
||||
TEST(AllReduce, MscclSingleCall)
|
||||
{
|
||||
TestBed testBed;
|
||||
|
||||
// Configuration
|
||||
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllReduce};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclInt8, ncclInt32, ncclFloat32};
|
||||
std::vector<ncclRedOp_t> const redOps = {ncclSum, ncclProd};
|
||||
std::vector<int> const roots = {0};
|
||||
std::vector<int> const numElements = {384 * 1024, 384};
|
||||
std::vector<bool> const inPlaceList = {true, false};
|
||||
std::vector<bool> const managedMemList = {true, false};
|
||||
std::vector<bool> const useHipGraphList = {true, false};
|
||||
|
||||
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList);
|
||||
testBed.Finalize();
|
||||
}
|
||||
|
||||
TEST(AllReduce, MscclGroupCall)
|
||||
{
|
||||
TestBed testBed;
|
||||
|
||||
// Configuration
|
||||
ncclFunc_t const funcType = ncclCollAllReduce;
|
||||
std::vector<ncclDataType_t> const& dataTypes = {ncclFloat};
|
||||
std::vector<ncclRedOp_t> const& redOps = {ncclSum};
|
||||
std::vector<int> const numElements = {384 * 1024, 384};
|
||||
bool const inPlace = false;
|
||||
bool const useManagedMem = false;
|
||||
int const numCollPerGroup = numElements.size();
|
||||
|
||||
OptionalColArgs options;
|
||||
// This tests runs 3 collectives in the same group call
|
||||
bool isCorrect = true;
|
||||
for (int totalRanks = testBed.ev.minGpus; totalRanks <= testBed.ev.maxGpus && isCorrect; ++totalRanks)
|
||||
for (int isMultiProcess = 0; isMultiProcess <= 1 && isCorrect; ++isMultiProcess)
|
||||
{
|
||||
if (!(testBed.ev.processMask & (1 << isMultiProcess))) continue;
|
||||
|
||||
// Test either single process all GPUs, or 1 process per GPU
|
||||
int const numProcesses = isMultiProcess ? totalRanks : 1;
|
||||
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks), numCollPerGroup);
|
||||
|
||||
for (int redOpIdx = 0; redOpIdx < redOps.size() && isCorrect; ++redOpIdx)
|
||||
{
|
||||
options.redOp = redOps[redOpIdx];
|
||||
for (int dataIdx = 0; dataIdx < dataTypes.size() && isCorrect; ++dataIdx)
|
||||
{
|
||||
if (testBed.ev.showNames)
|
||||
INFO("%s %d-ranks AllReduce %d Grouped Calls (%s-%s)\n",
|
||||
isMultiProcess ? "MP" : "SP",
|
||||
totalRanks, numCollPerGroup,
|
||||
ncclRedOpNames[redOps[redOpIdx]], ncclDataTypeNames[dataTypes[dataIdx]]);
|
||||
|
||||
// Run all element sizes in parallel as single group
|
||||
for (int collIdx = 0; collIdx < numCollPerGroup; ++collIdx)
|
||||
{
|
||||
testBed.SetCollectiveArgs(funcType,
|
||||
dataTypes[dataIdx],
|
||||
numElements[collIdx],
|
||||
numElements[collIdx],
|
||||
options,
|
||||
collIdx);
|
||||
}
|
||||
testBed.AllocateMem(inPlace, useManagedMem);
|
||||
testBed.PrepareData();
|
||||
testBed.ExecuteCollectives();
|
||||
testBed.ValidateResults(isCorrect);
|
||||
testBed.DeallocateMem();
|
||||
}
|
||||
}
|
||||
testBed.DestroyComms();
|
||||
}
|
||||
testBed.Finalize();
|
||||
}
|
||||
|
||||
TEST(AllReduce, MscclPreMultScalar)
|
||||
{
|
||||
TestBed testBed;
|
||||
|
||||
// Configuration
|
||||
ncclFunc_t const funcType = ncclCollAllReduce;
|
||||
std::vector<ncclDataType_t> const& dataTypes = {ncclInt32, ncclFloat32, ncclFloat64};
|
||||
ncclRedOp_t const redOp = ncclSum;
|
||||
std::vector<int> const numElements = {384 * 1024, 384};
|
||||
bool const inPlace = false;
|
||||
bool const useManagedMem = false;
|
||||
|
||||
OptionalColArgs options;
|
||||
// Terminate the test as soon as first failure occurs
|
||||
bool isCorrect = true;
|
||||
for (int totalRanks = testBed.ev.minGpus; totalRanks <= testBed.ev.maxGpus && isCorrect; ++totalRanks)
|
||||
for (int isMultiProcess = 0; isMultiProcess <= 1; ++isMultiProcess)
|
||||
{
|
||||
if (!(testBed.ev.processMask & (1 << isMultiProcess))) continue;
|
||||
|
||||
int const numProcesses = isMultiProcess ? totalRanks : 1;
|
||||
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks));
|
||||
|
||||
for (int dataIdx = 0; dataIdx < dataTypes.size() && isCorrect; ++dataIdx)
|
||||
{
|
||||
ncclDataType_t const dataType = dataTypes[dataIdx];
|
||||
|
||||
// Set scalars per rank
|
||||
PtrUnion scalarsPerRank;
|
||||
scalarsPerRank.AllocateCpuMem(totalRanks * DataTypeToBytes(dataType));
|
||||
for (int i = 0; i < totalRanks; i++)
|
||||
{
|
||||
double F = i;
|
||||
scalarsPerRank.Set(dataType, i, i, F);
|
||||
}
|
||||
int const numBytes = totalRanks * DataTypeToBytes(dataType);
|
||||
memcpy(options.scalarTransport.ptr, scalarsPerRank.ptr, numBytes);
|
||||
|
||||
// Test various scalar residence modes
|
||||
for (int scalarMode = 0; scalarMode <= 1 && isCorrect; ++scalarMode)
|
||||
{
|
||||
if (testBed.ev.showNames)
|
||||
INFO("%s %d-ranks AllReduce (custom-scalar Mode %d %s)\n",
|
||||
isMultiProcess ? "MP" : "SP",
|
||||
totalRanks, scalarMode, ncclDataTypeNames[dataType]);
|
||||
|
||||
for (int i = 0; i < numElements.size() && isCorrect; ++i)
|
||||
{
|
||||
options.scalarMode = scalarMode;
|
||||
options.redOp = redOp;
|
||||
testBed.SetCollectiveArgs(funcType, dataType,
|
||||
numElements[i], numElements[i],
|
||||
options);
|
||||
// For performance, only allocate and prepare data on largest size
|
||||
if (i == 0)
|
||||
{
|
||||
testBed.AllocateMem(inPlace, useManagedMem);
|
||||
testBed.PrepareData();
|
||||
}
|
||||
testBed.ExecuteCollectives();
|
||||
testBed.ValidateResults(isCorrect);
|
||||
}
|
||||
testBed.DeallocateMem();
|
||||
}
|
||||
}
|
||||
testBed.DestroyComms();
|
||||
}
|
||||
testBed.Finalize();
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
# Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
|
||||
cmake_minimum_required(VERSION 2.8.12)
|
||||
|
||||
if(BUILD_TESTS)
|
||||
@@ -48,6 +50,7 @@ if(BUILD_TESTS)
|
||||
AllReduce_ManagedMem.cpp
|
||||
AllReduce_OutOfPlace.cpp
|
||||
AllReduce_PreMultScalar.cpp
|
||||
AllReduce_Msccl.cpp
|
||||
)
|
||||
else()
|
||||
set(TEST_SOURCE_FILES
|
||||
@@ -58,6 +61,7 @@ if(BUILD_TESTS)
|
||||
AllReduce_ManagedMem.cpp
|
||||
AllReduce_OutOfPlace.cpp
|
||||
AllReduce_PreMultScalar.cpp
|
||||
AllReduce_Msccl.cpp
|
||||
#AllGather
|
||||
AllGather_InPlace.cpp
|
||||
AllGather_ManagedMem.cpp
|
||||
|
||||
Tá difríocht comhad cosc orthu toisc go bhfuil sé ró-mhór
Difríocht Luchtaigh
Tá difríocht comhad cosc orthu toisc go bhfuil sé ró-mhór
Difríocht Luchtaigh
Tá difríocht comhad cosc orthu toisc go bhfuil sé ró-mhór
Difríocht Luchtaigh
Tagairt in Eagrán Nua
Cuir bac ar úsáideoir