Merge remote-tracking branch 'nccl/master' into 2.10.3
[ROCm/rccl commit: bf2339f93e]
Этот коммит содержится в:
@@ -16,9 +16,11 @@ __hidden ncclResult_t pluginPtrSupport(int dev, int* supportedTypes) { return nc
|
||||
__hidden ncclResult_t pluginListen(int dev, void* handle, void** listenComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginConnect(int dev, void* handle, void** sendComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginAccept(void* listenComm, void** recvComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginIsend(void* sendComm, void* data, int size, int type, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginIrecv(void* recvComm, void* data, int size, int type, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginFlush(void* recvComm, void* data, int size) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginRegMr(void* collComm, void* data, int size, int type, void** mhandle) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginDeregMr(void* collComm, void* mhandle) { return ncclInternalError;}
|
||||
__hidden ncclResult_t pluginIsend(void* sendComm, void* data, int size, void* mhandle, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginFlush(void* recvComm, void* data, int size, void* mhandle) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginTest(void* request, int* done, int* size) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCloseSend(void* sendComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCloseRecv(void* recvComm) { return ncclInternalError; }
|
||||
@@ -33,6 +35,8 @@ ncclNet_t NCCL_PLUGIN_SYMBOL = {
|
||||
pluginListen,
|
||||
pluginConnect,
|
||||
pluginAccept,
|
||||
pluginRegMr,
|
||||
pluginDeregMr,
|
||||
pluginIsend,
|
||||
pluginIrecv,
|
||||
pluginFlush,
|
||||
@@ -41,3 +45,36 @@ ncclNet_t NCCL_PLUGIN_SYMBOL = {
|
||||
pluginCloseRecv,
|
||||
pluginCloseListen
|
||||
};
|
||||
|
||||
__hidden ncclResult_t pluginCollNetInit(ncclDebugLogger_t logFunction) { return ncclSuccess; }
|
||||
__hidden ncclResult_t pluginCollNetDevices(int* ndev) { *ndev = 0; return ncclSuccess; }
|
||||
__hidden ncclResult_t pluginCollNetPciPath(int dev, char** path) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetPtrSupport(int dev, int* supportedTypes) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetListen(int dev, void* handle, void** listenComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetConnect(void* handles[], int nranks, int rank, void* listenComm, void** collComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetReduceSupport(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetRegMr(void* collComm, void* data, int size, int type, void** mhandle) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetDeregMr(void* collComm, void* mhandle) { return ncclInternalError;}
|
||||
__hidden ncclResult_t pluginCollNetIallreduce(void* collComm, void* sendData, void* recvData, int count, ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetFlush(void* collComm, void* data, int size, void* mhandle) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetTest(void* request, int* done, int* size) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetCloseColl(void* collComm) { return ncclInternalError; }
|
||||
__hidden ncclResult_t pluginCollNetCloseListen(void* listenComm) { return ncclInternalError; }
|
||||
|
||||
ncclCollNet_t NCCL_COLLNET_PLUGIN_SYMBOL = {
|
||||
"Dummy",
|
||||
pluginCollNetInit,
|
||||
pluginCollNetDevices,
|
||||
pluginCollNetPciPath,
|
||||
pluginCollNetPtrSupport,
|
||||
pluginCollNetListen,
|
||||
pluginCollNetConnect,
|
||||
pluginCollNetReduceSupport,
|
||||
pluginCollNetRegMr,
|
||||
pluginCollNetDeregMr,
|
||||
pluginCollNetIallreduce,
|
||||
pluginCollNetFlush,
|
||||
pluginCollNetTest,
|
||||
pluginCollNetCloseColl,
|
||||
pluginCollNetCloseListen
|
||||
};
|
||||
|
||||
@@ -55,7 +55,7 @@ CXXFLAGS := -DCUDA_MAJOR=$(CUDA_MAJOR) -DCUDA_MINOR=$(CUDA_MINOR) -fPIC -fvisi
|
||||
# Maxrregcount needs to be set accordingly to NCCL_MAX_NTHREADS (otherwise it will cause kernel launch errors)
|
||||
# 512 : 120, 640 : 96, 768 : 80, 1024 : 60
|
||||
# We would not have to set this if we used __launch_bounds__, but this only works on kernels, not on functions.
|
||||
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 -Xptxas -maxrregcount=96 -Xfatbin -compress-all
|
||||
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xptxas -maxrregcount=96 -Xfatbin -compress-all
|
||||
# Use addprefix so that we can specify more than one path
|
||||
NVLDFLAGS := -L${CUDA_LIB} -lcudart -lrt
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
##### version
|
||||
NCCL_MAJOR := 2
|
||||
NCCL_MINOR := 9
|
||||
NCCL_PATCH := 9
|
||||
NCCL_MINOR := 10
|
||||
NCCL_PATCH := 3
|
||||
NCCL_SUFFIX :=
|
||||
PKG_REVISION := 1
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
include/nccl.h /usr/include
|
||||
include/nccl_net.h /usr/include
|
||||
lib/libnccl.so /usr/lib/${pkg:MultiArch}
|
||||
lib/libnccl_static.a /usr/lib/${pkg:MultiArch}
|
||||
|
||||
@@ -7,7 +7,7 @@ Group: Development/Libraries
|
||||
License: BSD
|
||||
URL: http://developer.nvidia.com/nccl
|
||||
Source0: nccl_${nccl:Major}.${nccl:Minor}.${nccl:Patch}${nccl:Suffix}-${pkg:Revision}+cuda${cuda:Major}.${cuda:Minor}_${pkg:Arch}.txz
|
||||
Prereq: /sbin/ldconfig
|
||||
Requires(pre,preun): /sbin/ldconfig
|
||||
|
||||
%description
|
||||
NCCL (pronounced "Nickel") is a stand-alone library of standard collective
|
||||
@@ -46,6 +46,7 @@ ln -s libnccl.so.${nccl:Major}.${nccl:Minor}.${nccl:Patch} $RPM_BUILD_ROOT/%{_li
|
||||
# devel
|
||||
install -m 755 -d $RPM_BUILD_ROOT/%{_includedir}
|
||||
install -m 644 include/nccl.h $RPM_BUILD_ROOT/%{_includedir}
|
||||
install -m 644 include/nccl_net.h $RPM_BUILD_ROOT/%{_includedir}
|
||||
ln -s libnccl.so.${nccl:Major} $RPM_BUILD_ROOT/%{_libdir}/libnccl.so
|
||||
|
||||
# static
|
||||
@@ -64,6 +65,7 @@ rm -rf $RPM_BUILD_ROOT
|
||||
%doc LICENSE.txt
|
||||
%defattr(-,root,root,-)
|
||||
%{_includedir}/nccl.h
|
||||
%{_includedir}/nccl_net.h
|
||||
%{_libdir}/libnccl.so
|
||||
|
||||
%files static
|
||||
|
||||
@@ -49,7 +49,7 @@ ncclResult_t bootstrapNetInit() {
|
||||
}
|
||||
char line[SOCKET_NAME_MAXLEN+MAX_IF_NAME_SIZE+2];
|
||||
sprintf(line, " %s:", bootstrapNetIfName);
|
||||
socketToString(&bootstrapNetIfAddr.sa, line+strlen(line));
|
||||
socketToString(&bootstrapNetIfAddr, line+strlen(line));
|
||||
INFO(NCCL_INIT, "Bootstrap : Using%s", line);
|
||||
bootstrapNetInitDone = 1;
|
||||
}
|
||||
@@ -61,27 +61,27 @@ ncclResult_t bootstrapNetInit() {
|
||||
/* Socket Interface Selection type */
|
||||
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
|
||||
|
||||
static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd) {
|
||||
struct sockaddr_in sockaddr;
|
||||
socklen_t socklen = sizeof(struct sockaddr_in);
|
||||
SYSCHECKVAL(accept(listenFd, (struct sockaddr*)&sockaddr, &socklen), "accept", *recvFd);
|
||||
static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd, union socketAddress *addr) {
|
||||
struct sockaddr *saddr = &addr->sa;
|
||||
socklen_t socklen = sizeof(union socketAddress);
|
||||
SYSCHECKVAL(accept(listenFd, saddr, &socklen), "accept", *recvFd);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Additional sync functions
|
||||
static ncclResult_t bootstrapNetSend(int fd, void* data, int size) {
|
||||
NCCLCHECK(socketSend(fd, &size, sizeof(int)));
|
||||
NCCLCHECK(socketSend(fd, data, size));
|
||||
static ncclResult_t bootstrapNetSend(int fd, union socketAddress *addr, void* data, int size) {
|
||||
NCCLCHECK(socketSend(fd, addr, &size, sizeof(int)));
|
||||
NCCLCHECK(socketSend(fd, addr, data, size));
|
||||
return ncclSuccess;
|
||||
}
|
||||
static ncclResult_t bootstrapNetRecv(int fd, void* data, int size) {
|
||||
static ncclResult_t bootstrapNetRecv(int fd, union socketAddress *addr, void* data, int size) {
|
||||
int recvSize;
|
||||
NCCLCHECK(socketRecv(fd, &recvSize, sizeof(int)));
|
||||
NCCLCHECK(socketRecv(fd, addr, &recvSize, sizeof(int)));
|
||||
if (recvSize > size) {
|
||||
WARN("Message truncated : received %d bytes instead of %d", recvSize, size);
|
||||
return ncclInternalError;
|
||||
}
|
||||
NCCLCHECK(socketRecv(fd, data, std::min(recvSize, size)));
|
||||
NCCLCHECK(socketRecv(fd, addr, data, std::min(recvSize, size)));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -114,7 +114,6 @@ static void *bootstrapRoot(void* bootstrapRootStruct) { // [RCCL] Modified to in
|
||||
|
||||
ncclResult_t res = ncclSuccess;
|
||||
int nranks = 0, c = 0;
|
||||
|
||||
struct extInfo info;
|
||||
union socketAddress *rankAddresses = NULL;
|
||||
union socketAddress *rankAddressesRoot = NULL; // for initial rank <-> root information exchange
|
||||
@@ -126,8 +125,9 @@ static void *bootstrapRoot(void* bootstrapRootStruct) { // [RCCL] Modified to in
|
||||
/* Receive addresses from all ranks */
|
||||
do {
|
||||
int tmpFd;
|
||||
NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &info, sizeof(info)), res, out);
|
||||
union socketAddress addr;
|
||||
NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd, &addr), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &addr, &info, sizeof(info)), res, out);
|
||||
close(tmpFd);
|
||||
|
||||
if (c == 0) {
|
||||
@@ -165,9 +165,9 @@ static void *bootstrapRoot(void* bootstrapRootStruct) { // [RCCL] Modified to in
|
||||
|
||||
int tmpSendFd;
|
||||
NCCLCHECKGOTO(connectAddress(&tmpSendFd, rankAddressesRoot+r), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddresses+next, sizeof(union socketAddress)), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddressesRoot+r, rankAddresses+next, sizeof(union socketAddress)), res, out);
|
||||
{ // [RCCL] Send the root pid for shared file naming
|
||||
NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, &pid, sizeof(int)), res, out);
|
||||
NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddressesRoot+r, &pid, sizeof(int)), res, out);
|
||||
} // [/RCCL]
|
||||
close(tmpSendFd);
|
||||
}
|
||||
@@ -224,6 +224,7 @@ struct unexConn {
|
||||
int peer;
|
||||
int tag;
|
||||
int fd;
|
||||
union socketAddress addr;
|
||||
struct unexConn* next;
|
||||
};
|
||||
|
||||
@@ -238,6 +239,7 @@ struct extState {
|
||||
int extListenFd;
|
||||
int extRingRecvFd;
|
||||
int extRingSendFd;
|
||||
union socketAddress extRingRecvAddr, extRingSendAddr;
|
||||
union socketAddress* peerCommAddresses;
|
||||
union socketAddress* peerAllocAddresses;
|
||||
struct unexConn* unexpectedConnections;
|
||||
@@ -252,11 +254,11 @@ struct extState {
|
||||
|
||||
#define MAX_SEGMENTS 128
|
||||
|
||||
static ncclResult_t remoteAlloc(void** ptr, int fd) {
|
||||
static ncclResult_t remoteAlloc(void** ptr, int fd, union socketAddress *addr) {
|
||||
size_t size;
|
||||
NCCLCHECK(socketRecv(fd, &size, sizeof(size_t)));
|
||||
NCCLCHECK(socketRecv(fd, addr, &size, sizeof(size_t)));
|
||||
hipIpcMemHandle_t devIpc;
|
||||
NCCLCHECK(ncclCudaCalloc((char**)ptr, size, true));
|
||||
NCCLCHECK(ncclCudaCalloc((char**)ptr, size));
|
||||
hipError_t res = hipIpcGetMemHandle(&devIpc, *ptr);
|
||||
if (res != hipSuccess) {
|
||||
WARN("[Rem Allocator] hipIpcGetMemHandle failed : %s", hipGetErrorString(res));
|
||||
@@ -264,9 +266,9 @@ static ncclResult_t remoteAlloc(void** ptr, int fd) {
|
||||
CUDACHECK(res);
|
||||
}
|
||||
// The CUDA IPC
|
||||
NCCLCHECK(socketSend(fd, &devIpc, sizeof(hipIpcMemHandle_t)));
|
||||
NCCLCHECK(socketSend(fd, addr, &devIpc, sizeof(hipIpcMemHandle_t)));
|
||||
// And the direct pointer
|
||||
NCCLCHECK(socketSend(fd, ptr, sizeof(void*)));
|
||||
NCCLCHECK(socketSend(fd, addr, ptr, sizeof(void*)));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -298,11 +300,12 @@ void* ncclRemoteMemAllocationService(void* args) {
|
||||
}
|
||||
if (pollfds[MAX_SEGMENTS].revents) {
|
||||
int s = 0;
|
||||
union socketAddress addr;
|
||||
while (segments[s] != NULL && s < MAX_SEGMENTS) s++;
|
||||
if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd) != ncclSuccess) {
|
||||
if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd, &addr) != ncclSuccess) {
|
||||
pollfds[s].fd = -1;
|
||||
} else {
|
||||
if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd) != ncclSuccess)) {
|
||||
if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd, &addr) != ncclSuccess)) {
|
||||
WARN("[Rem Allocator] Allocation failed (segment %d, fd %d)", s, pollfds[s].fd);
|
||||
close(pollfds[s].fd);
|
||||
pollfds[s].fd = -1;
|
||||
@@ -337,10 +340,11 @@ ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id,
|
||||
int fd;
|
||||
ncclResult_t res;
|
||||
*id = -1;
|
||||
NCCLCHECK(connectAddress(&fd, state->peerAllocAddresses+rank));
|
||||
NCCLCHECKGOTO(socketSend(fd, &size, sizeof(size_t)), res, end);
|
||||
NCCLCHECKGOTO(socketRecv(fd, ipc, sizeof(hipIpcMemHandle_t)), res, end);
|
||||
NCCLCHECKGOTO(socketRecv(fd, ptr, sizeof(void*)), res, end);
|
||||
union socketAddress *addr = state->peerAllocAddresses+rank;
|
||||
NCCLCHECK(connectAddress(&fd, addr));
|
||||
NCCLCHECKGOTO(socketSend(fd, addr, &size, sizeof(size_t)), res, end);
|
||||
NCCLCHECKGOTO(socketRecv(fd, addr, ipc, sizeof(hipIpcMemHandle_t)), res, end);
|
||||
NCCLCHECKGOTO(socketRecv(fd, addr, ptr, sizeof(void*)), res, end);
|
||||
*id = fd;
|
||||
end:
|
||||
return res;
|
||||
@@ -384,22 +388,22 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS
|
||||
// send info on my listening socket to root
|
||||
union socketAddress* rootAddr = (union socketAddress*)id;
|
||||
NCCLCHECK(connectAddress(&tmpSendFd, rootAddr));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, &info, sizeof(info)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, rootAddr, &info, sizeof(info)));
|
||||
close(tmpSendFd);
|
||||
|
||||
// get info on my "next" rank in the bootstrap ring from root
|
||||
union socketAddress extAddressNext;
|
||||
NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &extAddressNext, sizeof(extAddressNext)));
|
||||
union socketAddress addr;
|
||||
NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd, &addr));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &state->extRingSendAddr, sizeof(state->extRingSendAddr)));
|
||||
{ // [RCCL] Receive PID from root
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, rootPid, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, rootPid, sizeof(int)));
|
||||
} // [/RCCL]
|
||||
close(tmpRecvFd);
|
||||
close(extListenFdRoot);
|
||||
|
||||
NCCLCHECK(connectAddress(&state->extRingSendFd, &extAddressNext));
|
||||
NCCLCHECK(connectAddress(&state->extRingSendFd, &state->extRingSendAddr));
|
||||
// Accept the connect request from the previous rank in the AllGather ring
|
||||
NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd));
|
||||
NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd, &state->extRingRecvAddr));
|
||||
|
||||
// AllGather all listen handlers
|
||||
NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks));
|
||||
@@ -437,9 +441,9 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
|
||||
size_t sslice = (rank - i + nranks) % nranks;
|
||||
|
||||
// Send slice to the right
|
||||
NCCLCHECK(bootstrapNetSend(state->extRingSendFd, data+sslice*size, size));
|
||||
NCCLCHECK(bootstrapNetSend(state->extRingSendFd, &state->extRingSendAddr, data+sslice*size, size));
|
||||
// Recv slice from the left
|
||||
NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, data+rslice*size, size));
|
||||
NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, &state->extRingRecvAddr, data+rslice*size, size));
|
||||
}
|
||||
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
|
||||
@@ -449,21 +453,44 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
|
||||
ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) {
|
||||
struct extState* state = (struct extState*)commState;
|
||||
int tmpSendFd;
|
||||
NCCLCHECK(connectAddress(&tmpSendFd, state->peerCommAddresses+peer));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, &state->rank, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, &tag, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, data, size));
|
||||
union socketAddress *addr = state->peerCommAddresses+peer;
|
||||
NCCLCHECK(connectAddress(&tmpSendFd, addr));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, &state->rank, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, &tag, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetSend(tmpSendFd, addr, data, size));
|
||||
close(tmpSendFd);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd) {
|
||||
ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks) {
|
||||
if (nranks == 1) return ncclSuccess;
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag);
|
||||
|
||||
/* Simple intra process barrier
|
||||
*
|
||||
* Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet,
|
||||
* "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988"
|
||||
*/
|
||||
int data[1];
|
||||
for (int mask=1; mask<nranks; mask<<=1) {
|
||||
int src = (rank - mask + nranks) % nranks;
|
||||
int dst = (rank + mask) % nranks;
|
||||
NCCLCHECK(bootstrapSend(commState, ranks[dst], tag, data, sizeof(data)));
|
||||
NCCLCHECK(bootstrapRecv(commState, ranks[src], tag, data, sizeof(data)));
|
||||
}
|
||||
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d tag %x - DONE", rank, nranks, tag);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd, union socketAddress *addr) {
|
||||
// New unex
|
||||
struct unexConn* unex;
|
||||
NCCLCHECK(ncclCalloc(&unex, 1));
|
||||
unex->peer = peer;
|
||||
unex->tag = tag;
|
||||
unex->fd = fd;
|
||||
unex->addr = *addr;
|
||||
|
||||
// Enqueue
|
||||
struct unexConn* list = state->unexpectedConnections;
|
||||
@@ -476,7 +503,7 @@ ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
int unexpectedDequeue(struct extState* state, int peer, int tag) {
|
||||
int unexpectedDequeue(struct extState* state, int peer, int tag, union socketAddress *addr) {
|
||||
struct unexConn* elem = state->unexpectedConnections;
|
||||
struct unexConn* prev = NULL;
|
||||
while (elem) {
|
||||
@@ -487,6 +514,7 @@ int unexpectedDequeue(struct extState* state, int peer, int tag) {
|
||||
prev->next = elem->next;
|
||||
}
|
||||
int fd = elem->fd;
|
||||
*addr = elem->addr;
|
||||
free(elem);
|
||||
return fd;
|
||||
}
|
||||
@@ -501,27 +529,29 @@ ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int s
|
||||
struct extState* state = (struct extState*)commState;
|
||||
|
||||
int tmpRecvFd;
|
||||
union socketAddress addr;
|
||||
|
||||
// Search unexpected connections first
|
||||
if ((tmpRecvFd = unexpectedDequeue(state, peer, tag)) != -1) {
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size));
|
||||
if ((tmpRecvFd = unexpectedDequeue(state, peer, tag, &addr)) != -1) {
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, ((char*)data), size));
|
||||
close(tmpRecvFd);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Then look for new connections
|
||||
while (1) {
|
||||
NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd));
|
||||
union socketAddress addr;
|
||||
NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd, &addr));
|
||||
int newPeer, newTag;
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newPeer, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newTag, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &newPeer, sizeof(int)));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, &newTag, sizeof(int)));
|
||||
if (newPeer == peer && newTag == tag) {
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size));
|
||||
NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &addr, ((char*)data), size));
|
||||
close(tmpRecvFd);
|
||||
return ncclSuccess;
|
||||
}
|
||||
// Unexpected connection. Save for later.
|
||||
NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, tmpRecvFd));
|
||||
NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, tmpRecvFd, &addr));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,8 @@ __device__ void AllReduceCliqueSplitKernel(struct ncclWorkElem* args)
|
||||
size_t const currBlockStop = min(currBlockStart + perBlockN, N);
|
||||
size_t const blockN = currBlockStop - currBlockStart;
|
||||
|
||||
FUNC redOp(FuncTraits<FUNC>().make(args->comm->nRanks));
|
||||
|
||||
if (blockN > 0)
|
||||
{
|
||||
// Prepare input / output subarrays
|
||||
@@ -64,7 +66,7 @@ __device__ void AllReduceCliqueSplitKernel(struct ncclWorkElem* args)
|
||||
// Perform the reduction
|
||||
#define ALL_REDUCE_CLIQUE_UNROLL 1
|
||||
ReduceOrCopyMulti<ALL_REDUCE_CLIQUE_UNROLL, FUNC, T, NUM_RANKS, NUM_RANKS, NUM_RANKS, NUM_RANKS>(
|
||||
threadIdx.x, blockDim.x, NUM_RANKS, srcs, NUM_RANKS, dsts, blockN);
|
||||
threadIdx.x, blockDim.x, redOp, false, false, NUM_RANKS, srcs, NUM_RANKS, dsts, blockN);
|
||||
}
|
||||
|
||||
// Even if there was nothing for this GPU to do, it must participate in a barrier
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# See LICENSE.txt for license information
|
||||
#
|
||||
@@ -32,7 +32,7 @@ all_deps: $(DEPENDFILES)
|
||||
$(RULESFILE) :
|
||||
@printf "Generating %-35s > %s\n" rules $@
|
||||
@mkdir -p $(OBJDIR)
|
||||
@./gen_rules.sh $(OBJDIR) > $@
|
||||
@CUDA_MAJOR=${CUDA_MAJOR} CUDA_MINOR=${CUDA_MINOR} ./gen_rules.sh $(OBJDIR) > $@
|
||||
|
||||
-include $(RULESFILE)
|
||||
|
||||
|
||||
@@ -6,204 +6,95 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclRing *ring = &ncclShmem->channel.ring;
|
||||
const int *ringRanks = ring->devUserRanks;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLGATHER_CHUNKSTEPS : 1));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2);
|
||||
const int nranks = ncclShmem->comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*int(chunkSize);
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
T *inputBuf = (T*)args->sendbuff;
|
||||
T *outputBuf = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, 0, args->coll.connIndex);
|
||||
|
||||
ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex));
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t chunkOffset = gridOffset + bid*realChunkSize;
|
||||
|
||||
/////////////// begin AllGather steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
if (thisInput + chunkOffset == thisOutput + offset) { // In place
|
||||
prims.directSend(thisInput+chunkOffset, offset, nelem);
|
||||
} else {
|
||||
prims.directCopySend(thisInput+chunkOffset, thisOutput+offset, offset, nelem);
|
||||
}
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
prims.directRecvCopySend(thisOutput+offset, offset, nelem);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
rankDest = ring->devUserRanks[1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
// Final wait/copy.
|
||||
prims.directRecv(thisOutput+offset, offset, nelem);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset,nChannels));
|
||||
realChunkSize = roundUp(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
}
|
||||
};
|
||||
else if (Proto::Id == NCCL_PROTO_LL)
|
||||
realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize;
|
||||
else if (Proto::Id == NCCL_PROTO_LL128)
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128);
|
||||
realChunkSize = int(realChunkSize);
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
ssize_t chunkOffset = gridOffset + int(bid*realChunkSize);
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
/////////////// begin AllGather steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ringRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
if (size-gridOffset < loopSize) {
|
||||
chunkSize = args->coll.lastChunkSize;
|
||||
}
|
||||
ssize_t chunkOffset = gridOffset + bid*chunkSize;
|
||||
|
||||
/////////////// begin AllGather steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(chunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
if (thisInput + chunkOffset == thisOutput + offset) { // In place
|
||||
LLprims.send(thisInput+chunkOffset, nelem);
|
||||
} else {
|
||||
LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: final store
|
||||
rankDest = ring->devUserRanks[1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
if (inputBuf + chunkOffset == outputBuf + offset) { // In place
|
||||
prims.directSend(chunkOffset, offset, nelem);
|
||||
} else {
|
||||
prims.directCopySend(chunkOffset, offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
|
||||
|
||||
ssize_t chunkOffset = gridOffset + bid*chunkSize;
|
||||
|
||||
/////////////// begin AllGather steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(chunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[0];
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
rankDest = ringRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
if (thisInput + chunkOffset == thisOutput + offset) { // In place
|
||||
LLprims.send(thisInput+chunkOffset, nelem);
|
||||
} else {
|
||||
LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: final store
|
||||
rankDest = ring->devUserRanks[1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
prims.directRecvCopySend(offset, offset, nelem);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
rankDest = ringRanks[1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
// Final wait/copy.
|
||||
prims.directRecv(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_TREE, PROTO, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllGather, NCCL_ALGO_COLLNET, PROTO, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -6,518 +6,399 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
#include "clique/AllReduceCliqueKernel.h" // [RCCL] AllReduce Clique-based kernel support
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
ncclRing *ring = &ncclShmem->channel.ring;
|
||||
int ringIx = ring->index;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLREDUCE_CHUNKSTEPS : 1));
|
||||
const int nranks = ncclShmem->comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*nranks*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
#ifdef ENABLE_PROFILING
|
||||
auto devProf = comm->devProf;
|
||||
auto devProf = ncclShmem->comm.devProf;
|
||||
uint64_t clk, t0 = 0ULL, ws;
|
||||
if (tid == 0) clk = __builtin_amdgcn_s_memrealtime();
|
||||
#endif
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
int minChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_LL)
|
||||
minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T));
|
||||
if (Proto::Id == NCCL_PROTO_LL128) {
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2;
|
||||
}
|
||||
|
||||
ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex));
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto> prims
|
||||
(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
|
||||
ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t chunkOffset = gridOffset + bid*nranks*realChunkSize;
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*nranks));
|
||||
realChunkSize = roundUp(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
else
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize);
|
||||
realChunkSize = int(realChunkSize);
|
||||
|
||||
auto calcOffset = [&]__device__(int chunk)->ssize_t {
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE)
|
||||
return gridOffset + bid*nranks*realChunkSize + chunk*realChunkSize;
|
||||
else
|
||||
return gridOffset + (chunk*nChannels + bid)*realChunkSize;
|
||||
};
|
||||
auto modRanks = [&]__device__(int r)->int {
|
||||
return r - (r >= nranks ? nranks : 0);
|
||||
};
|
||||
|
||||
/////////////// begin AllReduce steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem;
|
||||
int chunk;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
chunk = ring->devUserRanks[nranks-1];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = modRanks(ringIx + nranks-1);
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
INIT_COUNTER;
|
||||
prims.send(thisInput+offset, nelem);
|
||||
prims.send(offset, nelem);
|
||||
ACCUMULATE_COUNTER(send);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = modRanks(ringIx + nranks-j);
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
INIT_COUNTER;
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
ACCUMULATE_COUNTER(recvReduceSend);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data and push to the next GPU
|
||||
chunk = ring->devUserRanks[0];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = ringIx + 0;
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
INIT_COUNTER;
|
||||
prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem);
|
||||
prims.directRecvReduceCopySend(offset, offset, offset, nelem, /*postOp=*/true);
|
||||
ACCUMULATE_COUNTER(directRecvReduceCopySend);
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = modRanks(ringIx + nranks-j);
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
INIT_COUNTER;
|
||||
prims.directRecvCopySend(thisOutput+offset, offset, nelem);
|
||||
prims.directRecvCopySend(offset, offset, nelem);
|
||||
ACCUMULATE_COUNTER(directRecvCopySend);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
chunk = ring->devUserRanks[1];
|
||||
offset = chunkOffset + chunk * realChunkSize;
|
||||
chunk = modRanks(ringIx + 1);
|
||||
offset = calcOffset(chunk);
|
||||
nelem = min(realChunkSize, size-offset);
|
||||
|
||||
// Final wait/copy.
|
||||
INIT_COUNTER;
|
||||
prims.directRecv(thisOutput+offset, offset, nelem);
|
||||
prims.directRecv(offset, nelem);
|
||||
ACCUMULATE_COUNTER(directRecv);
|
||||
}
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (tid == 0) __atomic_fetch_add(&(devProf->total_cycle), __builtin_amdgcn_s_memrealtime() - clk, __ATOMIC_SEQ_CST);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runTreeUpDown(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclTree* tree = &channel->tree;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
int chunkSize = args->coll.lastChunkSize;
|
||||
const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
ncclTree *tree = &ncclShmem->channel.tree;
|
||||
ssize_t chunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? args->coll.lastChunkSize
|
||||
/* LL & LL128 */ : Proto::calcBytePerStep()/sizeof(T));
|
||||
const ssize_t minChunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? nthreads*8*(sizeof(uint64_t)/sizeof(T))
|
||||
/* LL & LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
const ssize_t loopSize = int(nChannels*chunkSize);
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
if (loopSize > size) {
|
||||
chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
|
||||
}
|
||||
if (loopSize > size)
|
||||
chunkSize = divUp((int)size, int(nChannels*minChunkSize))*int(minChunkSize);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
#if 1
|
||||
if (tid < nthreads) {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, 1, 0, FUNC>
|
||||
prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->up == -1) {
|
||||
prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else if (tree->down[0] == -1) {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
{ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto> prims
|
||||
(tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
}
|
||||
else if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.send(offset, nelem);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tid < nthreads) {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_DEV_ARITY, 1, FUNC>
|
||||
prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->up == -1) {
|
||||
prims.directSend(thisOutput+offset, offset, nelem);
|
||||
} else if (tree->down[0] == -1) {
|
||||
prims.directRecv(thisOutput+offset, offset, nelem);
|
||||
} else {
|
||||
prims.directRecvCopySend(thisOutput+offset, offset, nelem);
|
||||
{ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto> prims
|
||||
(tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directSendFromOutput(offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
else if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecv(offset, nelem);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecvCopySend(offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
int nthreadsSplit = nthreads/2;
|
||||
if (nthreadsSplit >= 256) nthreadsSplit += 64;
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runTreeSplit(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclTree *tree = &ncclShmem->channel.tree;
|
||||
ssize_t chunkSize = int(
|
||||
Proto::Id != NCCL_PROTO_LL ? args->coll.lastChunkSize
|
||||
: Proto::calcBytePerStep()/sizeof(T));
|
||||
const ssize_t minChunkSize = int(
|
||||
Proto::Id == NCCL_PROTO_SIMPLE ? nthreads*8*(sizeof(uint64_t)/sizeof(T)) :
|
||||
Proto::Id == NCCL_PROTO_LL ? nthreads*(Proto::calcBytePerGrain()/sizeof(T))
|
||||
/* LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T))/8);
|
||||
const ssize_t loopSize = int(nChannels*chunkSize);
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
int nthreadsSplit;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
nthreadsSplit = nthreads/2;
|
||||
if (nthreadsSplit >= 256) nthreadsSplit += 64;
|
||||
} else { // LL & LL128
|
||||
// Receiving from up to 3 sources is more compute intensive than sending
|
||||
// to 3 dests. Use 70% for reduce and 30% for bcast.
|
||||
nthreadsSplit = (nthreads*7/(10*WARP_SIZE))*WARP_SIZE;
|
||||
}
|
||||
|
||||
if (loopSize > size)
|
||||
chunkSize = divUp((int)size, nChannels*int(minChunkSize))*int(minChunkSize);
|
||||
|
||||
if (tree->up == -1) {
|
||||
if (tid < nthreads) {
|
||||
// ReduceAndBroadcast : max number of recv is 3, max number of send is 3
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, NCCL_MAX_DEV_ARITY, 1, FUNC>
|
||||
prims(tid, nthreads, tree->down, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
// Reduce and broadcast. Max number of recv is 3, max number of send is 3
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecvReduceCopySend(offset, offset, offset, nelem, /*doPost=*/true);
|
||||
}
|
||||
}
|
||||
else if (tid < nthreadsSplit) {
|
||||
/* Reduce up. Max number of recv is 3, max number of send is 1 (binary tree + local).
|
||||
* Why Direct=1????
|
||||
* Answer: Because despite not performing any direct operations, the ctor
|
||||
* must assume Direct so that it can exchange direct pointers with remote ctors
|
||||
* that are Direct, otherwise it hangs. A cleaner solution would be to seperate
|
||||
* into DirectRecv and DirectSend capabilities, this ctor would have both=0,
|
||||
* but the ctor above for tree roots would be DirectRecv=0 DirectSend=1.
|
||||
*/
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem);
|
||||
prims.send(offset, nelem);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (tid < nthreadsSplit) {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, 1, 0, FUNC>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->down[0] == -1) {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_DEV_ARITY, 1, FUNC>
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 2);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->down[0] == -1) {
|
||||
prims.directRecv(thisOutput+offset, offset, nelem);
|
||||
} else {
|
||||
prims.directRecvCopySend(thisOutput+offset, offset, nelem);
|
||||
}
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
// Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecv(offset, nelem);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.directRecvCopySend(offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
#define COLLNET_COPY_THREADS 64
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
static constexpr int COLLNET_COPY_THREADS = 64;
|
||||
const int tid = threadIdx.x;
|
||||
//const int nthreads = args->nThreads-3*WARP_SIZE;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclDirect* tree = &channel->collTree;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
int chunkSize = args->coll.lastChunkSize;
|
||||
struct ncclDirect* tree = &ncclShmem->channel.collTree;
|
||||
const ssize_t chunkSize = int(args->coll.lastChunkSize);
|
||||
const ssize_t size = args->coll.count;
|
||||
const ssize_t loopSize = nChannels*tree->nHeads*chunkSize;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
const int hasUp = (tree->up[0] >= 0) ? 1 : 0;
|
||||
const int hasDn = (tree->down[0] >= 0) ? 1 : 0;
|
||||
const int nThreadsScatter = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 2*COLLNET_COPY_THREADS : 0;
|
||||
const int nThreadsGather = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 1*COLLNET_COPY_THREADS : 0;
|
||||
const int nThreadsBcast = (hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 0 : 1*COLLNET_COPY_THREADS;
|
||||
// Gather does not need sync threads, sparing one more warp for reduce
|
||||
const int nThreadsReduce = NCCL_SIMPLE_MAX_NTHREADS - nThreadsScatter - nThreadsGather - nThreadsBcast;
|
||||
const int nThreadsScatter = ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 2*COLLNET_COPY_THREADS : 0);
|
||||
const int nThreadsGather = ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 1*COLLNET_COPY_THREADS : 0);
|
||||
const int nThreadsBcast = ((hasUp && hasDn) ? COLLNET_COPY_THREADS : hasUp ? 0 : 1*COLLNET_COPY_THREADS);
|
||||
const int nThreadsReduce = args->nThreads - nThreadsScatter - nThreadsGather - nThreadsBcast;
|
||||
const int tidStartBcast = nThreadsGather;
|
||||
const int tidStartScatter = tidStartBcast + nThreadsBcast;
|
||||
const int tidStartReduce = tidStartScatter + nThreadsScatter;
|
||||
|
||||
using Proto = ProtoSimple<1, 1>;
|
||||
|
||||
if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) {
|
||||
// Scatter
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 0, NCCL_MAX_DIRECT_ARITY, 0, FUNC>
|
||||
prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 2);
|
||||
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, args->sendbuff, args->recvbuff, 2*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize;
|
||||
int nelem = min(tree->nHeads*chunkSize, size-offset);
|
||||
prims.scatter(thisInput+offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
prims.scatter(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
}
|
||||
} else if (tid >= tidStartReduce && tree->out != -1) {
|
||||
// Reduce, send to network
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DIRECT_ARITY, 1, 0, FUNC>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, NULL, stepSize, channel, comm, ncclShmem->ptrs, 3);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (hasDn) {
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
} else {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
if (hasDn) {
|
||||
// Reduce, send to network
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 1>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
} else {
|
||||
// Directly send to network
|
||||
Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, nullptr, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.send(offset, nelem);
|
||||
}
|
||||
}
|
||||
} else if (tid < tidStartBcast && hasUp) {
|
||||
// Gather
|
||||
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DIRECT_ARITY, 0, 0, FUNC>
|
||||
prims(tid, nThreadsGather, tree->up, NULL, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 0>, /*Direct=*/0, Proto>
|
||||
prims(tid, nThreadsGather, tree->up, NULL, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize;
|
||||
int nelem = min(tree->nHeads*chunkSize, size-offset);
|
||||
prims.gather(thisOutput+offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
prims.gather(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
}
|
||||
} else if (tid >= tidStartBcast && tid < tidStartScatter && tree->out != -1) {
|
||||
// Recv from network, broadcast
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_DIRECT_ARITY, 0, FUNC>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 1);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (hasDn) {
|
||||
prims.recvCopySend(thisOutput+offset, nelem);
|
||||
} else {
|
||||
prims.recv(thisOutput+offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T);
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*nranks*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
|
||||
|
||||
/////////////// begin AllReduce steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem;
|
||||
int chunk;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
chunk = ring->devUserRanks[nranks-1];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data and push to the next GPU
|
||||
chunk = ring->devUserRanks[0];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
|
||||
|
||||
// k-2 steps: copy to next GPU
|
||||
for (int j=1; j<nranks-1; ++j) {
|
||||
chunk = ring->devUserRanks[nranks-j];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
|
||||
// Make final copy from buffer to dest.
|
||||
chunk = ring->devUserRanks[1];
|
||||
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
|
||||
nelem = min(chunkSize, size-offset);
|
||||
|
||||
// Here we need to copy from buffer to this output.
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclTree* tree = &channel->tree;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
if (loopSize > size) {
|
||||
chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
|
||||
}
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
do {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclLLPrimitives<T, FUNC, NCCL_MAX_DEV_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->up == -1) {
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else if (tree->down[0] == -1) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
}
|
||||
} while(0);
|
||||
|
||||
do {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_DEV_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->up == -1) {
|
||||
LLprims.send(thisOutput+offset, nelem);
|
||||
} else if (tree->down[0] == -1) {
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
}
|
||||
} while(0);
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ void run(struct ncclWorkElem* args) { }
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
// [RCCL] RingLL128 is re-purposed as clique-based kernel
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_CLIQUE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
LAUNCH_CLIQUE_KERNEL(AllReduceCliqueSplitKernel, FUNC, T, args);
|
||||
}
|
||||
};
|
||||
// [/RCCL]
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclTree* tree = &channel->tree;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = args->coll.lastChunkSize;
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
int nthreadsSplit = NCCL_LL128_SPLIT(nthreads);
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
if (loopSize > size) {
|
||||
chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
|
||||
}
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
if (tree->up == -1) {
|
||||
// ReduceAndBroadcast : max number of recv is 3, max number of send is 3
|
||||
ncclLL128Primitives<T, FUNC, NCCL_MAX_DEV_ARITY, NCCL_MAX_DEV_ARITY> LLprims(tid, nthreads, tree->down, tree->down, stepSize, channel, comm);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
|
||||
}
|
||||
} else {
|
||||
if (tid < nthreadsSplit) {
|
||||
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
ncclLL128Primitives<T, FUNC, NCCL_MAX_DEV_ARITY, 1> LLprims(tid, nthreadsSplit, tree->down, &tree->up, stepSize, channel, comm);
|
||||
if (hasDn) {
|
||||
// Recv from network, broadcast
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Up
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->down[0] == -1) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
prims.recvCopySend(offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
} else {
|
||||
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_DEV_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, stepSize, channel, comm);
|
||||
// Recv from network (no post thread needed)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, nullptr, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
// Down
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (tree->down[0] == -1) {
|
||||
LLprims.recv(thisOutput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvCopySend(thisOutput+offset, nelem);
|
||||
}
|
||||
prims.recv(offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) { }
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runTreeUpDown<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
LAUNCH_CLIQUE_KERNEL(AllReduceCliqueSplitKernel, RedOp, T, args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
LAUNCH_CLIQUE_KERNEL(AllReduceCliqueSplitKernel, RedOp, T, args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,163 +1,98 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * BROADCAST_CHUNKSTEPS;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclRing *ring = &ncclShmem->channel.ring;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? BROADCAST_CHUNKSTEPS : 1));
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
#ifdef ENABLE_PROFILING
|
||||
auto devProf = ncclShmem->comm.devProf;
|
||||
uint64_t clk, t0 = 0ULL, ws;
|
||||
if (tid == 0) clk = __builtin_amdgcn_s_memrealtime();
|
||||
#endif
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
T *inputBuf = (T*)args->sendbuff;
|
||||
T *outputBuf = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, 0, args->coll.connIndex);
|
||||
|
||||
ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex));
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels));
|
||||
realChunkSize = roundUp(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
else if (Proto::Id == NCCL_PROTO_LL)
|
||||
realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize;
|
||||
else if (Proto::Id == NCCL_PROTO_LL128)
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128);
|
||||
realChunkSize = int(realChunkSize);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t offset = gridOffset + bid*realChunkSize;
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
ssize_t offset = gridOffset + int(bid*realChunkSize);
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
|
||||
if (rank == root) {
|
||||
if (thisInput == thisOutput) {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
prims.copySend(thisInput+offset, thisOutput+offset, nelem);
|
||||
}
|
||||
} else if (nextRank == root) {
|
||||
prims.recv(thisOutput+offset, nelem);
|
||||
if (rank == root) {
|
||||
if (inputBuf == outputBuf) {
|
||||
INIT_COUNTER;
|
||||
prims.send(offset, nelem);
|
||||
ACCUMULATE_COUNTER(send);
|
||||
} else {
|
||||
prims.recvCopySend(thisOutput+offset, nelem);
|
||||
INIT_COUNTER;
|
||||
prims.copySend(offset, offset, nelem);
|
||||
ACCUMULATE_COUNTER(copySend);
|
||||
}
|
||||
} else if (nextRank == root) {
|
||||
INIT_COUNTER;
|
||||
prims.recv(offset, nelem);
|
||||
ACCUMULATE_COUNTER(recv);
|
||||
} else {
|
||||
INIT_COUNTER;
|
||||
prims.recvCopySend(offset, nelem);
|
||||
ACCUMULATE_COUNTER(recvCopySend);
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (tid == 0) __atomic_fetch_add(&(devProf->total_cycle), __builtin_amdgcn_s_memrealtime() - clk, __ATOMIC_SEQ_CST);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
if (size-gridOffset < loopSize) {
|
||||
chunkSize = args->coll.lastChunkSize;
|
||||
}
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (rank == root) {
|
||||
if (thisInput == thisOutput) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
|
||||
}
|
||||
} else if (nextRank == root) {
|
||||
LLprims.recv(thisOutput + offset, nelem);
|
||||
} else {
|
||||
LLprims.recvCopySend(thisOutput + offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int nextRank = ring->devUserRanks[1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (rank == root) {
|
||||
if (thisInput == thisOutput) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else {
|
||||
LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
|
||||
}
|
||||
} else if (nextRank == root) {
|
||||
LLprims.recv(thisOutput + offset, nelem);
|
||||
} else {
|
||||
LLprims.recvCopySend(thisOutput + offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -14,14 +14,6 @@
|
||||
#define COLL_UNROLL 2
|
||||
#define NCCL_MAX_DEV_ARITY NCCL_MAX_TREE_ARITY
|
||||
|
||||
// Exit If Abort Barrier across CTA: make sure all threads exit consistently
|
||||
// Each thread sets a predicate to true if abort == 1
|
||||
// all CTA's threads enter the barrier and do a popc on their predicates being True
|
||||
// If any of the thread's predicate was True, all the threads call exit()
|
||||
#define exitIfAbortBarrier(abort, abortCount) \
|
||||
if (abort) __atomic_fetch_add(abortCount, 1, __ATOMIC_SEQ_CST); \
|
||||
__syncthreads(); \
|
||||
if (LOAD(abortCount)) { /*asm volatile ("s_endpgm");*/ return false; }
|
||||
#define __syncwarp()
|
||||
|
||||
#define NCCL_FUNC5(func, algo, redop, type) \
|
||||
@@ -63,11 +55,13 @@
|
||||
NCCL_FUNCS3A(func, Sum ), \
|
||||
NCCL_FUNCS3A(func, Prod), \
|
||||
NCCL_FUNCS3A(func, Max ), \
|
||||
NCCL_FUNCS3A(func, Min )
|
||||
NCCL_FUNCS3A(func, Min ), \
|
||||
NCCL_FUNCS3A(func, Avg)
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
|
||||
// [RCCL] Adding clique-based kernels for AllReduce, in-place of unused RingLL28 kernels
|
||||
@@ -97,7 +91,8 @@
|
||||
NCCL_FUNCS3C(func, Sum ), \
|
||||
NCCL_FUNCS3C(func, Prod), \
|
||||
NCCL_FUNCS3C(func, Max ), \
|
||||
NCCL_FUNCS3C(func, Min )
|
||||
NCCL_FUNCS3C(func, Min ), \
|
||||
NCCL_FUNCS3A(func, Avg)
|
||||
|
||||
// Must be consistent with ncclFunc_t
|
||||
#define NCCL_FUNCS() { \
|
||||
@@ -121,7 +116,7 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{
|
||||
NCCL_FUNCS2A(Reduce),
|
||||
NCCL_FUNCS2B(AllGather),
|
||||
NCCL_FUNCS2A(ReduceScatter),
|
||||
NCCL_FUNCS2C(AllReduce),
|
||||
NCCL_FUNCS2A(AllReduce),
|
||||
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
|
||||
#endif
|
||||
};
|
||||
@@ -143,12 +138,12 @@ struct Caller<f, f + 1>{
|
||||
void call(struct ncclWorkElem* const c) noexcept { ncclFuncs[f](c); }
|
||||
};
|
||||
|
||||
static_assert(FUNC_INDEX_P2P == 1800, "Wrong P2P function index");
|
||||
static_assert(FUNC_INDEX_P2P == 2250, "Wrong P2P function index");
|
||||
|
||||
inline
|
||||
__device__
|
||||
void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
|
||||
if (c->funcIndex < 360) {
|
||||
if (c->funcIndex < 450) {
|
||||
if (c->funcIndex % 9 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c);
|
||||
else if (c->funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c);
|
||||
else if (c->funcIndex % 9 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t(c);
|
||||
@@ -159,8 +154,8 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
|
||||
else if (c->funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t(c);
|
||||
else ncclFunction_Broadcast_COLLNET_SIMPLE_Sum_int8_t(c);
|
||||
}
|
||||
else if (c->funcIndex < 720) Caller<360, 720>::call(c);
|
||||
else if (c->funcIndex < 1080) {
|
||||
else if (c->funcIndex < 900) Caller<450, 900>::call(c);
|
||||
else if (c->funcIndex < 1350) {
|
||||
if (c->funcIndex % 9 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t(c);
|
||||
else if (c->funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t(c);
|
||||
else if (c->funcIndex % 9 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t(c);
|
||||
@@ -171,25 +166,10 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
|
||||
else if (c->funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t(c);
|
||||
else ncclFunction_AllGather_COLLNET_SIMPLE_Sum_int8_t(c);
|
||||
}
|
||||
else if (c->funcIndex < 1800) Caller<1080, 1800>::call(c);
|
||||
else if (c->funcIndex < 2250) Caller<1350, 2250>::call(c);
|
||||
else ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t(c);
|
||||
}
|
||||
|
||||
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) {
|
||||
int* d = (int*)dst;
|
||||
int* s = (int*)src;
|
||||
for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o];
|
||||
}
|
||||
|
||||
static __device__ bool load_coll(struct ncclWork* localWork, struct ncclWork *hostWork, struct ncclWork* workFifo, int tid, struct ncclDevComm* comm, uint32_t* abortCount) {
|
||||
load_parallel(localWork, workFifo, sizeof(struct ncclWork), tid);
|
||||
// Check whether the last operation was aborted and make sure all threads exit
|
||||
int abort = tid == 0 ? LOAD(comm->abortFlag) : 0;
|
||||
exitIfAbortBarrier(abort, abortCount);
|
||||
if (tid == 0) hostWork->elems[0].active = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction {
|
||||
public:
|
||||
@@ -198,42 +178,42 @@ class ncclFunction {
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
#define traceColl(fIdx) \
|
||||
uint32_t pos = __atomic_fetch_add(comm->collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
|
||||
comm->collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
|
||||
comm->collTrace[pos].opCount = w->op.opCount; \
|
||||
comm->collTrace[pos].bid = bid; \
|
||||
comm->collTrace[pos].funcIndex = fIdx; \
|
||||
uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
|
||||
shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
|
||||
shmem.comm.collTrace[pos].opCount = elems[0].op.opCount; \
|
||||
shmem.comm.collTrace[pos].bid = bid; \
|
||||
shmem.comm.collTrace[pos].funcIndex = fIdx; \
|
||||
if (fIdx == FUNC_INDEX_P2P) { \
|
||||
comm->collTrace[pos].p2p.nThreads = w->p2p.nThreads; \
|
||||
comm->collTrace[pos].p2p.delta = (uint16_t)(w->p2p.delta); \
|
||||
shmem.comm.collTrace[pos].p2p.nThreads = elems[0].p2p.nThreads; \
|
||||
shmem.comm.collTrace[pos].p2p.delta = (uint16_t)(elems[0].p2p.delta); \
|
||||
} else { \
|
||||
comm->collTrace[pos].coll.nThreads = w->nThreads; \
|
||||
comm->collTrace[pos].coll.bid = w->coll.bid; \
|
||||
comm->collTrace[pos].coll.nChannels = w->coll.nChannels; \
|
||||
shmem.comm.collTrace[pos].coll.nThreads = elems[0].nThreads; \
|
||||
shmem.comm.collTrace[pos].coll.bid = elems[0].coll.bid; \
|
||||
shmem.comm.collTrace[pos].coll.nChannels = elems[0].coll.nChannels; \
|
||||
}
|
||||
#define traceKernelLaunch(fIdx) { \
|
||||
traceColl(fIdx); \
|
||||
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (comm->collTrace[pos].data_0)); \
|
||||
comm->collTrace[pos].type = ncclCollTraceKernelLaunchType; \
|
||||
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (shmem.comm.collTrace[pos].data_0)); \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceKernelLaunchType; \
|
||||
}
|
||||
#define traceCollEnd(fIdx) { \
|
||||
traceColl(fIdx); \
|
||||
comm->collTrace[pos].type = ncclCollTraceCollEndType; \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceCollEndType; \
|
||||
}
|
||||
#define traceAbort(fIdx) { \
|
||||
traceColl(fIdx); \
|
||||
comm->collTrace[pos].type = ncclCollTraceAbortType; \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceAbortType; \
|
||||
}
|
||||
// traceData(int16_t data2, uint32_t data4, uint64_t data8_0, uint64_t data8_1)
|
||||
#define traceData(data2, data4, data8_0, data8_1) { \
|
||||
uint32_t pos = __atomic_fetch_add(comm->collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
|
||||
comm->collTrace[pos].bid = blockIdx.x; \
|
||||
comm->collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
|
||||
comm->collTrace[pos].funcIndex = data2; \
|
||||
comm->collTrace[pos].data_0 = data4; \
|
||||
comm->collTrace[pos].opCount = data8_0; \
|
||||
comm->collTrace[pos].data_1 = data8_1; \
|
||||
comm->collTrace[pos].type = ncclCollTraceDataType; \
|
||||
uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
|
||||
shmem.comm.collTrace[pos].bid = blockIdx.x; \
|
||||
shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
|
||||
shmem.comm.collTrace[pos].funcIndex = data2; \
|
||||
shmem.comm.collTrace[pos].data_0 = data4; \
|
||||
shmem.comm.collTrace[pos].opCount = data8_0; \
|
||||
shmem.comm.collTrace[pos].data_1 = data8_1; \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceDataType; \
|
||||
}
|
||||
#else
|
||||
#define traceKernelLaunch(fIdx)
|
||||
@@ -242,9 +222,61 @@ class ncclFunction {
|
||||
#define traceData(data2, data4, data8_0, data8_1)
|
||||
#endif
|
||||
|
||||
#define MAXWARPS (NCCL_MAX_NTHREADS/WARP_SIZE)
|
||||
__device__ inline bool barrierReduceAny(int bit, uint32_t* abortCount) {
|
||||
if (bit) __atomic_fetch_add(abortCount, 1, __ATOMIC_SEQ_CST); \
|
||||
__syncthreads(); \
|
||||
return LOAD(abortCount) != 0;
|
||||
}
|
||||
|
||||
struct ncclShmemPtrs {
|
||||
template<typename T>
|
||||
__device__ int copyToShmem(T *dst, T const *src, int turn=0) {
|
||||
static_assert(sizeof(uint64_t) <= alignof(T), "Uhoh");
|
||||
uint64_t *d = reinterpret_cast<uint64_t*>(dst);
|
||||
uint64_t const *s = reinterpret_cast<uint64_t const*>(src);
|
||||
int t = threadIdx.x - turn;
|
||||
if (t < 0) t += blockDim.x;
|
||||
int n = sizeof(T)/sizeof(uint64_t);
|
||||
|
||||
int delta = (n + WARP_SIZE-1) & -WARP_SIZE; // round up to warp lane 0
|
||||
if (delta < blockDim.x) {
|
||||
turn += delta;
|
||||
if (turn >= blockDim.x) turn -= blockDim.x;
|
||||
}
|
||||
else
|
||||
turn = 0;
|
||||
|
||||
n -= t;
|
||||
d += t;
|
||||
s += t;
|
||||
#pragma unroll
|
||||
for (int i=0; i < divUp(sizeof(T), WARP_SIZE*sizeof(uint64_t)); i++) {
|
||||
if (n > 0) {
|
||||
*d = *s;
|
||||
d += blockDim.x;
|
||||
s += blockDim.x;
|
||||
n -= blockDim.x;
|
||||
}
|
||||
}
|
||||
return turn;
|
||||
}
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
|
||||
struct RunWorkElement {
|
||||
__device__ void run(ncclWorkElem*) {
|
||||
// Put NOT IMPLEMENTED behavior here.
|
||||
}
|
||||
};
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
|
||||
struct RunWork {
|
||||
__device__ void run(ncclWork *w) {
|
||||
}
|
||||
};
|
||||
|
||||
#define MAXWARPS (NCCL_MAX_NTHREADS/WARP_SIZE)
|
||||
struct ncclShmemGroup {
|
||||
ncclConnInfo *recvConns[NCCL_MAX_DIRECT_ARITY];
|
||||
ncclConnInfo *sendConns[NCCL_MAX_DIRECT_ARITY];
|
||||
void* srcs[NCCL_MAX_DIRECT_ARITY+1];
|
||||
void* dsts[NCCL_MAX_DIRECT_ARITY+1];
|
||||
uint64_t barrier;
|
||||
@@ -253,20 +285,19 @@ struct ncclShmemPtrs {
|
||||
|
||||
struct ncclShmemData {
|
||||
union {
|
||||
#ifdef ENABLE_LL128
|
||||
volatile uint64_t data[NCCL_LL128_SHMEM_SIZE];
|
||||
#else
|
||||
volatile uint64_t* data;
|
||||
#endif
|
||||
struct ncclShmemPtrs ptrs[NCCL_MAX_GROUPS];
|
||||
uint64_t ll128warp[NCCL_MAX_GROUPS][NCCL_MAX_GROUPS];
|
||||
struct ncclShmemGroup groups[NCCL_MAX_GROUPS];
|
||||
};
|
||||
uint32_t sync[MAXWARPS];
|
||||
struct ncclWork localWork;
|
||||
ncclDevComm comm;
|
||||
ncclChannel channel;
|
||||
ncclWork work;
|
||||
};
|
||||
|
||||
extern __device__ struct ncclShmemData *ncclShmem;
|
||||
template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL, int FINDEX, bool COLLTRACE>
|
||||
__device__ void ncclKernel(struct ncclWorkElem first) {
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int FnIndex, bool COLLTRACE>
|
||||
__device__ void ncclKernel(ncclWorkElem first) {
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
__shared__ struct ncclShmemData shmem;
|
||||
@@ -275,51 +306,63 @@ __device__ void ncclKernel(struct ncclWorkElem first) {
|
||||
if (tid == 0) {
|
||||
abortCount = 0;
|
||||
for (auto i = 0; i < NCCL_MAX_GROUPS; i++) {
|
||||
shmem.ptrs[i].barrier = 0;
|
||||
for (auto j = 0; j < MAXWARPS; j++) shmem.ptrs[i].barrier_next[j] = 0;
|
||||
shmem.groups[i].barrier = 0;
|
||||
for (auto j = 0; j < MAXWARPS; j++) shmem.groups[i].barrier_next[j] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto f = ncclFunction<FUNCTION, ALGO, PROTO, REDOP, T, UNROLL>();
|
||||
int turn = copyToShmem(&shmem.comm, first.comm);
|
||||
// get address of channel without incurring indirect load from ncclDevCom::channels
|
||||
ncclChannel *channel = &((ncclDevCommAndChannels*)first.comm)->channels[bid];
|
||||
turn = copyToShmem(&shmem.channel, channel, turn);
|
||||
|
||||
struct ncclDevComm* comm = first.comm;
|
||||
struct ncclChannel* channel = comm->channels+bid;
|
||||
struct ncclWorkElem* w = NULL;
|
||||
// To optimize for latency, (only) the first operation is passed as argument.
|
||||
struct ncclWorkElem* elems = NULL;
|
||||
bool firstLaunch = true;
|
||||
if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) elems = &first;
|
||||
|
||||
if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) w = &first;
|
||||
ncclWork *workFifoHost = channel->workFifo;
|
||||
ncclWork *workFifoDev = channel->workFifoDev;
|
||||
int workFifoIx = channel->index;
|
||||
|
||||
while (1) {
|
||||
if (w == NULL) {
|
||||
w = shmem.localWork.elems;
|
||||
if (elems == NULL) {
|
||||
elems = shmem.work.elems;
|
||||
__syncthreads();
|
||||
if (!load_coll(&shmem.localWork, channel->workFifo+channel->index, channel->workFifoDev+channel->index, tid, comm, &abortCount)) {
|
||||
if (COLLTRACE && tid == 0) traceAbort(0xffff);
|
||||
return;
|
||||
copyToShmem(&shmem.work, &workFifoDev[workFifoIx]);
|
||||
{ // Check whether the last operation was aborted and make sure all threads exit
|
||||
int aborted = tid == 0 ? *shmem.comm.abortFlag : 0;
|
||||
if (barrierReduceAny(aborted, &abortCount)) { // publish ncclShmem->work
|
||||
if (COLLTRACE && tid == 0) traceAbort(0xffff);
|
||||
break;
|
||||
}
|
||||
if (tid == 0)
|
||||
workFifoHost[workFifoIx].elems[0].active = 0;
|
||||
}
|
||||
if (COLLTRACE && tid == 0) {
|
||||
if (firstLaunch) traceKernelLaunch(w->funcIndex);
|
||||
if (!firstLaunch) traceCollEnd(w->funcIndex);
|
||||
if (firstLaunch) traceKernelLaunch(elems->funcIndex);
|
||||
if (!firstLaunch) traceCollEnd(elems->funcIndex);
|
||||
firstLaunch = false;
|
||||
}
|
||||
} else if (COLLTRACE && tid == 0) {
|
||||
traceKernelLaunch(w->funcIndex);
|
||||
traceKernelLaunch(elems->funcIndex);
|
||||
firstLaunch = false;
|
||||
}
|
||||
if (tid < w->nThreads) {
|
||||
if (w->funcIndex == FINDEX) {
|
||||
f.run(w);
|
||||
} else {
|
||||
NCCL_CALL_FUNCTIONS(w);
|
||||
}
|
||||
workFifoIx = (workFifoIx + 1)%NCCL_MAX_OPS;
|
||||
if (tid == 0)
|
||||
channel->index = workFifoIx; // write back to real channel, not shmem shadow
|
||||
if (elems->funcIndex == FnIndex) {
|
||||
RunWork<Fn, T, RedOp, Algo, Proto>().run(&shmem.work);
|
||||
} else {
|
||||
if (tid < elems->nThreads && elems->active != 0)
|
||||
NCCL_CALL_FUNCTIONS(elems);
|
||||
}
|
||||
if (tid == 0) channel->index = (channel->index+1) % NCCL_MAX_OPS;
|
||||
if (w->active == 2) {
|
||||
if (elems->active == 2) {
|
||||
if (COLLTRACE && tid == 0) traceCollEnd(0xffff);
|
||||
return;
|
||||
}
|
||||
w = NULL;
|
||||
elems = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,17 +370,16 @@ __device__ void ncclKernel(struct ncclWorkElem first) {
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
|
||||
__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first) { \
|
||||
if (first.comm->collTraceThread) \
|
||||
ncclKernel<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL, fIndex, true>(first); \
|
||||
ncclKernel<ncclFunc##func, type, Func##redop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(first); \
|
||||
else \
|
||||
ncclKernel<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL, fIndex, false>(first); \
|
||||
ncclKernel<ncclFunc##func, type, Func##redop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(first); \
|
||||
}
|
||||
|
||||
// Examples : AllReduce, RING, LL, Sum, uint8
|
||||
/* Functions for aggregation case */
|
||||
#define IMPL_COLL_FUNC(func, algo, proto, redop, type) \
|
||||
__device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args) { \
|
||||
auto f = ncclFunction<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL>(); \
|
||||
f.run(args); \
|
||||
RunWorkElement<ncclFunc##func, type, Func##redop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(args); \
|
||||
}
|
||||
|
||||
// Only generate inline kernels for LL
|
||||
@@ -367,7 +409,8 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, red
|
||||
IMPL_COLL2(func, Sum) \
|
||||
IMPL_COLL2(func, Prod) \
|
||||
IMPL_COLL2(func, Min) \
|
||||
IMPL_COLL2(func, Max)
|
||||
IMPL_COLL2(func, Max) \
|
||||
IMPL_COLL2(func, Avg)
|
||||
|
||||
// [RCCL] Define clique-based implementations (repurposed LL128)
|
||||
#define IMPL_COLL4_CLIQUE(func, algo, redop, type, ncclType) \
|
||||
@@ -396,7 +439,8 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, red
|
||||
IMPL_COLL2_CLIQUE(func, Sum) \
|
||||
IMPL_COLL2_CLIQUE(func, Prod) \
|
||||
IMPL_COLL2_CLIQUE(func, Min) \
|
||||
IMPL_COLL2_CLIQUE(func, Max)
|
||||
IMPL_COLL2_CLIQUE(func, Max) \
|
||||
IMPL_COLL2_CLIQUE(func, Avg)
|
||||
// [/RCCL]
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
@@ -17,24 +17,26 @@
|
||||
// Define min for ssize_t
|
||||
static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; }
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void loadPtr(void** ptr, T* &v) {
|
||||
v = LOAD(ptr);
|
||||
}
|
||||
|
||||
typedef uint64_t PackType;
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
|
||||
template<class FUNC, typename T>
|
||||
struct MULTI {
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const
|
||||
{
|
||||
return FUNC()(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
template<typename Fn>
|
||||
struct FuncTraits /*{
|
||||
__device__ static Fn make();
|
||||
__device__ static T preOp(Fn, T);
|
||||
__device__ static T postOp(Fn, T);
|
||||
}*/;
|
||||
|
||||
// unpack x and y to elements of type T and apply FUNC to each element
|
||||
template<class FUNC, typename T>
|
||||
struct MULTI {
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const;
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const;
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const;
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const;
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@@ -48,17 +50,39 @@ struct MULTI<FUNC, int8_t> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
// for char, we do these as vector ops
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int8_t elt[8];
|
||||
} u;
|
||||
u.pack = x;
|
||||
#pragma unroll 1
|
||||
for (int i=0; i < 8; i++)
|
||||
u.elt[i] = FuncTraits<FUNC>().preOp(fn, u.elt[i]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int8_t elt[8];
|
||||
} u;
|
||||
u.pack = x;
|
||||
#pragma unroll 1
|
||||
for (int i=0; i < 8; i++)
|
||||
u.elt[i] = FuncTraits<FUNC>().postOp(fn, u.elt[i]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@@ -72,17 +96,39 @@ struct MULTI<FUNC, uint8_t> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
// for char, we do these as vector ops
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint8_t elt[8];
|
||||
} u;
|
||||
u.pack = x;
|
||||
#pragma unroll 1
|
||||
for (int i=0; i < 8; i++)
|
||||
u.elt[i] = FuncTraits<FUNC>().preOp(fn, u.elt[i]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint8_t elt[8];
|
||||
} u;
|
||||
u.pack = x;
|
||||
#pragma unroll 1
|
||||
for (int i=0; i < 8; i++)
|
||||
u.elt[i] = FuncTraits<FUNC>().postOp(fn, u.elt[i]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@@ -96,16 +142,36 @@ struct MULTI<FUNC, int32_t> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int32_t elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().preOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().preOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int32_t elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().postOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().postOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@@ -119,16 +185,36 @@ struct MULTI<FUNC, uint32_t> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint32_t elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().preOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().preOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint32_t elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().postOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().postOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
@@ -136,22 +222,75 @@ struct MULTI<FUNC, half> {
|
||||
static_assert(sizeof(PackType) == 4 * sizeof(half),
|
||||
"PackType must be four times the size of half.");
|
||||
|
||||
struct PackHalf2 {
|
||||
half2 a, b;
|
||||
union Converter {
|
||||
PackType pack;
|
||||
half2 h2[2];
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
struct PackHalf2 cx, cy, cr;
|
||||
cx = *(reinterpret_cast<const struct PackHalf2*>(&x));
|
||||
cy = *(reinterpret_cast<const struct PackHalf2*>(&y));
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
|
||||
return *(reinterpret_cast<PackType*>(&cr));
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
Converter cx, cy, cr;
|
||||
cx.pack = x;
|
||||
cy.pack = y;
|
||||
cr.h2[0] = fn(cx.h2[0], cy.h2[0]);
|
||||
cr.h2[1] = fn(cx.h2[1], cy.h2[1]);
|
||||
return cr.pack;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
Converter c;
|
||||
c.pack = x;
|
||||
c.h2[0] = FuncTraits<FUNC>().preOp(fn, c.h2[0]);
|
||||
c.h2[1] = FuncTraits<FUNC>().preOp(fn, c.h2[1]);
|
||||
return c.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
Converter c;
|
||||
c.pack = x;
|
||||
c.h2[0] = FuncTraits<FUNC>().postOp(fn, c.h2[0]);
|
||||
c.h2[1] = FuncTraits<FUNC>().postOp(fn, c.h2[1]);
|
||||
return c.pack;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, rccl_bfloat16> {
|
||||
static_assert(sizeof(PackType) == 4 * sizeof(rccl_bfloat16),
|
||||
"PackType must be four times the size of rccl_bfloat16.");
|
||||
|
||||
union Converter {
|
||||
PackType pack;
|
||||
rccl_bfloat16 h2[4];
|
||||
};
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
Converter cx, cy, cr;
|
||||
cx.pack = x;
|
||||
cy.pack = y;
|
||||
cr.h2[0] = fn(cx.h2[0], cy.h2[0]);
|
||||
cr.h2[1] = fn(cx.h2[1], cy.h2[1]);
|
||||
cr.h2[2] = fn(cx.h2[2], cy.h2[2]);
|
||||
cr.h2[3] = fn(cx.h2[3], cy.h2[3]);
|
||||
return cr.pack;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
Converter c;
|
||||
c.pack = x;
|
||||
c.h2[0] = FuncTraits<FUNC>().preOp(fn, c.h2[0]);
|
||||
c.h2[1] = FuncTraits<FUNC>().preOp(fn, c.h2[1]);
|
||||
c.h2[2] = FuncTraits<FUNC>().preOp(fn, c.h2[2]);
|
||||
c.h2[3] = FuncTraits<FUNC>().preOp(fn, c.h2[3]);
|
||||
return c.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
Converter c;
|
||||
c.pack = x;
|
||||
c.h2[0] = FuncTraits<FUNC>().postOp(fn, c.h2[0]);
|
||||
c.h2[1] = FuncTraits<FUNC>().postOp(fn, c.h2[1]);
|
||||
c.h2[2] = FuncTraits<FUNC>().postOp(fn, c.h2[2]);
|
||||
c.h2[3] = FuncTraits<FUNC>().postOp(fn, c.h2[3]);
|
||||
return c.pack;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, float> {
|
||||
static_assert(sizeof(PackType) == 2 * sizeof(float),
|
||||
@@ -163,50 +302,122 @@ struct MULTI<FUNC, float> {
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
cr.a = fn(cx.a, cy.a);
|
||||
cr.b = fn(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
float elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().preOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().preOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
float elt[2];
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt[0] = FuncTraits<FUNC>().postOp(fn, u.elt[0]);
|
||||
u.elt[1] = FuncTraits<FUNC>().postOp(fn, u.elt[1]);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, double> {
|
||||
static_assert(sizeof(PackType) == sizeof(double),
|
||||
"PackType must be the same size as double.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
double rv = FUNC()(__longlong_as_double(x), __longlong_as_double(y));
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
double rv = fn(__longlong_as_double(x), __longlong_as_double(y));
|
||||
return __double_as_longlong(rv);
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
double elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().preOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
double elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().postOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, uint64_t> {
|
||||
static_assert(sizeof(PackType) == sizeof(uint64_t),
|
||||
"PackType must be the same size as uint64_t.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
uint64_t rv = FUNC()(x, y);
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
uint64_t rv = fn(x, y);
|
||||
return rv;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint64_t elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().preOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
uint64_t elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().postOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, int64_t> {
|
||||
static_assert(sizeof(PackType) == sizeof(int64_t),
|
||||
"PackType must be the same size as int64_t.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
int64_t rv = FUNC()((int64_t)x, (int64_t)y);
|
||||
__device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const {
|
||||
int64_t rv = fn((int64_t)x, (int64_t)y);
|
||||
return rv;
|
||||
}
|
||||
__device__ PackType preOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int64_t elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().preOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
__device__ PackType postOp(FUNC fn, PackType x) const {
|
||||
union {
|
||||
PackType pack;
|
||||
int64_t elt;
|
||||
} u;
|
||||
u.pack = x;
|
||||
u.elt = FuncTraits<FUNC>().postOp(fn, u.elt);
|
||||
return u.pack;
|
||||
}
|
||||
};
|
||||
|
||||
#endif //defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
|
||||
template<typename T> inline __device__
|
||||
T vFetch(const volatile T* ptr) {
|
||||
return *ptr;
|
||||
@@ -259,9 +470,17 @@ typedef ulong2 Pack128;
|
||||
|
||||
template<class FUNC, typename T>
|
||||
struct MULTI128 {
|
||||
__device__ void operator()(Pack128& x, Pack128& y) {
|
||||
x.x = MULTI<FUNC, T>()(x.x, y.x);
|
||||
x.y = MULTI<FUNC, T>()(x.y, y.y);
|
||||
__device__ void operator()(FUNC fn, Pack128& x, Pack128 const& y) const {
|
||||
x.x = MULTI<FUNC, T>()(fn, x.x, y.x);
|
||||
x.y = MULTI<FUNC, T>()(fn, x.y, y.y);
|
||||
}
|
||||
__device__ void preOp(FUNC fn, Pack128 &x) const {
|
||||
x.x = MULTI<FUNC, T>().preOp(fn, x.x);
|
||||
x.y = MULTI<FUNC, T>().preOp(fn, x.y);
|
||||
}
|
||||
__device__ void postOp(FUNC fn, Pack128 &x) const {
|
||||
x.x = MULTI<FUNC, T>().postOp(fn, x.x);
|
||||
x.y = MULTI<FUNC, T>().postOp(fn, x.y);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -284,7 +503,8 @@ inline __device__ void Store128(Pack128* p, Pack128& v) {
|
||||
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t,
|
||||
int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) {
|
||||
FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem
|
||||
) {
|
||||
const int inc = nw * UNROLL * WARP_SIZE;
|
||||
int offset = w * UNROLL * WARP_SIZE + t;
|
||||
|
||||
@@ -297,22 +517,30 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const
|
||||
T vals[UNROLL];
|
||||
// Load and reduce
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = vFetch(srcs[0]+u*WARP_SIZE);
|
||||
if (preOpSrc0) {
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits<FUNC>().preOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i=1; i<MINSRCS; i++) {
|
||||
T vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i=MINSRCS; i<MAXSRCS; i++) {
|
||||
if (i<nsrcs) {
|
||||
T vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) {
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits<FUNC>().postOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
// Store
|
||||
#pragma unroll
|
||||
for (int i = 0; i < MINDSTS; i++) {
|
||||
@@ -331,8 +559,9 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const
|
||||
}
|
||||
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
|
||||
int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) {
|
||||
__device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
|
||||
FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack
|
||||
) {
|
||||
const int inc = nw * UNROLL * WARP_SIZE;
|
||||
int offset = w * UNROLL * WARP_SIZE + t;
|
||||
|
||||
@@ -345,20 +574,32 @@ __device__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
|
||||
Pack128 vals[UNROLL];
|
||||
// Load and reduce
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE);
|
||||
if (preOpSrc0) {
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>().preOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
#pragma unroll 1
|
||||
for (int i=1; i<MINSRCS; i++) {
|
||||
Pack128 vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(fn, vals[u], vals2[u]);
|
||||
}
|
||||
#pragma unroll 1
|
||||
for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) {
|
||||
Pack128 vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
|
||||
for (int i=MINSRCS; i<MAXSRCS; i++) {
|
||||
if (i<nsrcs) {
|
||||
Pack128 vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(fn, vals[u], vals2[u]);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) {
|
||||
#pragma unroll 1
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>().postOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
// Store
|
||||
#pragma unroll 1
|
||||
for (int i = 0; i < MINDSTS; i++) {
|
||||
for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
|
||||
}
|
||||
@@ -375,23 +616,15 @@ __device__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(int32_t); }
|
||||
__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(uint32_t); }
|
||||
|
||||
#define PACKELEMS (sizeof(Pack128) / sizeof(T))
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
// Multiply UNROLL by 2 if single source/single destination
|
||||
#define AUTOUNROLL (UNROLL*((MINSRCS==1 && MINDSTS==1) ? 2 : 1))
|
||||
#else
|
||||
// Try to limit consecutive load/stores to 8.
|
||||
// Use UNROLL 8 when we have a single source and a single destination, 4 otherwise
|
||||
#define AUTOUNROLL (UNROLL*(4/(MINDSTS+MINSRCS)))
|
||||
#endif
|
||||
|
||||
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads,
|
||||
int nsrcs, const T** srcs, int ndsts, T** dsts,
|
||||
int N) {
|
||||
__device__ __forceinline__ void ReduceOrCopyMulti(
|
||||
const int tid, const int nthreads, FUNC fn, bool preOpSrc0, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, int N
|
||||
) {
|
||||
int Nrem = N;
|
||||
if (Nrem <= 0) return;
|
||||
|
||||
@@ -417,7 +650,8 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre
|
||||
int Npack = (Nrem / (PACKELEMS*AUTOUNROLL*WARP_SIZE)) * (AUTOUNROLL*WARP_SIZE); // round down
|
||||
int Nelem = Npack * PACKELEMS;
|
||||
|
||||
ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
@@ -427,7 +661,8 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre
|
||||
Npack = Nrem / PACKELEMS;
|
||||
Nelem = Npack * PACKELEMS;
|
||||
|
||||
ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
@@ -437,14 +672,16 @@ __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthre
|
||||
// unrolled, by-type (mostly for unaligned buffers)
|
||||
int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down
|
||||
|
||||
ReduceCopyMulti<FUNC, T, UNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem);
|
||||
ReduceCopyMulti<FUNC, T, UNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
offset += Nelem;
|
||||
|
||||
// no unroll, by type. Should finish what's remaining.
|
||||
ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem);
|
||||
ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, preOpSrc0, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem);
|
||||
}
|
||||
|
||||
#endif // COMMON_KERNEL_H_
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# See LICENSE.txt for license information
|
||||
#
|
||||
|
||||
dir=$1
|
||||
|
||||
datatypes="i8 u8 i32 u32 i64 u64 f16 f32 f64"
|
||||
if [ "$CUDA_MAJOR" -ge 11 ]
|
||||
then
|
||||
datatypes+=" bf16"
|
||||
fi
|
||||
|
||||
targets="GENOBJS := \\\\\n"
|
||||
|
||||
for base in sendrecv all_reduce all_gather broadcast reduce reduce_scatter; do
|
||||
opn=0
|
||||
for op in sum prod min max; do
|
||||
for op in sum prod min max avg; do
|
||||
dtn=0
|
||||
for dt in i8 u8 i32 u32 i64 u64 f16 f32 f64; do
|
||||
# Order must match that of the ncclDataType_t enum
|
||||
for dt in ${datatypes}; do
|
||||
echo "${dir}/${base}_${op}_${dt}.o : ${base}.cu ${dir}/${base}.dep"
|
||||
echo " @printf \"Compiling %-35s > %s\\\\n\" ${base}.cu ${dir}/${base}_${op}_${dt}.o"
|
||||
echo " mkdir -p ${dir}"
|
||||
|
||||
@@ -10,7 +10,23 @@
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) {
|
||||
return (uint64_t*)shmemGenericPtr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) {
|
||||
}
|
||||
|
||||
inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) {
|
||||
}
|
||||
|
||||
inline __device__ void loadShmem128(uint64_t* shmemAsmPtr, uint64_t &v0, uint64_t &v1) {
|
||||
}
|
||||
|
||||
inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_t v1) {
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline __device__ void loadShmemMisaligned128(T *ptr, uint64_t &v0, uint64_t &v1) {
|
||||
}
|
||||
#else
|
||||
inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) {
|
||||
@@ -38,6 +54,38 @@ inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_
|
||||
asm volatile("st.volatile.shared.v2.u64 [%2], {%0,%1};"
|
||||
:: "l"(v0), "l"(v1), "l"(shmemAsmPtr));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline __device__ void loadShmemMisaligned128(T *ptr, uint64_t &v0, uint64_t &v1) {
|
||||
union {
|
||||
uint32_t tmp4[4];
|
||||
uint64_t tmp8[2];
|
||||
};
|
||||
if(sizeof(T) < 4) {
|
||||
uint32_t *ptr4 = reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(ptr) & -uintptr_t(4));
|
||||
#pragma unroll
|
||||
for(int e=0; e < 4; e++) {
|
||||
// Produce 4 bytes of sub-register type by reading 2 4-byte
|
||||
// aligned values and shifting.
|
||||
uint32_t lo, hi;
|
||||
asm("ld.shared.b32 %0,[%1];" : "=r"(lo) : "l"(ptr4+e+0));
|
||||
asm("ld.shared.b32 %0,[%1];" : "=r"(hi) : "l"(ptr4+e+1));
|
||||
tmp4[e] = __funnelshift_r(lo, hi, 8*(int(reinterpret_cast<uintptr_t>(ptr))%4));
|
||||
}
|
||||
}
|
||||
else if(sizeof(T) == 4) {
|
||||
#pragma unroll
|
||||
for(int e=0; e < 4; e++)
|
||||
asm("ld.shared.b32 %0,[%1];" : "=r"(tmp4[e]) : "l"(ptr+e));
|
||||
}
|
||||
else /*sizeof(T)==8*/ {
|
||||
#pragma unroll
|
||||
for(int e=0; e < 2; e++)
|
||||
asm("ld.shared.b64 %0,[%1];" : "=l"(tmp8[e]) : "l"(ptr+e));
|
||||
}
|
||||
v0 = tmp8[0];
|
||||
v1 = tmp8[1];
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -12,25 +12,7 @@
|
||||
#include "reduce_kernel.h" // for reduction funcs
|
||||
#include "common.h"
|
||||
|
||||
#define SPINS_BEFORE_CHECK_ABORT 1000000
|
||||
|
||||
// Unroll unconditionally the first send/recv since nsend/nrecv should be at
|
||||
// least 1 if SEND/RECV is set.
|
||||
#define FOR_SEND(func, ...) do { \
|
||||
if (SEND) { \
|
||||
/* Send to far first, then close */ \
|
||||
for (int i=1; i<NSEND && i<nsend; i++) func(i, ##__VA_ARGS__); \
|
||||
func(0, ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define FOR_RECV(func, ...) do { \
|
||||
if (RECV) { \
|
||||
/* Recv from close first, then far */ \
|
||||
func(0, ##__VA_ARGS__); \
|
||||
for (int i=1; i<NRECV && i<nrecv; i++) func(i, ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000
|
||||
|
||||
#define barrier_by_group() do { \
|
||||
const int w = threadIdx.x/WARP_SIZE; \
|
||||
@@ -42,389 +24,132 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ROLE_SRC 0x01
|
||||
#define ROLE_DST 0x02
|
||||
#define ROLE_WAIT_RECV 0x04
|
||||
#define ROLE_WAIT_SEND 0x08
|
||||
#define ROLE_POST_SEND 0x10
|
||||
#define ROLE_POST_RECV 0x20
|
||||
/* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128
|
||||
* We use these as template args to the Primtiives class instead of integral
|
||||
* enums (e.g. NCCL_PROTO_LL) because for SIMPLE we need to carry a few extra
|
||||
* numbers. Also these types hold methods which let us compute numbers important
|
||||
* to how that protocol operates with a consistent interface so that our
|
||||
* algorithm code can operate protocol parametrically.
|
||||
*/
|
||||
template<int SlicePerChunk_1, int StepPerSlice_1, int Unroll_1 = COLL_UNROLL>
|
||||
struct ProtoSimple {
|
||||
static constexpr int Id = NCCL_PROTO_SIMPLE;
|
||||
static constexpr int SlicePerChunk = SlicePerChunk_1;
|
||||
static constexpr int StepPerSlice = StepPerSlice_1;
|
||||
static constexpr int Unroll = Unroll_1;
|
||||
|
||||
// Connection index is used to select P2P and NET and needs to be passed into ncclPrimitives constructor.
|
||||
// To avoid adding another parameter which requires changes to every places ncclPrimitives are constructed,
|
||||
// we pack group (max 7) and connection index (max 2) to original group which is 32-bit.
|
||||
#define PACK_GROUP(gr, idx) (gr | (idx<<16))
|
||||
#define TO_GR(group) (group&0xffff)
|
||||
#define TO_IDX(group) (group>>16)
|
||||
|
||||
// Implementation of primitive types
|
||||
template <int UNROLL, int SLICESPERCHUNK, int SLICESTEPS, typename T, int NRECV, int NSEND, int DIRECT, class FUNC>
|
||||
class ncclPrimitives {
|
||||
private:
|
||||
const int tid;
|
||||
int nthreads;
|
||||
int nworkers;
|
||||
const int stepSize;
|
||||
int nrecv = 0;
|
||||
int nsend = 0;
|
||||
struct ncclConnInfo* conn = NULL;
|
||||
volatile int* connSizesFifoPtr = NULL;
|
||||
void** connPtrsFifoPtr = NULL;
|
||||
volatile uint64_t* connHeadPtr = NULL;
|
||||
volatile uint64_t* connTailPtr = NULL;
|
||||
uint64_t connTailCache; // Cache last seen value
|
||||
uint64_t connHeadCache; // Cache last seen value
|
||||
|
||||
int index; // Peer index I'm responsible for
|
||||
int peer = -1;
|
||||
int role = 0;
|
||||
int group;
|
||||
uint64_t step;
|
||||
T* direct = NULL;
|
||||
T* buff;
|
||||
struct ncclDevComm* comm;
|
||||
const int connIndex;
|
||||
|
||||
const T** srcs;
|
||||
T** dsts;
|
||||
|
||||
uint64_t* barriers;
|
||||
uint64_t* barrier_next;
|
||||
// Don't use barrier 0 as it's used by the final sync
|
||||
inline __device__ void barrier() {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if (nthreads == WARP_SIZE) __syncwarp();
|
||||
else barrier_by_group();
|
||||
#else
|
||||
if (nthreads == WARP_SIZE) __syncwarp();
|
||||
else asm volatile ("bar.sync %0, %1;" :: "r"(group+1), "r"(nthreads));
|
||||
#endif
|
||||
// Data bytes (no flags etc) in one step of the fifo queue.
|
||||
__device__ static int calcBytePerStep() {
|
||||
return ncclShmem->comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS;
|
||||
}
|
||||
|
||||
inline __device__ void subBarrier() {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
barrier();
|
||||
#else
|
||||
if (nworkers == nthreads) barrier();
|
||||
else asm volatile ("bar.sync %0, %1;" :: "r"(group+2), "r"(nworkers));
|
||||
#endif
|
||||
// Granularity of data bytes transferred per thread.
|
||||
__device__ static int calcBytePerGrain() {
|
||||
return sizeof(uint64_t); // Bogus value? Nobody queries this metric for simple.
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
uint32_t abort = 0;
|
||||
|
||||
inline __device__ int checkAbort() {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = LOAD(comm->abortFlag);
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
}
|
||||
|
||||
template <int DIRECTPTR>
|
||||
inline __device__ T* directPtr(ssize_t directOffset) {
|
||||
return DIRECTPTR && direct ? direct+directOffset : buff+(step%NCCL_STEPS)*stepSize;
|
||||
}
|
||||
|
||||
template <int DST, int DIRECTSEND>
|
||||
inline __device__ void waitSend(ssize_t directOffset, int nbytes) {
|
||||
spins = 0;
|
||||
while (connHeadCache + NCCL_STEPS < step + SLICESTEPS) {
|
||||
connHeadCache = LOAD(connHeadPtr);
|
||||
if (checkAbort()) break;
|
||||
}
|
||||
if (connSizesFifoPtr) {
|
||||
STORE(connSizesFifoPtr+step%NCCL_STEPS, nbytes);
|
||||
}
|
||||
|
||||
if (connPtrsFifoPtr) dsts[DST+index] = (T *)LOAD(connPtrsFifoPtr+step%NCCL_STEPS);
|
||||
else dsts[DST+index] = directPtr<DIRECTSEND>(directOffset);
|
||||
step += SLICESTEPS;
|
||||
}
|
||||
|
||||
template <int SRC, int DIRECTRECV>
|
||||
inline __device__ void waitRecv(ssize_t directOffset) {
|
||||
spins = 0;
|
||||
#ifdef ENABLE_PROFILING
|
||||
uint64_t t0 = __builtin_amdgcn_s_memrealtime();
|
||||
#endif
|
||||
while (connTailCache < step + SLICESTEPS) {
|
||||
connTailCache = LOAD(connTailPtr);
|
||||
if (checkAbort()) break;
|
||||
}
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (tid == 0) __atomic_fetch_add(&comm->devProf->wait_recv_cycle[blockIdx.x], __builtin_amdgcn_s_memrealtime() - t0, __ATOMIC_SEQ_CST);
|
||||
#endif
|
||||
if (connPtrsFifoPtr) srcs[SRC+index] = (const T *)LOAD(connPtrsFifoPtr+step%NCCL_STEPS);
|
||||
else srcs[SRC+index] = directPtr<DIRECTRECV>(directOffset);
|
||||
step += SLICESTEPS;
|
||||
}
|
||||
|
||||
inline __device__ void postRecv() {
|
||||
STORE(connHeadPtr, step += SLICESTEPS);
|
||||
}
|
||||
|
||||
inline __device__ void postSend() {
|
||||
if (conn->next_hdp_reg) STORE(conn->next_hdp_reg, 0x1);
|
||||
STORE(connTailPtr, step += SLICESTEPS);
|
||||
}
|
||||
|
||||
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
|
||||
inline __device__ void
|
||||
GenericOp(const T* srcPtr, T* dstPtr, int nelem, ssize_t directOffset) {
|
||||
int offset = 0;
|
||||
int sliceSize = stepSize*SLICESTEPS;
|
||||
int dataSize = max(DIVUP(nelem, 16*SLICESPERCHUNK)*16, sliceSize/32);
|
||||
|
||||
#pragma unroll
|
||||
for (int slice=0; slice<SLICESPERCHUNK; ++slice) {
|
||||
int realSize = max(0, min(dataSize, nelem-offset));
|
||||
#ifdef ENABLE_PROFILING
|
||||
uint64_t t0 = __builtin_amdgcn_s_memrealtime();
|
||||
#endif
|
||||
if (tid < nworkers) {
|
||||
if (SRC && (role & ROLE_SRC)) srcs[0] = srcPtr+offset;
|
||||
if (RECV && (role & ROLE_WAIT_RECV)) waitRecv<SRC, DIRECTRECV>(directOffset+offset);
|
||||
if (DST && (role & ROLE_DST)) dsts[0] = dstPtr+offset;
|
||||
if (SEND && (role & ROLE_WAIT_SEND)) waitSend<DST, DIRECTSEND>(directOffset+offset, realSize*sizeof(T));
|
||||
if (realSize > 0) {
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (tid == 0) __atomic_fetch_add(&comm->devProf->wait_cycle[blockIdx.x], __builtin_amdgcn_s_memrealtime() - t0, __ATOMIC_SEQ_CST);
|
||||
#endif
|
||||
subBarrier();
|
||||
ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nworkers, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (SEND && (role & ROLE_POST_SEND) && realSize > 0 && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
if (SEND && (role & ROLE_POST_SEND)) postSend();
|
||||
if (RECV && (role & ROLE_POST_RECV)) postRecv();
|
||||
offset += realSize;
|
||||
}
|
||||
}
|
||||
|
||||
// Scatter and gather do not support DIRECT
|
||||
template <int RECV, int SEND>
|
||||
inline __device__ void
|
||||
ScatterGatherOp(const T* srcPtr, T* dstPtr, int totalElem, int peerElem, int skip, int shift) {
|
||||
int offset = 0; // slice offset
|
||||
int sliceSize = stepSize*SLICESTEPS;
|
||||
int dataSize = max(DIVUP(peerElem, 16*SLICESPERCHUNK)*16, sliceSize/32); // per-peer slice size
|
||||
|
||||
#pragma unroll
|
||||
for (int slice=0; slice<SLICESPERCHUNK; ++slice) {
|
||||
int realSize = max(0, min(dataSize, peerElem-offset));
|
||||
if (tid < nworkers) {
|
||||
if (RECV && (role & ROLE_WAIT_RECV)) waitRecv<0, 0>(0);
|
||||
// realSize is not accurate here; but intra-node does not rely on sizes FIFO
|
||||
if (SEND && (role & ROLE_WAIT_SEND)) waitSend<0, 0>(0, realSize*sizeof(T));
|
||||
subBarrier();
|
||||
if (SEND) {
|
||||
#pragma unroll 1
|
||||
for (int j=0; j<nsend; j++) {
|
||||
int i = (j+shift)%nsend;
|
||||
int peerOffset = i*peerElem + offset;
|
||||
if (skip >=0 && i >= skip) peerOffset += peerElem;
|
||||
const T* src0 = srcPtr + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nworkers, 1, &src0, 1, dsts+i, realPeerSize);
|
||||
}
|
||||
} else if (RECV) {
|
||||
#pragma unroll 1
|
||||
for (int j=0; j<nrecv; j++) {
|
||||
int i = (j+shift)%nrecv;
|
||||
int peerOffset = i*peerElem + offset;
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
T* dst0 = dstPtr + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nworkers, 1, srcs+i, 1, &dst0, realPeerSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (SEND && (role & ROLE_POST_SEND) && realSize > 0 && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
if (SEND && (role & ROLE_POST_SEND)) postSend();
|
||||
if (RECV && (role & ROLE_POST_RECV)) postRecv();
|
||||
offset += realSize;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclChannel* channel, T* directBuff) {
|
||||
if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) {
|
||||
// For oneshot: groups 0,1 use conn 0, groups 2,3 use conn 1
|
||||
conn = &channel->devPeers[peer].recv[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS);
|
||||
if (role & ROLE_POST_RECV) {
|
||||
connHeadPtr = conn->head;
|
||||
// Return credits in case we rounded up.
|
||||
STORE(connHeadPtr, step);
|
||||
}
|
||||
if (role & ROLE_WAIT_RECV) {
|
||||
buff = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
|
||||
//if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
|
||||
// direct = directBuff;
|
||||
// *conn->ptrExchange = directBuff;
|
||||
//}
|
||||
connTailPtr = conn->tail;
|
||||
connTailCache = LOAD(connTailPtr);
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadSendConn(struct ncclChannel* channel) {
|
||||
if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) {
|
||||
// For oneshot: groups 0,1 use conn 0, groups 2,3 use conn 1
|
||||
conn = &channel->devPeers[peer].send[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS);
|
||||
if (role & ROLE_POST_SEND) {
|
||||
connTailPtr = conn->tail;
|
||||
}
|
||||
if (role & ROLE_WAIT_SEND) {
|
||||
buff = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
|
||||
#if 0
|
||||
if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
|
||||
void* volatile* ptr = conn->ptrExchange;
|
||||
while ((direct = (T*)(*ptr)) == NULL) { if (checkAbort()) break; }
|
||||
*ptr = NULL;
|
||||
}
|
||||
#endif
|
||||
connHeadPtr = conn->head;
|
||||
connHeadCache = LOAD(connHeadPtr);
|
||||
connSizesFifoPtr = conn->sizesFifo;
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveSync() {
|
||||
if (role & (ROLE_POST_SEND|ROLE_POST_RECV)) {
|
||||
conn->step = step;
|
||||
__threadfence_system();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclPrimitives(const int tid, const int nworkers, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, struct ncclShmemPtrs* ptrs, int group)
|
||||
: comm(comm), tid(tid), nworkers(nworkers), stepSize(stepSize), srcs((const T**)ptrs[TO_GR(group)].srcs), dsts((T**)ptrs[TO_GR(group)].dsts),
|
||||
group(TO_GR(group)), barriers(&ptrs[TO_GR(group)].barrier), barrier_next(ptrs[TO_GR(group)].barrier_next),
|
||||
connIndex((NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? TO_GR(group)/2 : TO_IDX(group)) {
|
||||
nthreads = nworkers;
|
||||
// For send operations, we need an extra warp to overlap the threadfence and the copy
|
||||
// int postThreads = NSEND && nworkers >= 64 ? WARP_SIZE : 0;
|
||||
// nthreads += postThreads;
|
||||
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
|
||||
for (int i=0; i<NRECV; i++) if (recvPeers[i] != -1) nrecv++;
|
||||
for (int i=0; i<NSEND; i++) if (sendPeers[i] != -1) nsend++;
|
||||
|
||||
#define SYNC_GROUP 8
|
||||
static_assert(NSEND < SYNC_GROUP && NRECV < SYNC_GROUP, "Not enough threads to cover all peers");
|
||||
|
||||
int g = tid / SYNC_GROUP;
|
||||
int ng = nthreads / SYNC_GROUP;
|
||||
index = tid % SYNC_GROUP;
|
||||
|
||||
if (g == 0) {
|
||||
if (index < nrecv) role |= ROLE_WAIT_RECV;
|
||||
if (index == nrecv) role |= ROLE_SRC;
|
||||
} else if (g == 1) {
|
||||
if (index < nsend) role |= ROLE_WAIT_SEND;
|
||||
if (index == nsend) role |= ROLE_DST;
|
||||
} else if (g == ng - 2) {
|
||||
if (index < nrecv) role |= ROLE_POST_RECV;
|
||||
} else if (g == ng - 1) {
|
||||
if (index < nsend) role |= ROLE_POST_SEND;
|
||||
}
|
||||
|
||||
if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) peer = recvPeers[index];
|
||||
if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) peer = sendPeers[index];
|
||||
|
||||
loadRecvConn(channel, directBuff);
|
||||
loadSendConn(channel);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
send(const T* src, int nelem) {
|
||||
GenericOp<0, 0, 0, 1, 1, 0>(src, NULL, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directSend(const T* src, ssize_t directOffset, int nelem) {
|
||||
GenericOp<0, 1, 0, 1, 1, 0>(src, NULL, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recv(T* dst, int nelem) {
|
||||
GenericOp<0, 0, 1, 0, 0, 1>(NULL, dst, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directRecv(T* dst, ssize_t directOffset, int nelem) {
|
||||
GenericOp<1, 0, 1, 0, 0, 1>(NULL, dst, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
copySend(const T* src, T* dst, int nelem) {
|
||||
GenericOp<0, 0, 0, 1, 1, 1>(src, dst, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directCopySend(const T* src, T* dst, ssize_t directOffset, int nelem) {
|
||||
GenericOp<0, 1, 0, 1, 1, 1>(src, dst, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recvCopySend(T* dst, int nelem) {
|
||||
GenericOp<0, 0, 1, 1, 0, 1>(NULL, dst, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directRecvCopySend(T* dst, ssize_t directOffset, int nelem) {
|
||||
GenericOp<1, 1, 1, 1, 0, 1>(NULL, dst, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recvReduceCopy(const T* src, T* dst, int nelem) {
|
||||
GenericOp<0, 0, 1, 0, 1, 1>(src, dst, nelem, 0);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recvReduceSend(const T* src, int nelem) {
|
||||
GenericOp<0, 0, 1, 1, 1, 0>(src, NULL, nelem, 0);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
recvReduceCopySend(const T* src, T* dst, int nelem) {
|
||||
GenericOp<0, 0, 1, 1, 1, 1>(src, dst, nelem, 0);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directRecvReduceCopySend(const T* src, T* dst, ssize_t directOffset, int nelem) {
|
||||
// Direct is only for the send part
|
||||
GenericOp<0, 1, 1, 1, 1, 1>(src, dst, nelem, directOffset);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
scatter(const T* src, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<0, 1>(src, NULL, totalElem, peerElem, skip, shift);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
gather(T* dst, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<1, 0>(NULL, dst, totalElem, peerElem, skip, shift);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ ~ncclPrimitives() {
|
||||
// Save steps for the next operation
|
||||
saveSync();
|
||||
// Group width is how many consecutive group values a subchannel occupies.
|
||||
static constexpr int MaxGroupWidth = 1;
|
||||
__device__ static int calcGroupWidth(bool send, int nthreads) {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProtoLL {
|
||||
static constexpr int Id = NCCL_PROTO_LL;
|
||||
|
||||
// Data bytes (no flags etc) in one step of the fifo queue.
|
||||
__device__ static int calcBytePerStep() {
|
||||
return ncclShmem->comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/2; // Half is data
|
||||
}
|
||||
// Granularity of data bytes transferred per thread.
|
||||
__device__ static int calcBytePerGrain() {
|
||||
return sizeof(uint64_t); // One 16-byte line has 8-bytes of data
|
||||
}
|
||||
// Group width is how many consecutive group values a subchannel occupies.
|
||||
static constexpr int MaxGroupWidth = 1;
|
||||
__device__ static int calcGroupWidth(bool send, int nthreads) {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProtoLL128 {
|
||||
static constexpr int Id = NCCL_PROTO_LL128;
|
||||
|
||||
// Data bytes (no flags etc) in one step of the fifo queue.
|
||||
__device__ static int calcBytePerStep() {
|
||||
return (ncclShmem->comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS)*NCCL_LL128_DATAELEMS/NCCL_LL128_LINEELEMS;
|
||||
}
|
||||
// Granularity of data bytes transferred per thread.
|
||||
__device__ static int calcBytePerGrain() {
|
||||
return NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_DATAELEMS*sizeof(uint64_t)/NCCL_LL128_LINEELEMS;
|
||||
}
|
||||
// Group width is how many consecutive group values a subchannel occupies.
|
||||
static constexpr int MaxGroupWidth = 1;
|
||||
__device__ static int calcGroupWidth(bool send, int nthreads) {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
/* Fan (as in fan-in & fan-out) classes hold recv and send counts. The template
|
||||
* arguments are static bounds on the maximum values. Asymmetric counts are
|
||||
* independent. Symmetric is a static guarantee that nrecv==nsend, so it only
|
||||
* stores one value at runtime. This optimization save 32-bit register, but more
|
||||
* importantly uses fewer predicate registers when unrolling loops.
|
||||
*/
|
||||
template<int MaxRecv_, int MaxSend_>
|
||||
struct FanAsymmetric {
|
||||
static constexpr int MaxRecv = MaxRecv_, MaxSend = MaxSend_;
|
||||
int nr, ns;
|
||||
FanAsymmetric() = default;
|
||||
__device__ FanAsymmetric(int nrecv, int nsend): nr(nrecv), ns(nsend) {
|
||||
// assert(nrecv <= MaxRecv && nsend <= MaxSend);
|
||||
}
|
||||
__device__ int nrecv() const { return MaxRecv ? nr : 0; }
|
||||
__device__ int nsend() const { return MaxSend ? ns : 0; }
|
||||
};
|
||||
|
||||
template<int MaxArity>
|
||||
struct FanSymmetric {
|
||||
static constexpr int MaxRecv = MaxArity, MaxSend = MaxArity;
|
||||
int n;
|
||||
FanSymmetric() = default;
|
||||
__device__ FanSymmetric(int nrecv, int nsend): n(nrecv) {
|
||||
// assert(nrecv == nsend && nrecv <= MaxArity);
|
||||
}
|
||||
__device__ int nrecv() const { return n; }
|
||||
__device__ int nsend() const { return n; }
|
||||
};
|
||||
|
||||
// The primitives class. Specialized per protocol in the other headers.
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, typename Proto>
|
||||
class Primitives;
|
||||
|
||||
// Used by LL & LL128 to implement direct members in the naive way.
|
||||
template<typename RealPrimitives>
|
||||
struct PrimitivesWithoutDirect {
|
||||
__device__ void directSend(intptr_t inpIx, intptr_t remoteOutIx, int eltN) {
|
||||
static_cast<RealPrimitives*>(this)->send(inpIx, eltN);
|
||||
}
|
||||
__device__ void directSendFromOutput(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
static_cast<RealPrimitives*>(this)->sendFromOutput(outIx, eltN);
|
||||
}
|
||||
__device__ void directRecv(intptr_t outIx, int eltN) {
|
||||
static_cast<RealPrimitives*>(this)->recv(outIx, eltN, /*postOp=*/false);
|
||||
}
|
||||
__device__ void directCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
static_cast<RealPrimitives*>(this)->copySend(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
static_cast<RealPrimitives*>(this)->recvCopySend(outIx, eltN, /*postOp=*/false);
|
||||
}
|
||||
__device__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
// Direct is only for the send part
|
||||
static_cast<RealPrimitives*>(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_simple.h"
|
||||
#include "prims_ll.h"
|
||||
//#include "prims_ll128.h"
|
||||
#include "prims_ll128.h"
|
||||
|
||||
#ifdef ENABLE_PROFILING
|
||||
#define INIT_COUNTER \
|
||||
|
||||
@@ -5,18 +5,25 @@
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
template <typename T, class FUNC, int NRECV, int NSEND>
|
||||
class ncclLLPrimitives {
|
||||
private:
|
||||
template<typename T, typename RedOp, typename Fan, int Direct>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL>> {
|
||||
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
RedOp redOp;
|
||||
const int tid;
|
||||
const int nthreads;
|
||||
const int wid;
|
||||
const int group;
|
||||
const int stepLines;
|
||||
int nrecv = 0;
|
||||
int nsend = 0;
|
||||
Fan fan;
|
||||
T *userBufs[2];
|
||||
struct ncclConnInfo* recvConn = NULL;
|
||||
volatile uint64_t* recvConnHeadPtr = NULL;
|
||||
uint64_t recvConnHead;
|
||||
uint64_t* barriers;
|
||||
uint64_t* barrier_next;
|
||||
|
||||
struct ncclConnInfo* sendConn = NULL;
|
||||
volatile int* sendConnFifoPtr = NULL;
|
||||
@@ -24,11 +31,10 @@ class ncclLLPrimitives {
|
||||
uint64_t sendConnHead;
|
||||
uint64_t sendConnHeadCache; // Cache last seen value
|
||||
|
||||
uint64_t recvStep[NRECV];
|
||||
uint64_t sendStep[NSEND];
|
||||
union ncclLLFifoLine* recvBuff[NRECV];
|
||||
union ncclLLFifoLine* sendBuff[NSEND];
|
||||
struct ncclDevComm* comm;
|
||||
uint64_t recvStep[MaxRecv];
|
||||
uint64_t sendStep[MaxSend];
|
||||
union ncclLLFifoLine* recvBuff[MaxRecv];
|
||||
union ncclLLFifoLine* sendBuff[MaxSend];
|
||||
|
||||
inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; }
|
||||
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; }
|
||||
@@ -41,28 +47,32 @@ class ncclLLPrimitives {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
__syncthreads();
|
||||
#else
|
||||
asm volatile ("basync 1, %0;" :: "r"(nthreads));
|
||||
asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(1+group));
|
||||
#endif
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
static inline __device__ uint32_t __funnelshift_r(uint32_t lo, uint32_t hi, uint32_t shift) {
|
||||
uint64_t val64 = ((uint64_t)lo+((uint64_t)hi<<32))>>(shift&31);
|
||||
return (uint32_t)val64;
|
||||
}
|
||||
|
||||
uint32_t abort = 0;
|
||||
|
||||
inline __device__ int checkAbort(int i, int send) {
|
||||
inline __device__ int checkAbort(int &spins, int send) {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = LOAD(comm->abortFlag);
|
||||
if (abort == 0 && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = LOAD(ncclShmem->comm.abortFlag);
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
}
|
||||
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
spins = 0;
|
||||
if (sendConnHeadPtr) {
|
||||
int spins = 0;
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
||||
sendConnHeadCache = LOAD(sendConnHeadPtr);
|
||||
if (checkAbort(wid, 1)) break;
|
||||
if (checkAbort(spins, 1)) break;
|
||||
}
|
||||
if (sendConnFifoPtr) {
|
||||
int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes;
|
||||
@@ -83,36 +93,68 @@ class ncclLLPrimitives {
|
||||
|
||||
inline __device__ void incSend(int i, int offset) {
|
||||
// LL Cleanup : write all flags in the slice to make sure we don't have
|
||||
// data corruption when flag loops ove
|
||||
// data corruption when flag loops over.
|
||||
if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) {
|
||||
for (int o = offset; o<stepLines; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
|
||||
}
|
||||
sendStep[i]++;
|
||||
}
|
||||
|
||||
__device__ uint64_t readLL(int i, int offset) {
|
||||
__device__ uint64_t readLL(int offset, int i) {
|
||||
union ncclLLFifoLine* src = recvPtr(i) + offset;
|
||||
uint32_t flag = recvFlag(i);
|
||||
uint32_t data1, flag1, data2, flag2;
|
||||
spins = 0;
|
||||
int spins = 0;
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
union ncclLLFifoLine i4;
|
||||
do {
|
||||
i4.v[0] = __builtin_nontemporal_load(src->v);
|
||||
i4.v[1] = __builtin_nontemporal_load(src->v+1);
|
||||
if (checkAbort(i, 0)) break;
|
||||
if (checkAbort(spins, 0)) break;
|
||||
} while ((i4.flag1 != flag) || (i4.flag2 != flag));
|
||||
uint64_t val64 = (uint64_t)(i4.data1) + (((uint64_t)i4.data2) << 32);
|
||||
#else
|
||||
do {
|
||||
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
|
||||
if (checkAbort(i, 0)) break;
|
||||
asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
|
||||
if (checkAbort(spins, 0)) break;
|
||||
} while ((flag1 != flag) || (flag2 != flag));
|
||||
uint64_t val64 = data1 + (((uint64_t)data2) << 32);
|
||||
#endif
|
||||
return val64;
|
||||
}
|
||||
|
||||
template<int BeginIx>
|
||||
__device__ void readLLBeginAll(int offset, ncclLLFifoLine(&line)[MaxRecv]) {
|
||||
#pragma unroll 1
|
||||
for (int i=BeginIx; i < MaxRecv; i++) {
|
||||
if (i < fan.nrecv()) {
|
||||
union ncclLLFifoLine* src = recvPtr(i) + offset;
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
line[i].v[0] = __builtin_nontemporal_load(src->v);
|
||||
line[i].v[1] = __builtin_nontemporal_load(src->v+1);
|
||||
#else
|
||||
asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
__device__ uint64_t readLLFinish(int offset, ncclLLFifoLine(&line)[MaxRecv], int i) {
|
||||
union ncclLLFifoLine* src = recvPtr(i) + offset;
|
||||
uint32_t flag = recvFlag(i);
|
||||
int spins = 0;
|
||||
do {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
line[i].v[0] = __builtin_nontemporal_load(src->v);
|
||||
line[i].v[1] = __builtin_nontemporal_load(src->v+1);
|
||||
#else
|
||||
asm("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4));
|
||||
#endif
|
||||
if (checkAbort(spins, 0)) break;
|
||||
} while(line[i].flag1 != flag || line[i].flag2 != flag);
|
||||
uint64_t val64 = line[i].data1 + (((uint64_t)line[i].data2) << 32);
|
||||
return val64;
|
||||
}
|
||||
|
||||
__device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
union ncclLLFifoLine i4;
|
||||
@@ -127,66 +169,203 @@ class ncclLLPrimitives {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Using memcpy handles misaligned pointer
|
||||
__device__ uint64_t readAL(uint64_t* src) {
|
||||
uint64_t val;
|
||||
memcpy((char*)&val, (char*)src, sizeof(uint64_t));
|
||||
return val;
|
||||
static constexpr int EltPerLine = sizeof(uint64_t)/sizeof(T);
|
||||
|
||||
template<typename U>
|
||||
__device__ static U load(U *src) {
|
||||
union {
|
||||
U elt;
|
||||
uint8_t u1;
|
||||
uint16_t u2;
|
||||
uint32_t u4;
|
||||
uint64_t u8;
|
||||
};
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if(sizeof(U) == 1)
|
||||
u1 = __builtin_nontemporal_load((uint8_t*)src);
|
||||
else if(sizeof(U) == 2)
|
||||
u2 = __builtin_nontemporal_load((uint16_t*)src);
|
||||
else if(sizeof(U) == 4)
|
||||
u4 = __builtin_nontemporal_load((uint32_t*)src);
|
||||
else
|
||||
u8 = __builtin_nontemporal_load((uint64_t*)src);
|
||||
#else
|
||||
if(sizeof(U) == 1)
|
||||
asm("ld.volatile.global.b8 %0,[%1];" : "=r"(u4) : "l"(src));
|
||||
else if(sizeof(U) == 2)
|
||||
asm("ld.volatile.global.b16 %0,[%1];" : "=h"(u2) : "l"(src));
|
||||
else if(sizeof(U) == 4)
|
||||
asm("ld.volatile.global.b32 %0,[%1];" : "=r"(u4) : "l"(src));
|
||||
else
|
||||
asm("ld.volatile.global.b64 %0,[%1];" : "=l"(u8) : "l"(src));
|
||||
#endif
|
||||
return elt;
|
||||
}
|
||||
|
||||
__device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) {
|
||||
memcpy((char*)dst, (char*)&val, nbytes);
|
||||
template<typename U>
|
||||
__device__ static void store(U *dst, U val) {
|
||||
union {
|
||||
U elt;
|
||||
uint8_t u1;
|
||||
uint16_t u2;
|
||||
uint32_t u4;
|
||||
uint64_t u8;
|
||||
};
|
||||
elt = val;
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if(sizeof(U) == 1)
|
||||
__builtin_nontemporal_store(u1, (uint8_t*)dst);
|
||||
else if(sizeof(U) == 2)
|
||||
__builtin_nontemporal_store(u2, (uint16_t*)dst);
|
||||
else if(sizeof(U) == 4)
|
||||
__builtin_nontemporal_store(u4, (uint32_t*)dst);
|
||||
else
|
||||
__builtin_nontemporal_store(u8, (uint64_t*)dst);
|
||||
#else
|
||||
if(sizeof(U) == 1)
|
||||
asm("st.volatile.global.b8 [%0],%1;" :: "l"(dst), "r"(u4));
|
||||
else if(sizeof(U) == 2)
|
||||
asm("st.volatile.global.b16 [%0],%1;" :: "l"(dst), "h"(u2));
|
||||
else if(sizeof(U) == 4)
|
||||
asm("st.volatile.global.b32 [%0],%1;" :: "l"(dst), "r"(u4));
|
||||
else
|
||||
asm("st.volatile.global.b64 [%0],%1;" :: "l"(dst), "l"(u8));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int RECV, int SEND, int SRC, int DST>
|
||||
__device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) {
|
||||
uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T);
|
||||
uint32_t npack = DIVUP(nbytes, sizeof(uint64_t));
|
||||
uint64_t* srcPack = (uint64_t*)srcPtr;
|
||||
uint64_t* dstPack = (uint64_t*)dstPtr;
|
||||
int offset = tid;
|
||||
struct DataLoader {
|
||||
int misalign;
|
||||
union {
|
||||
uint32_t u4[sizeof(T) <= 2 ? 3 : 2];
|
||||
uint64_t u8;
|
||||
T elt[EltPerLine];
|
||||
};
|
||||
|
||||
// Always waitSend in case of cleanup
|
||||
if (SEND) waitSend(npack*sizeof(union ncclLLFifoLine));
|
||||
|
||||
// Do multiples of 64 bits
|
||||
#pragma unroll 1
|
||||
for (; offset<npack; offset+=nthreads) {
|
||||
// Recv : local, then intra-node, then inter-node
|
||||
uint64_t val = SRC ? readAL(srcPack+offset) : readLL(0, offset);
|
||||
if (RECV) {
|
||||
if (SRC) val = MULTI<FUNC, T>()(readLL(0, offset), val);
|
||||
for (int i=1; i<NRECV && i<nrecv; i++) {
|
||||
val = MULTI<FUNC, T>()(readLL(i, offset), val);
|
||||
}
|
||||
__device__ void loadBegin(T *src, int eltN) {
|
||||
if (sizeof(T) <= 2) {
|
||||
misalign = reinterpret_cast<uintptr_t>(src)%4;
|
||||
uint32_t *p = reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(src) & -uintptr_t(4));
|
||||
u4[0] = load(p+0);
|
||||
u4[1] = misalign + eltN*sizeof(T) > 4 ? load(p+1) : 0;
|
||||
// u4[2] would be simpler, but that throws warnings on some compilers
|
||||
u4[sizeof(T) <= 2 ? 2 : 0] = misalign + eltN*sizeof(T) > 8 ? load(p+2) : 0;
|
||||
}
|
||||
|
||||
// Send : inter-node, then intra-node, then local
|
||||
if (SEND) {
|
||||
for (int i=1; i<NSEND && i<nsend; i++) storeLL(sendPtr(i)+offset, val, sendFlag(i));
|
||||
storeLL(sendPtr(0)+offset, val, sendFlag(0));
|
||||
}
|
||||
if (DST) {
|
||||
if (((offset*sizeof(uint64_t)) ^ nbytes) < sizeof(uint64_t)) {
|
||||
// Last incomplete word
|
||||
storeAL(dstPack+offset, val, nbytes & 0x7);
|
||||
} else {
|
||||
storeAL(dstPack+offset, val, sizeof(uint64_t));
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int i=0; i < EltPerLine; i++) {
|
||||
if(i==0 || i < eltN)
|
||||
elt[i] = load(src + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
FOR_RECV(incRecv); if (RECV) postRecv();
|
||||
FOR_SEND(incSend, offset);
|
||||
|
||||
__device__ uint64_t loadFinish() {
|
||||
if (sizeof(T) <= 2) {
|
||||
u4[0] = __funnelshift_r(u4[0], u4[1], 8*misalign);
|
||||
// u4[2] would be simpler, but that throws warnings on some compilers
|
||||
u4[1] = __funnelshift_r(u4[1], u4[sizeof(T) <= 2 ? 2 : 0], 8*misalign);
|
||||
}
|
||||
return u8;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void storeData(T *dst, uint64_t val, int eltN) {
|
||||
union {
|
||||
uint64_t u8;
|
||||
T elt[EltPerLine];
|
||||
};
|
||||
u8 = val;
|
||||
#pragma unroll
|
||||
for(int i=0; i < EltPerLine; i++) {
|
||||
if (i==0 || i < eltN)
|
||||
//store(dst+i, elt[i]);
|
||||
dst[i] = elt[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
|
||||
|
||||
// Always waitSend in case of cleanup
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
#ifdef ENABLE_PROFILING
|
||||
uint64_t t0 = __builtin_amdgcn_s_memrealtime();
|
||||
#endif
|
||||
if (SEND) waitSend(divUp(nelem, EltPerLine)*sizeof(ncclLLFifoLine));
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (SEND && tid == 0) __atomic_fetch_add(&ncclShmem->comm.devProf->wait_cycle[blockIdx.x], __builtin_amdgcn_s_memrealtime() - t0, __ATOMIC_SEQ_CST);
|
||||
#endif
|
||||
|
||||
nelem -= tid*EltPerLine;
|
||||
srcElts += tid*EltPerLine;
|
||||
dstElts += tid*EltPerLine;
|
||||
int offset = tid;
|
||||
int eltPerTrip = nthreads*EltPerLine;
|
||||
while (nelem > 0) {
|
||||
int eltInLine = EltPerLine < nelem ? EltPerLine : nelem;
|
||||
|
||||
DataLoader dl;
|
||||
ncclLLFifoLine line[MaxRecv];
|
||||
uint64_t data, peerData;
|
||||
if (SRC) {
|
||||
dl.loadBegin(srcElts, eltInLine);
|
||||
srcElts += eltPerTrip;
|
||||
}
|
||||
if (RECV) {
|
||||
readLLBeginAll<1>(offset, line);
|
||||
peerData = readLL(offset, 0);
|
||||
}
|
||||
if (SRC) {
|
||||
data = dl.loadFinish();
|
||||
if (SrcBuf == Input) data = MULTI<RedOp, T>().preOp(redOp, data);
|
||||
}
|
||||
if (RECV) {
|
||||
data = !SRC ? peerData : MULTI<RedOp,T>()(redOp, peerData, data);
|
||||
#pragma unroll 1
|
||||
for (int i=1; i < MaxRecv && i < fan.nrecv(); i++) {
|
||||
peerData = readLLFinish(offset, line, i);
|
||||
data = MULTI<RedOp,T>()(redOp, peerData, data);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) data = MULTI<RedOp, T>().postOp(redOp, data);
|
||||
|
||||
// Send : inter-node, then intra-node, then local
|
||||
if (SEND) {
|
||||
for (int i=1; i < MaxSend && i < fan.nsend(); i++)
|
||||
storeLL(sendPtr(i)+offset, data, sendFlag(i));
|
||||
storeLL(sendPtr(0)+offset, data, sendFlag(0));
|
||||
}
|
||||
if (DST) {
|
||||
storeData(dstElts, data, eltInLine);
|
||||
dstElts += eltPerTrip;
|
||||
}
|
||||
nelem -= eltPerTrip;
|
||||
offset += nthreads;
|
||||
}
|
||||
|
||||
if (RECV) {
|
||||
for (int i=0; i < MaxRecv; i++) incRecv(i);
|
||||
postRecv();
|
||||
}
|
||||
if (SEND) {
|
||||
for (int i=1; i < MaxSend && i < fan.nsend(); i++)
|
||||
incSend(i, offset);
|
||||
incSend(0, offset);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
||||
recvBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL);
|
||||
recvStep[i] = LOAD(&conn->step);
|
||||
if (wid == i) recvConn = conn;
|
||||
nrecv++;
|
||||
}
|
||||
__device__ __forceinline__ void loadRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv()) {
|
||||
recvConnHeadPtr = LOAD(&recvConn->head);
|
||||
recvConnHead = LOAD(&recvConn->step);
|
||||
}
|
||||
@@ -196,10 +375,9 @@ class ncclLLPrimitives {
|
||||
sendBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL);
|
||||
sendStep[i] = LOAD(&conn->step);
|
||||
if (wid == i) sendConn = conn;
|
||||
nsend++;
|
||||
}
|
||||
__device__ __forceinline__ void loadSendSync() {
|
||||
if (tid < nsend) {
|
||||
if (tid < fan.nsend()) {
|
||||
sendConnHeadPtr = LOAD(&sendConn->head);
|
||||
sendConnHeadCache = LOAD(sendConnHeadPtr);
|
||||
sendConnHead = LOAD(&sendConn->step);
|
||||
@@ -207,65 +385,75 @@ class ncclLLPrimitives {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
STORE(&recvConn->step, recvConnHead);
|
||||
__threadfence_block();
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveSendSync() {
|
||||
if (tid < nsend) {
|
||||
STORE(&sendConn->step, sendConnHead);
|
||||
__threadfence_block();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines) {
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
__device__ Primitives(
|
||||
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0, int connIndex=0
|
||||
):
|
||||
redOp(FuncTraits<RedOp>().make(ncclShmem->comm.nRanks)),
|
||||
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group),
|
||||
stepLines(ncclShmem->comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)),
|
||||
barriers(&ncclShmem->groups[group].barrier), barrier_next(ncclShmem->groups[group].barrier_next) {
|
||||
|
||||
auto *channel = &ncclShmem->channel;
|
||||
// If we are going to support oneshot collNet + LL, then we would need to add connector index here
|
||||
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv->conn, i);
|
||||
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send->conn, i);
|
||||
int nrecv=0, nsend=0;
|
||||
while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) {
|
||||
loadRecvConn(&channel->devPeers[recvPeers[nrecv]].recv->conn, nrecv);
|
||||
nrecv++;
|
||||
}
|
||||
while (nsend < MaxSend && sendPeers[nsend] >= 0) {
|
||||
loadSendConn(&channel->devPeers[sendPeers[nsend]].send->conn, nsend);
|
||||
nsend++;
|
||||
}
|
||||
this->fan = Fan(nrecv, nsend);
|
||||
loadRecvSync();
|
||||
loadSendSync();
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
}
|
||||
|
||||
__device__ void send(const T* src, int nelem) {
|
||||
return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem);
|
||||
}
|
||||
|
||||
__device__ void recv(T* dst, int nelem) {
|
||||
return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceSend(const T* src, int nelem) {
|
||||
return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
|
||||
return LLGenericOp<1, 0, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void copySend(const T* src, T* dst, int nelem) {
|
||||
return LLGenericOp<0, 1, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvCopySend(T* dst, int nelem) {
|
||||
return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
|
||||
return LLGenericOp<1, 1, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ ~ncclLLPrimitives() {
|
||||
__device__ ~Primitives() {
|
||||
// Save steps for the next operation
|
||||
saveRecvSync();
|
||||
saveSendSync();
|
||||
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv())
|
||||
STORE(&recvConn->step, recvConnHead);
|
||||
if (tid < fan.nsend())
|
||||
STORE(&sendConn->step, sendConnHead);
|
||||
// Ensure all steps written back
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
userBufs[Input] = (T*)inputBuf;
|
||||
userBufs[Output] = (T*)outputBuf;
|
||||
}
|
||||
|
||||
__device__ void moveDataPtrs(intptr_t delta) {
|
||||
userBufs[Input] += delta;
|
||||
userBufs[Output] += delta;
|
||||
}
|
||||
|
||||
__device__ void send(intptr_t inpIx, int eltN) {
|
||||
return LLGenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void sendFromOutput(intptr_t outIx, int eltN) {
|
||||
return LLGenericOp<0, 1, Output, -1>(outIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void recv(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<1, 0, -1, Output>(-1, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvReduceSend(intptr_t inpIx, int eltN) {
|
||||
return LLGenericOp<1, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<1, 0, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<0, 1, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<1, 1, -1, Output>(-1, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return LLGenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -9,17 +9,24 @@
|
||||
|
||||
#define NCCL_LL128_FLAGTHREAD (NCCL_LL128_LINEELEMS-1)
|
||||
|
||||
template <typename T, class FUNC, int NRECV, int NSEND>
|
||||
class ncclLL128Primitives {
|
||||
private:
|
||||
#define __any_sync(WARP_MASK, needReload) (true)
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL128>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128>> {
|
||||
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
RedOp redOp;
|
||||
const int tid;
|
||||
const int nthreads;
|
||||
const int wid;
|
||||
const int stepSize;
|
||||
const int warp;
|
||||
const bool flagThread;
|
||||
int nrecv = 0;
|
||||
int nsend = 0;
|
||||
const int group;
|
||||
Fan fan;
|
||||
T *userBufs[2];
|
||||
struct ncclConnInfo* recvConn = NULL;
|
||||
volatile uint64_t* recvConnHeadPtr = NULL;
|
||||
uint64_t recvConnHead;
|
||||
@@ -32,14 +39,10 @@ class ncclLL128Primitives {
|
||||
uint64_t sendConnHead;
|
||||
uint64_t sendConnHeadCache; // Cache last seen value
|
||||
|
||||
uint64_t recvStep[NRECV];
|
||||
uint64_t sendStep[NSEND];
|
||||
uint64_t* recvBuff[NRECV];
|
||||
uint64_t* sendBuff[NSEND];
|
||||
struct ncclDevComm* comm;
|
||||
|
||||
volatile uint64_t* shmem;
|
||||
uint32_t* sync;
|
||||
uint64_t recvStep[MaxRecv];
|
||||
uint64_t sendStep[MaxSend];
|
||||
uint64_t* recvBuff[MaxRecv];
|
||||
uint64_t* sendBuff[MaxSend];
|
||||
|
||||
inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepSize; }
|
||||
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; }
|
||||
@@ -49,388 +52,376 @@ class ncclLL128Primitives {
|
||||
inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; }
|
||||
|
||||
inline __device__ void barrier() {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
__syncthreads();
|
||||
#else
|
||||
if (NSEND>NRECV) {
|
||||
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
|
||||
} else {
|
||||
asm volatile ("bar.sync 2, %0;" :: "r"(nthreads));
|
||||
}
|
||||
#endif
|
||||
asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(1+group));
|
||||
}
|
||||
|
||||
uint32_t spins = 0;
|
||||
uint32_t abort = 0;
|
||||
|
||||
inline __device__ int checkAbort(int i, int send) {
|
||||
inline __device__ int checkAbort(int &spins, int i, int send) {
|
||||
spins++;
|
||||
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = LOAD(comm->abortFlag);
|
||||
if (abort == 0 && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
|
||||
abort = *ncclShmem->comm.abortFlag;
|
||||
spins = 0;
|
||||
}
|
||||
return abort;
|
||||
}
|
||||
|
||||
inline __device__ void waitSend(int nbytes) {
|
||||
spins = 0;
|
||||
if (sendConnHeadPtr) {
|
||||
int spins = 0;
|
||||
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
|
||||
sendConnHeadCache = LOAD(sendConnHeadPtr);
|
||||
if (checkAbort(wid, 1)) break;
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
if (checkAbort(spins, wid, 1)) break;
|
||||
}
|
||||
if (sendConnFifoPtr) {
|
||||
STORE(sendConnFifoPtr+sendStep[wid]%NCCL_STEPS, nbytes);
|
||||
sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes;
|
||||
}
|
||||
sendConnHead += 1;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void incRecv(int i) {
|
||||
recvStep[i] += 1;
|
||||
}
|
||||
inline __device__ void postRecv() {
|
||||
if (recvConnHeadPtr) STORE(recvConnHeadPtr, recvConnHead += 1);
|
||||
}
|
||||
|
||||
inline __device__ void incSend(int i) {
|
||||
sendStep[i] += 1;
|
||||
if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1;
|
||||
}
|
||||
inline __device__ void postSend() {
|
||||
if (sendConnTailPtr) { __threadfence(); STORE(sendConnTailPtr, sendConnTail += 1); }
|
||||
if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; }
|
||||
}
|
||||
|
||||
template <int ELEMS_PER_THREAD>
|
||||
inline __device__ void loadSrcToShmem128(int maxOffset, const uint64_t* src64Ptr) {
|
||||
#if 0
|
||||
uint64_t v[ELEMS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
if (u*WARP_SIZE < maxOffset) load128(src64Ptr+u*WARP_SIZE, v[u], v[u+1]);
|
||||
}
|
||||
uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
storeShmem128(shmemAsmPtr+u*WARP_SIZE, v[u], v[u+1]);
|
||||
}
|
||||
#else
|
||||
uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
if (u*WARP_SIZE < maxOffset) {
|
||||
using Vec = uint64_t __attribute__((ext_vector_type(2)));
|
||||
Vec i2;
|
||||
//load128(src64Ptr+u*WARP_SIZE, v0, v1);
|
||||
asm volatile ("flat_load_dwordx4 %0, %1\n"
|
||||
"s_waitcnt vmcnt(0)\n" : "=v"(i2) : "v"(src64Ptr+u*WARP_SIZE));
|
||||
//storeShmem128(shmemAsmPtr+u*WARP_SIZE, i2[0], i2[1]);
|
||||
*(shmemAsmPtr+u*WARP_SIZE) = i2[0];
|
||||
*(shmemAsmPtr+u*WARP_SIZE+1) = i2[1];
|
||||
template<int WordPerThread>
|
||||
__device__ __forceinline__ void loadRegsBegin(uint64_t(®s)[WordPerThread], T const *src, int eltN) {
|
||||
constexpr int EltPer16B = 16/sizeof(T);
|
||||
if(reinterpret_cast<uintptr_t>(src)%16 == 0) {
|
||||
/* We are aligned to 16 bytes, so load directly to registers no shmem.
|
||||
* Flag threads load half as much data which gets shuffled to the even
|
||||
* registers during Finish. The point of splitting into two phases is to
|
||||
* defer that shuffle, which incurs a dependency stall, until after other
|
||||
* memops are launched by the caller.
|
||||
*/
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++) {
|
||||
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
|
||||
if(!flagThread || g%2==0) {
|
||||
if(ix*EltPer16B < eltN)
|
||||
load128((uint64_t*)(src + ix*EltPer16B), regs[2*g+0], regs[2*g+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
// Not aligned. Stage the smallest 16 byte aligned region subsuming the
|
||||
// buffer into shmem.
|
||||
int misalignment = reinterpret_cast<uintptr_t>(src) % 16;
|
||||
uint64_t *src8 = reinterpret_cast<uint64_t*>(reinterpret_cast<uintptr_t>(src) & -uintptr_t(16));
|
||||
uint64_t *shm8 = shmemCvtPtr(ncclShmem->ll128warp[warp]);
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++)
|
||||
if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T))
|
||||
load128(src8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]);
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++)
|
||||
storeShmem128(shm8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]);
|
||||
|
||||
inline __device__ void loadSrcToShmem(int start, int end, const T* srcPtr) {
|
||||
T* shmemPtr = (T*)(shmem-2*wid);
|
||||
for (int offset = start+wid; offset < end; offset += WARP_SIZE) {
|
||||
shmemPtr[offset] = srcPtr[offset];
|
||||
__syncwarp();
|
||||
|
||||
// Now load from shmem stage to regs. Preserve the same pre-shuffled layout
|
||||
// as the aligned case since Finish() will be applied regardless.
|
||||
T *shm = (T*)shm8 + misalignment/sizeof(T);
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++) {
|
||||
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
|
||||
if(!flagThread || g%2==0) {
|
||||
if(ix*EltPer16B < eltN)
|
||||
loadShmemMisaligned128(shm + ix*EltPer16B, regs[2*g+0], regs[2*g+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int ELEMS_PER_THREAD>
|
||||
inline __device__ void storeShmemToDst128(int maxOffset, uint64_t* dst64Ptr) {
|
||||
using Velem = uint64_t __attribute__((ext_vector_type(ELEMS_PER_THREAD)));
|
||||
Velem v;
|
||||
uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
|
||||
template<int WordPerThread>
|
||||
__device__ __forceinline__ void loadRegsFinish(uint64_t(®s)[WordPerThread]) {
|
||||
// Move data out of flag registers into the vacant registers.
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = *(shmemAsmPtr+u*WARP_SIZE);
|
||||
v[u+1] = *(shmemAsmPtr+u*WARP_SIZE+1);
|
||||
//loadShmem128(shmemAsmPtr+u*WARP_SIZE, v[u], v[u+1]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
//if (u*WARP_SIZE < maxOffset) store128(dst64Ptr+u*WARP_SIZE, v[u], v[u+1]);
|
||||
using Vec = uint64_t __attribute__((ext_vector_type(2)));
|
||||
Vec i2;
|
||||
i2[0] = v[u];
|
||||
i2[1] = v[u+1];//
|
||||
if (u*WARP_SIZE < maxOffset) asm volatile ("flat_store_dwordx4 %0, %1\n"
|
||||
"s_waitcnt vmcnt(0)\n" : : "v"(dst64Ptr+u*WARP_SIZE), "v"(i2));
|
||||
for (int g=1; g < WordPerThread/2; g+=2) {
|
||||
if (flagThread) regs[2*g] = regs[2*g-1];
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void storeShmemToDst(int start, int end, T* dstPtr) {
|
||||
T* shmemPtr = (T*)(shmem-2*wid);
|
||||
for (int offset = start+wid; offset < end; offset += WARP_SIZE) {
|
||||
dstPtr[offset] = shmemPtr[offset];
|
||||
template<int WordPerThread>
|
||||
__device__ __forceinline__ void storeRegs(T *dst, uint64_t(®s)[WordPerThread], int eltN) {
|
||||
constexpr int EltPer16B = 16/sizeof(T);
|
||||
// Reverse Finish() register permuatation.
|
||||
#pragma unroll
|
||||
for (int g=1; g < WordPerThread/2; g+=2) {
|
||||
if (flagThread) regs[2*g-1] = regs[2*g];
|
||||
}
|
||||
// Write to dst if 16-byte aligned, shmem otherwise.
|
||||
int misalignment = reinterpret_cast<uintptr_t>(dst)%16;
|
||||
uint64_t *shm8 = shmemCvtPtr(ncclShmem->ll128warp[warp]);
|
||||
#pragma unroll
|
||||
for(int g=0; g < WordPerThread/2; g++) {
|
||||
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
|
||||
if (!flagThread || g%2==0) {
|
||||
if(misalignment == 0 && (ix+1)*EltPer16B <= eltN)
|
||||
store128((uint64_t*)(dst + ix*EltPer16B), regs[2*g+0], regs[2*g+1]);
|
||||
else
|
||||
storeShmem128(shm8+2*ix, regs[2*g+0], regs[2*g+1]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
// Write rest from shmem to dst. No need to coalesce stores to 16-bytes,
|
||||
// the hardware keeps up fine.
|
||||
T *shm = (T*)ncclShmem->ll128warp[warp];
|
||||
int skip = misalignment == 0 ? eltN & -EltPer16B : 0;
|
||||
for(int i=skip+wid; i < eltN; i += WARP_SIZE)
|
||||
dst[i] = shm[i];
|
||||
}
|
||||
|
||||
#define WARP_MASK 0xffffffff
|
||||
|
||||
template <int ELEMS_PER_THREAD, int RECV, int SEND, int SRC, int DST>
|
||||
__device__ __forceinline__ void recvReduceSendCopy(int ll128Offset) {
|
||||
uint64_t v[ELEMS_PER_THREAD];
|
||||
template <int ELEMS_PER_THREAD, int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ __forceinline__ void recvReduceSendCopy(uint64_t(&v)[ELEMS_PER_THREAD], int ll128Offset, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
uint64_t vr[ELEMS_PER_THREAD];
|
||||
|
||||
/************* Data Loading : SHMEM -> REG **************/
|
||||
if (SRC) {
|
||||
volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = shmem64Ptr[u*(WARP_SIZE-2)];
|
||||
if (!flagThread) v[u+1] = shmem64Ptr[u*(WARP_SIZE-2)+1];
|
||||
}
|
||||
}
|
||||
/*********** End Data Loading : SHMEM -> REG ************/
|
||||
|
||||
/************************ Recv **************************/
|
||||
__syncwarp();
|
||||
/************************ Wait first recv ********************/
|
||||
if (RECV) {
|
||||
uint64_t flag = recvFlag(0);
|
||||
uint64_t* ptr = recvPtr(0)+ll128Offset;
|
||||
uint64_t flag = recvFlag(0);
|
||||
bool needReload;
|
||||
using Vec = uint64_t __attribute__((ext_vector_type(2)));
|
||||
Vec i2;
|
||||
int spins = 0;
|
||||
do {
|
||||
if (wid == 0) STORE(sync, 0);
|
||||
needReload = false;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
asm volatile ("flat_load_dwordx4 %0, %1, glc, slc\n"
|
||||
"s_waitcnt vmcnt(0)\n" : "=v"(i2) : "v"(ptr+u*WARP_SIZE));
|
||||
//load128(ptr+u*WARP_SIZE, v0, v1);
|
||||
needReload |= flagThread && (i2[1] != flag);
|
||||
load128(ptr+u*WARP_SIZE, vr[u], vr[u+1]);
|
||||
needReload |= flagThread && (vr[u+1] != flag);
|
||||
}
|
||||
if (needReload) __atomic_fetch_add(sync, 1, __ATOMIC_SEQ_CST);
|
||||
} while (LOAD(sync) && checkAbort(0, 0) == 0);
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
asm volatile ("flat_load_dwordx4 %0, %1, glc, slc\n"
|
||||
"s_waitcnt vmcnt(0)\n" : "=v"(i2) : "v"(ptr+u*WARP_SIZE));
|
||||
//load128(ptr+u*WARP_SIZE, v0, v1);
|
||||
v[u] = SRC ? MULTI<FUNC, T>()(i2[0], v[u]) : i2[0];
|
||||
v[u+1] = SRC ? MULTI<FUNC, T>()(i2[1], v[u+1]) : i2[1];
|
||||
}
|
||||
} while (__any_sync(WARP_MASK, needReload) && checkAbort(spins, 0, 0) == 0);
|
||||
}
|
||||
|
||||
for (int i=1; i<NRECV && i<nrecv; i++) {
|
||||
uint64_t flag = recvFlag(i);
|
||||
uint64_t* ptr = recvPtr(i)+ll128Offset;
|
||||
Vec i2;
|
||||
do {
|
||||
if (wid == 0) STORE(sync, 0);
|
||||
needReload = 0;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
asm volatile ("flat_load_dwordx4 %0, %1, glc, slc\n"
|
||||
"s_waitcnt vmcnt(0)\n" : "=v"(i2) : "v"(ptr+u*WARP_SIZE));
|
||||
//load128(ptr+u*WARP_SIZE, v0, v1);
|
||||
needReload |= flagThread && (i2[1] != flag);
|
||||
}
|
||||
if (needReload) __atomic_fetch_add(sync, 1, __ATOMIC_SEQ_CST);
|
||||
} while (LOAD(sync) && checkAbort(i, 0) == 0);
|
||||
/************* Finish register load **************/
|
||||
if (SRC) {
|
||||
// By deferring register shuffle here we've overlapped spinning on first
|
||||
// peer's data with memory loads of src data.
|
||||
loadRegsFinish(v);
|
||||
if (SrcBuf == Input) {
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
asm volatile ("flat_load_dwordx4 %0, %1, glc, slc\n"
|
||||
"s_waitcnt vmcnt(0)\n" : "=v"(i2) : "v"(ptr+u*WARP_SIZE));
|
||||
//load128(ptr+u*WARP_SIZE, v0, v1);
|
||||
v[u] = MULTI<FUNC, T>()(i2[0], v[u]);
|
||||
v[u+1] = MULTI<FUNC, T>()(i2[1], v[u+1]);
|
||||
v[u] = MULTI<RedOp, T>().preOp(redOp, v[u]);
|
||||
if (!flagThread)
|
||||
v[u+1] = MULTI<RedOp, T>().preOp(redOp, v[u+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/************************ Recv rest *********************/
|
||||
if (RECV) {
|
||||
{ // Consume data from first recv
|
||||
uint64_t* ptr = recvPtr(0)+ll128Offset;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = SRC ? MULTI<RedOp, T>()(redOp, vr[u], v[u]) : vr[u];
|
||||
v[u+1] = SRC ? MULTI<RedOp, T>()(redOp, vr[u+1], v[u+1]) : vr[u+1];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i=1; i<MaxRecv && i<fan.nrecv(); i++) {
|
||||
uint64_t flag = recvFlag(i);
|
||||
uint64_t* ptr = recvPtr(i)+ll128Offset;
|
||||
bool needReload;
|
||||
int spins = 0;
|
||||
do {
|
||||
needReload = false;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
load128(ptr+u*WARP_SIZE, vr[u], vr[u+1]);
|
||||
needReload |= flagThread && (vr[u+1] != flag);
|
||||
}
|
||||
} while (__any_sync(WARP_MASK, needReload) && checkAbort(spins, i, 0) == 0);
|
||||
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = MULTI<RedOp, T>()(redOp, vr[u], v[u]);
|
||||
v[u+1] = MULTI<RedOp, T>()(redOp, vr[u+1], v[u+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
/********************** End Recv ************************/
|
||||
|
||||
if (postOp && !FuncTraits<RedOp>::IsPostOpIdentity) {
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
v[u] = MULTI<RedOp, T>().postOp(redOp, v[u]);
|
||||
v[u+1] = MULTI<RedOp, T>().postOp(redOp, v[u+1]);
|
||||
}
|
||||
}
|
||||
|
||||
/************************ Send **************************/
|
||||
if (SEND) {
|
||||
for (int i=1; i<NSEND && i<nsend; i++) {
|
||||
for (int i=1; i<MaxSend && i<fan.nsend(); i++) {
|
||||
uint64_t flag = sendFlag(i);
|
||||
uint64_t* ptr = sendPtr(i)+ll128Offset;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
//store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
|
||||
using Vec = uint64_t __attribute__((ext_vector_type(2)));
|
||||
Vec i2;
|
||||
i2[0] = v[u];
|
||||
i2[1] = flagThread ? flag : v[u+1];//
|
||||
asm volatile ("flat_store_dwordx4 %0, %1, glc, slc\n"
|
||||
"s_waitcnt vmcnt(0)\n" : : "v"(ptr+u*WARP_SIZE), "v"(i2));
|
||||
store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
|
||||
}
|
||||
}
|
||||
uint64_t flag = sendFlag(0);
|
||||
uint64_t* ptr = sendPtr(0)+ll128Offset;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
//store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
|
||||
using Vec = uint64_t __attribute__((ext_vector_type(2)));
|
||||
Vec i2;
|
||||
i2[0] = v[u];
|
||||
i2[1] = flagThread ? flag : v[u+1];//
|
||||
asm volatile ("flat_store_dwordx4 %0, %1, glc, slc\n"
|
||||
"s_waitcnt vmcnt(0)\n" : : "v"(ptr+u*WARP_SIZE), "v"(i2));
|
||||
store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
|
||||
}
|
||||
}
|
||||
/********************** End Send ************************/
|
||||
|
||||
/************* Data Storing : REG -> SHMEM **************/
|
||||
if (DST) {
|
||||
volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS;
|
||||
#pragma unroll
|
||||
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
|
||||
shmem64Ptr[u*(WARP_SIZE-2)] = v[u];
|
||||
if (!flagThread) shmem64Ptr[u*(WARP_SIZE-2)+1] = v[u+1];
|
||||
}
|
||||
}
|
||||
/*********** End data Storing : REG -> SHMEM ************/
|
||||
}
|
||||
|
||||
#define LL128INC (WARP_SIZE*NCCL_LL128_SHMEM_ELEMS_PER_THREAD)
|
||||
#define ELEMINC (LL128INC-(LL128INC/NCCL_LL128_LINEELEMS))
|
||||
static constexpr int WireWordPerSlice = WARP_SIZE*NCCL_LL128_SHMEM_ELEMS_PER_THREAD;
|
||||
static constexpr int DataEltPerSlice = (WireWordPerSlice - WireWordPerSlice/NCCL_LL128_LINEELEMS)*(sizeof(uint64_t)/sizeof(T));
|
||||
|
||||
template <int RECV, int SEND, int SRC, int DST>
|
||||
__device__ void GenericOp(const T* srcPtr, T* dstPtr, int nelem) {
|
||||
if (nelem <= 0) {
|
||||
// Don't move any data but still increase steps and sync with prev/next
|
||||
if (SEND) waitSend(0);
|
||||
FOR_SEND(incSend); if (SEND) postSend();
|
||||
FOR_RECV(incRecv); if (RECV) postRecv();
|
||||
return;
|
||||
}
|
||||
const int nelem64 = ((nelem*sizeof(T))/(2*sizeof(uint64_t)))*2;
|
||||
const uint64_t* src64Ptr = ((uint64_t*)srcPtr);
|
||||
uint64_t* dst64Ptr = ((uint64_t*)dstPtr);
|
||||
template <int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
static_assert(-1<=SrcBuf && SrcBuf < 2, "Uhoh");
|
||||
static_assert(-1<=DstBuf && DstBuf < 2, "Uhoh");
|
||||
static_assert(DstBuf!=Input, "Mistake?");
|
||||
#if 0
|
||||
assert((SrcBuf==-1) == (srcIx==-1));
|
||||
assert((DstBuf==-1) == (dstIx==-1));
|
||||
#endif
|
||||
|
||||
int ll128Offset = LL128INC*warp+2*wid;
|
||||
int elemOffset = ELEMINC*warp;
|
||||
T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
|
||||
int wireOffset = WireWordPerSlice*warp + 2*wid;
|
||||
const int nwarps = nthreads/WARP_SIZE;
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
|
||||
if (SEND) waitSend(DIVUP(nelem*sizeof(T), ELEMINC*sizeof(uint64_t))*LL128INC*sizeof(uint64_t));
|
||||
if (SEND) waitSend(divUp(nelem, DataEltPerSlice)*WireWordPerSlice*sizeof(uint64_t));
|
||||
barrier();
|
||||
nelem -= DataEltPerSlice*warp;
|
||||
srcPtr += DataEltPerSlice*warp;
|
||||
dstPtr += DataEltPerSlice*warp;
|
||||
while (nelem > 0) {
|
||||
const int eltInSlice = min(nelem, DataEltPerSlice);
|
||||
uint64_t regs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD];
|
||||
if (SRC) loadRegsBegin(regs, srcPtr, eltInSlice);
|
||||
recvReduceSendCopy<NCCL_LL128_SHMEM_ELEMS_PER_THREAD, RECV, SEND, SrcBuf, DstBuf>(regs, wireOffset, postOp);
|
||||
if (DST) storeRegs(dstPtr, regs, eltInSlice);
|
||||
|
||||
while (elemOffset*(sizeof(uint64_t)/sizeof(T)) < nelem) {
|
||||
const int maxOffset128 = min(nelem64-elemOffset, (int)ELEMINC);
|
||||
const int maxOffset = min(nelem-(elemOffset*((int)(sizeof(uint64_t)/sizeof(T)))), (int)(ELEMINC*(sizeof(uint64_t)/sizeof(T))));
|
||||
if (SRC) {
|
||||
int done = 0;
|
||||
if ((((uint64_t)srcPtr)&0x3) == 0) {
|
||||
loadSrcToShmem128<NCCL_LL128_SHMEM_ELEMS_PER_THREAD>(maxOffset128-2*wid, src64Ptr+elemOffset+2*wid);
|
||||
done = maxOffset128*(sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
loadSrcToShmem(done, maxOffset, (T*)(src64Ptr+elemOffset));
|
||||
}
|
||||
__syncwarp();
|
||||
recvReduceSendCopy<NCCL_LL128_SHMEM_ELEMS_PER_THREAD, RECV, SEND, SRC, DST>(ll128Offset);
|
||||
__syncwarp();
|
||||
if (DST) {
|
||||
int done = 0;
|
||||
if ((((uint64_t)dstPtr)&0x3) == 0) {
|
||||
storeShmemToDst128<NCCL_LL128_SHMEM_ELEMS_PER_THREAD>(maxOffset128-2*wid, dst64Ptr+elemOffset+2*wid);
|
||||
done = maxOffset128*(sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
storeShmemToDst(done, maxOffset, (T*)(dst64Ptr+elemOffset));
|
||||
}
|
||||
__syncwarp();
|
||||
ll128Offset += LL128INC*nwarps;
|
||||
elemOffset += ELEMINC*nwarps;
|
||||
wireOffset += WireWordPerSlice*nwarps;
|
||||
srcPtr += DataEltPerSlice*nwarps;
|
||||
dstPtr += DataEltPerSlice*nwarps;
|
||||
nelem -= DataEltPerSlice*nwarps;
|
||||
}
|
||||
|
||||
barrier();
|
||||
FOR_SEND(incSend); if (SEND) postSend();
|
||||
FOR_RECV(incRecv); if (RECV) postRecv();
|
||||
if (SEND) for (int i=0; i < MaxSend; i++) sendStep[i] += 1;
|
||||
if (SEND) postSend();
|
||||
if (RECV) for (int i=0; i < MaxRecv; i++) recvStep[i] += 1;
|
||||
if (RECV) postRecv();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
|
||||
recvBuff[i] = (uint64_t*)LOAD(conn->buffs+NCCL_PROTO_LL128);
|
||||
recvStep[i] = LOAD(&conn->step);
|
||||
recvBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
|
||||
recvStep[i] = conn->step;
|
||||
if (wid == i) recvConn = conn;
|
||||
nrecv++;
|
||||
}
|
||||
__device__ __forceinline__ void loadRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
recvConnHeadPtr = LOAD(&recvConn->head);
|
||||
recvConnHead = LOAD(&recvConn->step);
|
||||
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv()) {
|
||||
recvConnHeadPtr = recvConn->head;
|
||||
recvConnHead = recvConn->step;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
|
||||
sendBuff[i] = (uint64_t*)LOAD(conn->buffs+NCCL_PROTO_LL128);
|
||||
sendStep[i] = LOAD(&conn->step);
|
||||
sendBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
|
||||
sendStep[i] = conn->step;
|
||||
if (wid == i) sendConn = conn;
|
||||
nsend++;
|
||||
}
|
||||
__device__ __forceinline__ void loadSendSync() {
|
||||
if (tid < nsend) {
|
||||
sendConnHeadPtr = LOAD(&sendConn->head);
|
||||
sendConnHeadCache = LOAD(sendConnHeadPtr);
|
||||
sendConnHead = LOAD(&sendConn->step);
|
||||
sendConnFifoPtr = LOAD(&sendConn->sizesFifo);
|
||||
if (tid < fan.nsend()) {
|
||||
sendConnHeadPtr = sendConn->head;
|
||||
sendConnHeadCache = *sendConnHeadPtr;
|
||||
sendConnHead = sendConn->step;
|
||||
sendConnFifoPtr = sendConn->sizesFifo;
|
||||
}
|
||||
if (tid >= nthreads-WARP_SIZE && wid<nsend) {
|
||||
if (tid >= nthreads-WARP_SIZE && wid<fan.nsend()) {
|
||||
if (sendConn->sizesFifo) {
|
||||
sendConnTailPtr = LOAD(&sendConn->tail);
|
||||
sendConnTail = LOAD(&sendConn->step);
|
||||
sendConnTailPtr = sendConn->tail;
|
||||
sendConnTail = sendConn->step;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveRecvSync() {
|
||||
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
|
||||
STORE(&recvConn->step, recvConnHead);
|
||||
__threadfence_block();
|
||||
public:
|
||||
__device__ Primitives(
|
||||
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0, int connIndex=0
|
||||
):
|
||||
redOp(FuncTraits<RedOp>().make(ncclShmem->comm.nRanks)),
|
||||
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE),
|
||||
flagThread((tid%8)==7), group(group),
|
||||
stepSize(ncclShmem->comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) {
|
||||
|
||||
auto *channel = &ncclShmem->channel;
|
||||
int nrecv=0, nsend=0;
|
||||
while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) {
|
||||
loadRecvConn(&channel->devPeers[recvPeers[nrecv]].recv->conn, nrecv);
|
||||
nrecv++;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void saveSendSync() {
|
||||
if (tid < nsend) {
|
||||
STORE(&sendConn->step, sendConnHead);
|
||||
__threadfence_block();
|
||||
while (nsend < MaxSend && sendPeers[nsend] >= 0) {
|
||||
loadSendConn(&channel->devPeers[sendPeers[nsend]].send->conn, nsend);
|
||||
nsend++;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ __forceinline__
|
||||
ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
|
||||
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem->data+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid), sync(ncclShmem->sync+warp) {
|
||||
// Make sure step is updated before we read it.
|
||||
barrier();
|
||||
|
||||
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv->conn, i);
|
||||
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send->conn, i);
|
||||
this->fan = Fan(nrecv, nsend);
|
||||
loadRecvSync();
|
||||
loadSendSync();
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
}
|
||||
|
||||
__device__ void send(const T* src, int nelem) {
|
||||
return GenericOp<0, 1, 1, 0>(src, NULL, nelem);
|
||||
}
|
||||
|
||||
__device__ void recv(T* dst, int nelem) {
|
||||
return GenericOp<1, 0, 0, 1>(NULL, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceSend(const T* src, int nelem) {
|
||||
return GenericOp<1, 1, 1, 0>(src, NULL, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
|
||||
return GenericOp<1, 0, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void copySend(const T* src, T* dst, int nelem) {
|
||||
return GenericOp<0, 1, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvCopySend(T* dst, int nelem) {
|
||||
return GenericOp<1, 1, 0, 1>(NULL, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
|
||||
return GenericOp<1, 1, 1, 1>(src, dst, nelem);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ ~ncclLL128Primitives() {
|
||||
__device__ ~Primitives() {
|
||||
// Save steps for the next operation
|
||||
saveRecvSync();
|
||||
saveSendSync();
|
||||
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv())
|
||||
recvConn->step = recvConnHead;
|
||||
if (tid < fan.nsend())
|
||||
sendConn->step = sendConnHead;
|
||||
// Ensure all steps written back
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
userBufs[Input] = (T*)inputBuf;
|
||||
userBufs[Output] = (T*)outputBuf;
|
||||
}
|
||||
|
||||
__device__ void moveDataPtrs(intptr_t delta) {
|
||||
userBufs[Input] += delta;
|
||||
userBufs[Output] += delta;
|
||||
}
|
||||
|
||||
__device__ void send(intptr_t inpIx, int eltN) {
|
||||
return GenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void sendFromOutput(intptr_t outIx, int eltN) {
|
||||
return GenericOp<0, 1, Output, -1>(outIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void recv(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<1, 0, -1, Output>(-1, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvReduceSend(intptr_t inpIx, int eltN) {
|
||||
return GenericOp<1, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
__device__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<1, 0, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<0, 1, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<1, 1, -1, Output>(-1, outIx, eltN, postOp);
|
||||
}
|
||||
__device__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
return GenericOp<1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,478 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct,
|
||||
int SlicePerChunk, int StepPerSlice, int Unroll>
|
||||
class Primitives<
|
||||
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, Unroll>
|
||||
> {
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
static constexpr int RoleInput = 0x01,
|
||||
RoleOutput = 0x02,
|
||||
RoleWaitRecv = 0x04,
|
||||
RoleWaitSend = 0x08,
|
||||
RolePostSend = 0x10,
|
||||
RolePostRecv = 0x20,
|
||||
Aborted = 0x40,
|
||||
PtrsFifoEnabled = 0x80,
|
||||
SizesFifoEnabled = 0x100,
|
||||
DirectEnabled = 0x200,
|
||||
ThreadsSynced = 0x400;
|
||||
const int tid;
|
||||
int nthreads;
|
||||
int nworkers;
|
||||
const int stepSize;
|
||||
Fan fan;
|
||||
RedOp const redOp;
|
||||
int index; // Peer index I'm responsible for
|
||||
int flags;
|
||||
int group;
|
||||
uint64_t step;
|
||||
union {
|
||||
void **connPtrsFifoPtr; // (flags & PtrsFifoEnabled)
|
||||
T *userBuff; // (flags & (RoleInput|RoleOutput))
|
||||
T *connEltsFifo; // !(flags & (PtrsFifoEnabled|RoleInput|RoleOutput))
|
||||
};
|
||||
union {
|
||||
int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled)
|
||||
T *directBuff; // !(flags & SizesFifoEnabled)
|
||||
};
|
||||
uint64_t volatile *connStepPtr;
|
||||
uint64_t connStepCache; // Cache last seen value of (*connStepPtr)
|
||||
uint64_t* barriers;
|
||||
uint64_t* barrier_next;
|
||||
const int connIndex;
|
||||
|
||||
// Don't use barrier 0 as it's used by the final sync
|
||||
inline __device__ void barrier() {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if (nthreads == WARP_SIZE)
|
||||
__syncwarp();
|
||||
else
|
||||
barrier_by_group();
|
||||
#else
|
||||
if (nthreads == WARP_SIZE)
|
||||
__syncwarp();
|
||||
else
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(group+1), "r"(nthreads));
|
||||
#endif
|
||||
flags |= ThreadsSynced;
|
||||
}
|
||||
inline __device__ void subBarrier() {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
barrier();
|
||||
#else
|
||||
if (nworkers == nthreads)
|
||||
barrier();
|
||||
else
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(group+2), "r"(nworkers));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ bool checkAbort(int &spins) {
|
||||
spins++;
|
||||
if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
|
||||
flags |= LOAD(ncclShmem->comm.abortFlag) ? Aborted : 0;
|
||||
spins = 0;
|
||||
}
|
||||
return flags & Aborted;
|
||||
}
|
||||
|
||||
template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst>
|
||||
inline __device__ void waitPeer(intptr_t dstIx, intptr_t remoteOutIx, int offset, int nelts) {
|
||||
if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) {
|
||||
bool const isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
|
||||
int spins = 0;
|
||||
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
|
||||
connStepCache = LOAD(connStepPtr);
|
||||
if (checkAbort(spins)) break;
|
||||
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem->comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
|
||||
}
|
||||
|
||||
if (isSendNotRecv && (flags & SizesFifoEnabled))
|
||||
STORE(connSizesFifoPtr+step%NCCL_STEPS, nelts*sizeof(T));
|
||||
|
||||
void **ptrs = isSendNotRecv ? (ncclShmem->groups[group].dsts + Dst)
|
||||
: (ncclShmem->groups[group].srcs + Src);
|
||||
if (flags & PtrsFifoEnabled)
|
||||
loadPtr(connPtrsFifoPtr + step%NCCL_STEPS, ptrs[index]);
|
||||
else if ((isSendNotRecv ? DirectSend : DirectRecv) && (flags & DirectEnabled))
|
||||
ptrs[index] = directBuff + (isSendNotRecv ? remoteOutIx : dstIx) + offset;
|
||||
else
|
||||
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
|
||||
step += StepPerSlice;
|
||||
}
|
||||
}
|
||||
|
||||
template<int Recv, int Send>
|
||||
inline __device__ void postPeer() {
|
||||
if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
|
||||
step += StepPerSlice;
|
||||
STORE(connStepPtr, step);
|
||||
}
|
||||
}
|
||||
|
||||
template <int DirectRecv1, int DirectSend1, int Recv, int Send, int SrcBuf, int DstBuf>
|
||||
inline __device__ void genericOp(
|
||||
intptr_t srcIx, intptr_t dstIx, intptr_t remoteOutIx, int nelem, bool postOp
|
||||
) {
|
||||
constexpr int DirectRecv = 1 && Direct && DirectRecv1;
|
||||
constexpr int DirectSend = 1 && Direct && DirectSend1;
|
||||
constexpr int Src = SrcBuf != -1;
|
||||
constexpr int Dst = DstBuf != -1;
|
||||
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
int sliceSize = stepSize*StepPerSlice;
|
||||
sliceSize = max(divUp(nelem, 16*SlicePerChunk)*16, sliceSize/32);
|
||||
int slice = 0;
|
||||
int offset = 0;
|
||||
|
||||
if (tid < nworkers && offset < nelem) {
|
||||
// Worker-only loop for non-empty slices. Non-workers and empty slices are
|
||||
// processed in the loop following this if block. The benefit of splitting
|
||||
// the loop like this is we pull two branches out of the critical path.
|
||||
// Using "number of branch insns (taken or not) encountered dynamically"
|
||||
// as the performance metric, then:
|
||||
// perf_orig = 2*numslices
|
||||
// perf_new = 2+numslices
|
||||
// So the new code and old code behave the same for numslices=2, and for
|
||||
// numslices>2 the new code is superior. And note that in the case
|
||||
// numslices=1, the loop is trivially unrollable (single iteration) so we
|
||||
// don't incur that that tail branch and we still have perf_new=2.
|
||||
//
|
||||
// ORIGINAL CODE:
|
||||
// unrolled for(slices) {
|
||||
// if(worker) { // This branch removed
|
||||
// wait();
|
||||
// subBarrier();
|
||||
// if(slice not empty) // This branch removed
|
||||
// ReduceCopyMulti();
|
||||
// }
|
||||
// barrier();
|
||||
// post();
|
||||
// } // Since we no longer unroll, new branch added here
|
||||
#pragma unroll 1
|
||||
do {
|
||||
sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset;
|
||||
if (Src && (flags & (SrcBuf==Input ? RoleInput : RoleOutput)))
|
||||
ncclShmem->groups[group].srcs[0] = userBuff + srcIx + offset;
|
||||
if (Dst && (flags & (DstBuf==Input ? RoleInput : RoleOutput)))
|
||||
ncclShmem->groups[group].dsts[0] = userBuff + dstIx + offset;
|
||||
#ifdef ENABLE_PROFILING
|
||||
uint64_t t0 = __builtin_amdgcn_s_memrealtime();
|
||||
#endif
|
||||
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(dstIx, remoteOutIx, offset, sliceSize);
|
||||
subBarrier();
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (tid == 0) __atomic_fetch_add(&ncclShmem->comm.devProf->wait_cycle[blockIdx.x], __builtin_amdgcn_s_memrealtime() - t0, __ATOMIC_SEQ_CST);
|
||||
#endif
|
||||
if (DirectRecv && ncclShmem->groups[group].srcs[0] == ncclShmem->groups[group].dsts[0]) {
|
||||
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
|
||||
if (Send) {
|
||||
// (1-Send) is only there to avoid compilation errors in case MaxSend=0 (and Send=0).
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, (1-Send)+MaxSend>
|
||||
(tid, nworkers, redOp, false, false,
|
||||
1, (T const**)ncclShmem->groups[group].srcs,
|
||||
fan.nsend(), (T**)ncclShmem->groups[group].dsts+1,
|
||||
sliceSize);
|
||||
}
|
||||
} else {
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, Recv+Src, Recv*MaxRecv+Src, Send+Dst, Send*MaxSend+Dst>
|
||||
(tid, nworkers, redOp, SrcBuf==Input, postOp,
|
||||
Recv*fan.nrecv()+Src, (T const**)ncclShmem->groups[group].srcs,
|
||||
Send*fan.nsend()+Dst, (T**)ncclShmem->groups[group].dsts,
|
||||
sliceSize);
|
||||
}
|
||||
barrier(); // This barrier has a counterpart in following loop
|
||||
if (Send && (flags & RolePostSend) && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
postPeer<Recv, Send>();
|
||||
offset += sliceSize;
|
||||
slice += 1;
|
||||
} while (slice < SlicePerChunk && offset < nelem);
|
||||
}
|
||||
|
||||
// Non-workers come straight here. Workers too but only once the remaining
|
||||
// slices are all empty. Since empty slices are the uncommon case, and
|
||||
// worker perf is the limiter, perf-wise this loop is effectively unentered,
|
||||
// hence just a single branch insn.
|
||||
#pragma unroll 1
|
||||
while (slice < SlicePerChunk) {
|
||||
sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset;
|
||||
{ // Only workers could have Wait roles so we know the slice must be empty
|
||||
// since we've exited the loop above.
|
||||
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(0, 0, 0, 0);
|
||||
}
|
||||
barrier(); // Has couterpart in preceding worker-only loop.
|
||||
if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
postPeer<Recv, Send>();
|
||||
offset += sliceSize;
|
||||
slice += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Scatter and gather do not support Direct
|
||||
template <int Recv, int Send>
|
||||
inline __device__ void
|
||||
ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp) {
|
||||
int offset = 0; // slice offset
|
||||
int sliceSize = stepSize*StepPerSlice;
|
||||
int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size
|
||||
|
||||
#pragma unroll 1
|
||||
for (int slice=0; slice<SlicePerChunk; ++slice) {
|
||||
int realSize = max(0, min(dataSize, peerElem-offset));
|
||||
if (tid < nworkers) {
|
||||
if (Send && (flags & RoleInput)) ncclShmem->groups[group].srcs[0] = userBuff + inpIx + offset;
|
||||
if (Recv && (flags & RoleOutput)) ncclShmem->groups[group].dsts[0] = userBuff + outIx + offset;
|
||||
// realSize is not accurate here; but intra-node does not rely on sizes FIFO
|
||||
waitPeer<0, 0, Recv, Send, 0, 0>(0, 0, 0, realSize);
|
||||
subBarrier();
|
||||
if (Send) {
|
||||
#pragma unroll 1
|
||||
for (int j=0; j<fan.nsend(); j++) {
|
||||
int i = (j+shift)%fan.nsend();
|
||||
int peerOffset = i*peerElem;
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
const T* src0 = (T*)ncclShmem->groups[group].srcs[0] + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1>(tid, nworkers, redOp, true, false, 1, &src0, 1, (T**)ncclShmem->groups[group].dsts+i, realPeerSize);
|
||||
}
|
||||
} else if (Recv) {
|
||||
#pragma unroll 1
|
||||
for (int j=0; j<fan.nrecv(); j++) {
|
||||
int i = (j+shift)%fan.nrecv();
|
||||
int peerOffset = i*peerElem;
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
T* dst0 = (T*)ncclShmem->groups[group].dsts[0] + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1>(tid, nworkers, redOp, false, postOp, 1, (T const**)ncclShmem->groups[group].srcs+i, 1, &dst0, realPeerSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (Send && (flags & RolePostSend) && realSize > 0 && index == 0) __threadfence_system();
|
||||
__syncwarp();
|
||||
postPeer<Recv, Send>();
|
||||
offset += realSize;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(ncclPeer *peer) {
|
||||
if (flags & (RoleWaitRecv|RolePostRecv)) {
|
||||
// For other colls: group <= 2, hence always use conn 0
|
||||
// For P2P: Direct is set to 1, hence always use conn 0
|
||||
// Ideally we should be accepting connIndex from the constructor!
|
||||
auto *conn = &peer->recv[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = roundUp(step, SlicePerChunk*StepPerSlice);
|
||||
if (flags & RolePostRecv) {
|
||||
connStepPtr = conn->head;
|
||||
STORE(connStepPtr, step); // Return credits in case we rounded up.
|
||||
}
|
||||
if (flags & RoleWaitRecv) {
|
||||
ncclShmem->groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs()
|
||||
connStepPtr = conn->tail;
|
||||
connStepCache = LOAD(connStepPtr);
|
||||
flags |= (conn->ptrsFifo != nullptr) ? PtrsFifoEnabled : 0;
|
||||
flags |= (Direct && (conn->direct & NCCL_DIRECT_GPU)) ? DirectEnabled : 0;
|
||||
if (flags & PtrsFifoEnabled)
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
else
|
||||
connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadSendConn(ncclPeer *peer) {
|
||||
if (flags & (RoleWaitSend|RolePostSend)) {
|
||||
// For other colls: group <= 2, hence always use conn 0
|
||||
// For P2P: Direct is set to 1, hence always use conn 0
|
||||
// Ideally we should be accepting connIndex from the constructor!
|
||||
auto *conn = &peer->send[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = roundUp(step, SlicePerChunk*StepPerSlice);
|
||||
if (flags & RolePostSend) {
|
||||
connStepPtr = conn->tail;
|
||||
}
|
||||
if (flags & RoleWaitSend) {
|
||||
ncclShmem->groups[group].sendConns[index] = conn; // WaitSend role saves since that's who needs it in setDataPtrs()
|
||||
connStepPtr = conn->head;
|
||||
connStepCache = LOAD(connStepPtr);
|
||||
flags |= (conn->ptrsFifo != nullptr) ? PtrsFifoEnabled : 0;
|
||||
if (flags & PtrsFifoEnabled)
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
else
|
||||
connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
|
||||
|
||||
if (conn->sizesFifo != nullptr) {
|
||||
flags |= SizesFifoEnabled;
|
||||
connSizesFifoPtr = conn->sizesFifo;
|
||||
}
|
||||
else if (Direct && (conn->direct & NCCL_DIRECT_GPU))
|
||||
flags |= DirectEnabled;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
__device__ Primitives(
|
||||
int tid, int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0, int connIndex=0
|
||||
):
|
||||
tid(tid),
|
||||
stepSize(ncclShmem->comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)),
|
||||
redOp(FuncTraits<RedOp>::make(ncclShmem->comm.nRanks)),
|
||||
connIndex((NCCL_MAX_DIRECT_ARITY==Fan::MaxSend || NCCL_MAX_DIRECT_ARITY==Fan::MaxRecv)?(group/2):connIndex),
|
||||
barriers(&ncclShmem->groups[group].barrier), barrier_next(ncclShmem->groups[group].barrier_next) {
|
||||
|
||||
// For send operations, we need an extra warp to overlap the threadfence and the copy
|
||||
this->nthreads = nthreads;
|
||||
this->nworkers = nthreads;
|
||||
this->group = group;
|
||||
|
||||
int nrecv=0, nsend=0;
|
||||
while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++;
|
||||
while (nsend < MaxSend && sendPeers[nsend] != -1) nsend++;
|
||||
this->fan = Fan(nrecv, nsend);
|
||||
|
||||
constexpr int ThreadPerSync = 8;
|
||||
static_assert(MaxSend < ThreadPerSync && MaxRecv < ThreadPerSync, "Not enough threads to cover all peers");
|
||||
|
||||
int g = tid / ThreadPerSync;
|
||||
int ng = nthreads / ThreadPerSync;
|
||||
index = tid % ThreadPerSync;
|
||||
flags = 0;
|
||||
if (g == 0) {
|
||||
if (index < nrecv) flags |= RoleWaitRecv;
|
||||
if (index == nrecv) flags |= RoleInput;
|
||||
} else if (g == 1) {
|
||||
if (index < nsend) flags |= RoleWaitSend;
|
||||
if (index == nsend) flags |= RoleOutput;
|
||||
} else if (g == ng - 2) {
|
||||
if (index < nrecv) flags |= RolePostRecv;
|
||||
} else if (g == ng - 1) {
|
||||
if (index < nsend) flags |= RolePostSend;
|
||||
}
|
||||
|
||||
int peer = 0;
|
||||
if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index];
|
||||
if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index];
|
||||
|
||||
loadRecvConn(&ncclShmem->channel.devPeers[peer]);
|
||||
loadSendConn(&ncclShmem->channel.devPeers[peer]);
|
||||
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
}
|
||||
|
||||
__device__ ~Primitives() {
|
||||
// Ensure ncclShmem->groups[].send/recvConns are available
|
||||
if (!(flags & ThreadsSynced))
|
||||
barrier();
|
||||
// Save steps for the next operation
|
||||
if (flags & (RolePostSend|RolePostRecv)) {
|
||||
auto *conns = (flags & RolePostSend) ? ncclShmem->groups[group].sendConns : ncclShmem->groups[group].recvConns;
|
||||
STORE(&conns[index]->step, step);
|
||||
}
|
||||
// Make sure all threads are done writing back conn->step and done using
|
||||
// ncclShmem->groups[group]
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
if (flags & RoleInput) userBuff = (T*)inputBuf;
|
||||
if (flags & RoleOutput) userBuff = (T*)outputBuf;
|
||||
if (Direct && flags == (flags|RoleWaitRecv|DirectEnabled)) {
|
||||
int spins = 0;
|
||||
void *volatile *slot = ncclShmem->groups[group].recvConns[index]->ptrExchange;
|
||||
// Wait for consumer to consume previous value before trampling it.
|
||||
while (LOAD(slot) != nullptr && !checkAbort(spins));
|
||||
directBuff = (T*)outputBuf;
|
||||
// Encode pointer by XOR'ing against some address they definitely wouldn't send
|
||||
// since we want to allow them sending us nullptr while not colliding with
|
||||
// the empty slot value.
|
||||
*slot = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(outputBuf) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
}
|
||||
if (Direct && flags == (flags|RoleWaitSend|DirectEnabled)) {
|
||||
int spins = 0;
|
||||
void *volatile *slot = ncclShmem->groups[group].sendConns[index]->ptrExchange;
|
||||
void *ptr;
|
||||
while (true) {
|
||||
ptr = LOAD(slot);
|
||||
if (ptr != nullptr || checkAbort(spins)) break;
|
||||
}
|
||||
directBuff = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(ptr) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
*slot = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void moveDataPtrs(intptr_t delta) {
|
||||
if (flags & (RoleInput|RoleOutput))
|
||||
userBuff += delta;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void send(intptr_t inpIx, int eltN) {
|
||||
genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, -1, eltN, false);
|
||||
}
|
||||
__device__ __forceinline__ void sendFromOutput(intptr_t outIx, int eltN) {
|
||||
genericOp<0, 0, 0, 1, Output, -1>(outIx, -1, -1, eltN, false);
|
||||
}
|
||||
__device__ __forceinline__ void directSend(intptr_t inpIx, intptr_t remoteOutIx, int eltN) {
|
||||
genericOp<0, 1, 0, 1, Input, -1>(inpIx, -1, remoteOutIx, eltN, false);
|
||||
}
|
||||
__device__ __forceinline__ void directSendFromOutput(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
genericOp<0, 1, 0, 1, Output, -1>(outIx, -1, remoteOutIx, eltN, false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recv(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 0, -1, Output>(-1, outIx, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directRecv(intptr_t outIx, int eltN) {
|
||||
genericOp<1, 0, 1, 0, -1, Output>(-1, outIx, -1, eltN, /*postOp=*/false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 0, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, -1, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, -1, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
// Direct is only for the send part
|
||||
genericOp<0, 1, 1, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
scatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
gather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp=false) {
|
||||
ScatterGatherOp<1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp);
|
||||
}
|
||||
};
|
||||
@@ -6,148 +6,87 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * REDUCE_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ring->devUserRanks[0];
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclRing *ring = &ncclShmem->channel.ring;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCE_CHUNKSTEPS : 1));
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
|
||||
const int nranks = ncclShmem->comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = ncclShmem->comm.rank;
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex);
|
||||
|
||||
ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex));
|
||||
auto calcChunkSize = [&]__device__(ssize_t gridOffset)->int {
|
||||
int realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels));
|
||||
realChunkSize = roundUp(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
else if (Proto::Id == NCCL_PROTO_LL)
|
||||
realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize;
|
||||
else if (Proto::Id == NCCL_PROTO_LL128)
|
||||
realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize);
|
||||
return realChunkSize;
|
||||
};
|
||||
|
||||
if (prevRank == root) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
int realChunkSize = calcChunkSize(gridOffset);
|
||||
ssize_t offset = gridOffset + bid*realChunkSize;
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
if (prevRank == root) {
|
||||
prims.send(thisInput+offset, nelem);
|
||||
} else if (rank == root) {
|
||||
prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else {
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
prims.send(offset, nelem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = comm->rank;
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
else if (rank == root) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
if (size-gridOffset < loopSize) {
|
||||
chunkSize = args->coll.lastChunkSize;
|
||||
}
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (prevRank == root) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else if (rank == root) {
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
int realChunkSize = calcChunkSize(gridOffset);
|
||||
ssize_t offset = gridOffset + bid*realChunkSize;
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
const int rank = comm->rank;
|
||||
const int prevRank = ring->devUserRanks[nranks-1];
|
||||
const int root = args->coll.root;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
else {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
|
||||
ssize_t offset = gridOffset + bid*chunkSize;
|
||||
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
if (prevRank == root) {
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
} else if (rank == root) {
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
|
||||
} else {
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
int realChunkSize = calcChunkSize(gridOffset);
|
||||
ssize_t offset = gridOffset + bid*realChunkSize;
|
||||
int nelem = min(realChunkSize, size-offset);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduce, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "common_kernel.h"
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
template<typename T>
|
||||
struct FuncNull {
|
||||
@@ -19,203 +20,6 @@ struct FuncNull {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
|
||||
//we really don't need any specializations and we don't need
|
||||
//to break things into uint32_t
|
||||
template<typename T>
|
||||
__device__ inline T ncclMinFunc(T x, T y) { return y < x ? y : x; }
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T ncclMaxFunc(T x, T y) { return y < x ? x : y; }
|
||||
|
||||
template<typename T>
|
||||
class FuncBase {
|
||||
protected:
|
||||
static constexpr auto n = sizeof(PackType) / sizeof(T);
|
||||
|
||||
union Cvt {
|
||||
using Vec = T __attribute__((ext_vector_type(n)));
|
||||
|
||||
PackType data;
|
||||
Vec vec;
|
||||
|
||||
static_assert(sizeof(Vec) == sizeof(data), "Vec must be the same size of data.");
|
||||
};
|
||||
};
|
||||
|
||||
template<>
|
||||
class FuncBase<half> {
|
||||
protected:
|
||||
static constexpr auto n = sizeof(PackType) / sizeof(_Float16);
|
||||
union Cvt {
|
||||
using Vec = _Float16 __attribute__((ext_vector_type(n)));
|
||||
|
||||
PackType data;
|
||||
Vec vec;
|
||||
|
||||
static_assert(sizeof(Vec) == sizeof(data), "Vec must be the same size of data.");
|
||||
};
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct FuncSum : private FuncBase<T> {
|
||||
__device__ PackType operator()(PackType x, PackType y) const
|
||||
{
|
||||
using Cvt = typename FuncBase<T>::Cvt;
|
||||
|
||||
Cvt tmp_x{x};
|
||||
tmp_x.vec += Cvt{y}.vec;
|
||||
|
||||
return tmp_x.data;
|
||||
}
|
||||
template<typename U = T, typename std::enable_if<!std::is_same<T, U>{}>* = nullptr>
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct FuncProd : private FuncBase<T> {
|
||||
__device__ PackType operator()(PackType x, PackType y) const
|
||||
{
|
||||
using Cvt = typename FuncBase<T>::Cvt;
|
||||
|
||||
Cvt tmp_x{x};
|
||||
tmp_x.vec *= Cvt{y}.vec;
|
||||
|
||||
return tmp_x.data;
|
||||
}
|
||||
template<typename U = T, typename std::enable_if<!std::is_same<T, U>{}>* = nullptr>
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct FuncMax : private FuncBase<T> {
|
||||
__device__ PackType operator()(PackType x, PackType y) const
|
||||
{
|
||||
using Cvt = typename FuncBase<T>::Cvt;
|
||||
|
||||
Cvt tmp_x{x};
|
||||
Cvt tmp_y{y};
|
||||
|
||||
for (auto i = 0u; i != FuncBase<T>::n; ++i) {
|
||||
tmp_x.vec[i] = ncclMaxFunc(tmp_x.vec[i], tmp_y.vec[i]);
|
||||
}
|
||||
|
||||
return tmp_x.data;
|
||||
}
|
||||
template<typename U = T, typename std::enable_if<!std::is_same<T, U>{}>* = nullptr>
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return (x < y) ? y : x;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct FuncMin : private FuncBase<T> {
|
||||
__device__ PackType operator()(PackType x, PackType y) const
|
||||
{
|
||||
using Cvt = typename FuncBase<T>::Cvt;
|
||||
|
||||
Cvt tmp_x{x};
|
||||
Cvt tmp_y{y};
|
||||
|
||||
for (auto i = 0u; i != FuncBase<T>::n; ++i) {
|
||||
tmp_x.vec[i] = ncclMinFunc(tmp_x.vec[i], tmp_y.vec[i]);
|
||||
}
|
||||
|
||||
return tmp_x.data;
|
||||
}
|
||||
template<typename U = T, typename std::enable_if<!std::is_same<T, U>{}>* = nullptr>
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return (x < y) ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncSum<rccl_bfloat16> {
|
||||
static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16);
|
||||
__device__ PackType operator()(PackType x, PackType y) const
|
||||
{
|
||||
union converter { PackType storage; rccl_bfloat16 vec[n]; };
|
||||
static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter.");
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
for (auto i = 0u; i != n; ++i) {
|
||||
cr.vec[i] = cx.vec[i] + cy.vec[i];
|
||||
}
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncProd<rccl_bfloat16> {
|
||||
static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16);
|
||||
__device__ PackType operator()(PackType x, PackType y) const
|
||||
{
|
||||
union converter { PackType storage; rccl_bfloat16 vec[n]; };
|
||||
static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter.");
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
for (auto i = 0u; i != n; ++i) {
|
||||
cr.vec[i] = cx.vec[i] * cy.vec[i];
|
||||
}
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncMax<rccl_bfloat16> {
|
||||
static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16);
|
||||
__device__ PackType operator()(PackType x, PackType y) const
|
||||
{
|
||||
union converter { PackType storage; rccl_bfloat16 vec[n]; };
|
||||
static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter.");
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
for (auto i = 0u; i != n; ++i) {
|
||||
cr.vec[i] = cx.vec[i] < cy.vec[i] ? cy.vec[i] : cx.vec[i];
|
||||
}
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
return x < y ? y : x;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncMin<rccl_bfloat16> {
|
||||
static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16);
|
||||
__device__ PackType operator()(PackType x, PackType y) const
|
||||
{
|
||||
union converter { PackType storage; rccl_bfloat16 vec[n]; };
|
||||
static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter.");
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
for (auto i = 0u; i != n; ++i) {
|
||||
cr.vec[i] = cx.vec[i] < cy.vec[i] ? cx.vec[i] : cy.vec[i];
|
||||
}
|
||||
return cr.storage;
|
||||
}
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template<typename T>
|
||||
struct FuncSum {
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
@@ -244,17 +48,29 @@ struct FuncMin {
|
||||
}
|
||||
};
|
||||
|
||||
#define MASK0 0x00ff00ff
|
||||
#define MASK1 0xff00ff00
|
||||
template<typename Fn>
|
||||
struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max
|
||||
static constexpr bool IsPreOpIdentity = true;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
|
||||
__device__ static Fn make(int rankN) { return Fn(); }
|
||||
template<typename T>
|
||||
__device__ static T preOp(Fn, T x) { return x; }
|
||||
template<typename T>
|
||||
__device__ static T postOp(Fn, T x) { return x; }
|
||||
};
|
||||
|
||||
#define NCCL_MASK0 0x00ff00ff
|
||||
#define NCCL_MASK1 0xff00ff00
|
||||
static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) {
|
||||
/* This can be used both for signed and unsigned 8-bit addition */
|
||||
const uint32_t x0 = x & MASK0;
|
||||
const uint32_t x1 = x & MASK1;
|
||||
const uint32_t y0 = y & MASK0;
|
||||
const uint32_t y1 = y & MASK1;
|
||||
const uint32_t x0 = x & NCCL_MASK0;
|
||||
const uint32_t x1 = x & NCCL_MASK1;
|
||||
const uint32_t y0 = y & NCCL_MASK0;
|
||||
const uint32_t y1 = y & NCCL_MASK1;
|
||||
const uint32_t r0 = (x0+y0);
|
||||
const uint32_t r1 = (x1+y1);
|
||||
return (r0 & MASK0) | (r1 & MASK1);
|
||||
return (r0 & NCCL_MASK0) | (r1 & NCCL_MASK1);
|
||||
}
|
||||
|
||||
template<>
|
||||
@@ -437,6 +253,19 @@ struct FuncSum<half> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncSum<rccl_bfloat16> {
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hadd(x, y);
|
||||
#else
|
||||
return x + y;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncProd<half> {
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
@@ -460,6 +289,19 @@ struct FuncProd<half> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncProd<rccl_bfloat16> {
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmul(x, y);
|
||||
#else
|
||||
return x * y;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMax<half> {
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
@@ -479,6 +321,19 @@ struct FuncMax<half> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncMax<rccl_bfloat16> {
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmax(x, y);
|
||||
#else
|
||||
return x < y ? y : x;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMin<half> {
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
@@ -498,6 +353,226 @@ struct FuncMin<half> {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncMin<rccl_bfloat16> {
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmin(x, y);
|
||||
#else
|
||||
return x < y ? x : y;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMax<float> {
|
||||
__device__ float operator()(float x, float y) const {
|
||||
return fmaxf(x, y);
|
||||
}
|
||||
};
|
||||
template<>
|
||||
struct FuncMin<float> {
|
||||
__device__ float operator()(float x, float y) const {
|
||||
return fminf(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncMax<double> {
|
||||
__device__ double operator()(double x, double y) const {
|
||||
return fmax(x, y);
|
||||
}
|
||||
};
|
||||
template<>
|
||||
struct FuncMin<double> {
|
||||
__device__ double operator()(double x, double y) const {
|
||||
return fmin(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct FuncAvg: FuncSum<T> {
|
||||
static_assert(!std::is_floating_point<T>::value, "Uhoh");
|
||||
static constexpr bool IsPreOpIdentity = true;
|
||||
static constexpr bool IsPostOpIdentity = false;
|
||||
int n;
|
||||
|
||||
template<typename ...Arg>
|
||||
__device__ FuncAvg(int n): n(n) {}
|
||||
|
||||
__device__ T preOp(T x) const {
|
||||
return x;
|
||||
}
|
||||
__device__ T postOp(T x) const {
|
||||
return T(x/n);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<double>: FuncSum<double> {
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
double rcp;
|
||||
__device__ FuncAvg(int n) {
|
||||
rcp = __drcp_rn(double(n));
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ double preOp(double x) const {
|
||||
return IsPreOpIdentity ? x : x*rcp;
|
||||
}
|
||||
__device__ double postOp(double x) const {
|
||||
return IsPostOpIdentity ? x : x*rcp;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<float>: FuncSum<float> {
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
float rcp;
|
||||
__device__ FuncAvg(int n) {
|
||||
rcp = __frcp_rn(float(n));
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ float preOp(float x) const {
|
||||
return IsPreOpIdentity ? x : x*rcp;
|
||||
}
|
||||
__device__ float postOp(float x) const {
|
||||
return IsPostOpIdentity ? x : x*rcp;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<half>: FuncSum<half> {
|
||||
// Change these to switch between all prescale, all postscale, or both by sqrt(N).
|
||||
// Obviously, the only invalid combination is both true. An improvement would be
|
||||
// make this parameterized as a build time setting and passed here through
|
||||
// preprocessor definitions.
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
|
||||
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
|
||||
half2 scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale.x = __float2half(__frsqrt_rn(float(n)));
|
||||
else
|
||||
scale.x = __float2half(__frcp_rn(float(n)));
|
||||
scale.y = scale.x;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ half preOp(half x) const {
|
||||
return IsPreOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
__device__ half2 preOp(half2 x) const {
|
||||
return IsPreOpIdentity ? x : __hmul2(x, scale);
|
||||
}
|
||||
__device__ half postOp(half x) const {
|
||||
return IsPostOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
__device__ half2 postOp(half2 x) const {
|
||||
return IsPostOpIdentity ? x : __hmul2(x, scale);
|
||||
}
|
||||
#else
|
||||
float scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale = __frsqrt_rn(float(n));
|
||||
else
|
||||
scale = __frcp_rn(float(n));
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ half preOp(half x) const {
|
||||
return IsPreOpIdentity ? x : __float2half(__half2float(x)*scale);
|
||||
}
|
||||
__device__ half2 preOp(half2 x) const {
|
||||
if (IsPreOpIdentity)
|
||||
return x;
|
||||
else {
|
||||
float2 a = __half22float2(x);
|
||||
a.x *= scale;
|
||||
a.y *= scale;
|
||||
return __float22half2_rn(a);
|
||||
}
|
||||
}
|
||||
__device__ half postOp(half x) const {
|
||||
return IsPostOpIdentity ? x : __float2half(__half2float(x)*scale);
|
||||
}
|
||||
__device__ half2 postOp(half2 x) const {
|
||||
if (IsPostOpIdentity)
|
||||
return x;
|
||||
else {
|
||||
float2 a = __half22float2(x);
|
||||
a.x *= scale;
|
||||
a.y *= scale;
|
||||
return __float22half2_rn(a);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncAvg<rccl_bfloat16>: FuncSum<rccl_bfloat16> {
|
||||
// Change these to switch between all prescale, all postscale, or both by sqrt(N).
|
||||
// Obviously, the only invalid combination is both true. An improvement would be
|
||||
// make this parameterized as a build time setting and passed here through
|
||||
// preprocessor definitions.
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale.x = __float2bfloat16(__frsqrt_rn(float(n)));
|
||||
else
|
||||
scale.x = __float2bfloat16(__frcp_rn(float(n)));
|
||||
scale.y = scale.x;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ rccl_bfloat16 preOp(rccl_bfloat16 x) const {
|
||||
return IsPreOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
__device__ rccl_bfloat16 postOp(rccl_bfloat16 x) const {
|
||||
return IsPostOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
#else
|
||||
float scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale = __frsqrt_rn(float(n));
|
||||
else
|
||||
scale = __frcp_rn(float(n));
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ rccl_bfloat16 preOp(rccl_bfloat16 x) const {
|
||||
return IsPreOpIdentity ? x : (rccl_bfloat16)(x*scale);
|
||||
}
|
||||
__device__ rccl_bfloat16 postOp(rccl_bfloat16 x) const {
|
||||
return IsPostOpIdentity ? x : (rccl_bfloat16)(x*scale);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
struct FuncTraits<FuncAvg<T>> {
|
||||
static constexpr bool IsPreOpIdentity = FuncAvg<T>::IsPreOpIdentity;
|
||||
static constexpr bool IsPostOpIdentity = FuncAvg<T>::IsPostOpIdentity;
|
||||
|
||||
__device__ static FuncAvg<T> make(int rankN) {
|
||||
return FuncAvg<T>(rankN);
|
||||
}
|
||||
template<typename U>
|
||||
__device__ static U preOp(FuncAvg<T> fn, U x) {
|
||||
return fn.preOp(x);
|
||||
}
|
||||
template<typename U>
|
||||
__device__ static U postOp(FuncAvg<T> fn, U x) {
|
||||
return fn.postOp(x);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // REDUCE_KERNEL_H_
|
||||
|
||||
@@ -6,192 +6,85 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
|
||||
const int chunkSize = stepSize * REDUCESCATTER_CHUNKSTEPS;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
ncclRing *ring = &ncclShmem->channel.ring;
|
||||
int const *ringRanks = ring->devUserRanks;
|
||||
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCESCATTER_CHUNKSTEPS : 1));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2);
|
||||
const int nranks = ncclShmem->comm.nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex);
|
||||
|
||||
ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, 0, FUNC>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex));
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
|
||||
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
ssize_t chunkOffset = gridOffset + bid*realChunkSize;
|
||||
|
||||
/////////////// begin ReduceScatter steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[nranks-1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
prims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
prims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final result
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
prims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
||||
realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels));
|
||||
realChunkSize = roundUp(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
||||
}
|
||||
}
|
||||
};
|
||||
else if (Proto::Id == NCCL_PROTO_LL)
|
||||
realChunkSize = size-gridOffset < loopSize ? args->coll.lastChunkSize : chunkSize;
|
||||
else if (Proto::Id == NCCL_PROTO_LL128)
|
||||
realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize);
|
||||
realChunkSize = int(realChunkSize);
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
ssize_t chunkOffset = gridOffset + bid*int(realChunkSize);
|
||||
|
||||
ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
|
||||
/////////////// begin ReduceScatter steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(realChunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ringRanks[nranks-1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
prims.send(offset, nelem);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
if (size-gridOffset < loopSize) {
|
||||
chunkSize = args->coll.lastChunkSize;
|
||||
}
|
||||
ssize_t chunkOffset = gridOffset + bid*chunkSize;
|
||||
|
||||
/////////////// begin ReduceScatter steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(chunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[nranks-1];
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
rankDest = ringRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final result
|
||||
rankDest = ringRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
prims.recvReduceCopy(offset, chunkOffset, nelem, /*postOp=*/true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
};
|
||||
|
||||
#include "prims_ll128.h"
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
const int nChannels = args->coll.nChannels;
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
struct ncclRing* ring = &channel->ring;
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
|
||||
ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
|
||||
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
||||
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
|
||||
const int nranks = comm->nRanks;
|
||||
const ssize_t loopSize = nChannels*chunkSize;
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
|
||||
|
||||
// Compute pointers
|
||||
const T * __restrict__ thisInput = (const T*)args->sendbuff;
|
||||
T * __restrict__ thisOutput = (T*)args->recvbuff;
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
|
||||
|
||||
ssize_t chunkOffset = gridOffset + bid*chunkSize;
|
||||
|
||||
/////////////// begin ReduceScatter steps ///////////////
|
||||
ssize_t offset;
|
||||
int nelem = min(chunkSize, size-chunkOffset);
|
||||
int rankDest;
|
||||
|
||||
// step 0: push data to next GPU
|
||||
rankDest = ring->devUserRanks[nranks-1];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.send(thisInput+offset, nelem);
|
||||
|
||||
// k-2 steps: reduce and copy to next GPU
|
||||
for (int j=2; j<nranks; ++j) {
|
||||
rankDest = ring->devUserRanks[nranks-j];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvReduceSend(thisInput+offset, nelem);
|
||||
}
|
||||
|
||||
// step k-1: reduce this buffer and data, which will produce the final
|
||||
// result that we store in this data
|
||||
rankDest = ring->devUserRanks[0];
|
||||
offset = chunkOffset + rankDest * size;
|
||||
|
||||
LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
|
||||
}
|
||||
}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
|
||||
};
|
||||
|
||||
template<int PROTO, class REDOP, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -6,89 +6,87 @@
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "primitives.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
template<class FUNC, typename T, int UNROLL>
|
||||
class ncclFunction<ncclFuncSendRecv, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
|
||||
public:
|
||||
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* firstArgs) {
|
||||
struct ncclWorkElem* args = firstArgs;
|
||||
int tid = threadIdx.x;
|
||||
int group = 0;
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s++) {
|
||||
int nThreadsSegment = args->p2p.nThreads;
|
||||
if (nThreadsSegment == 0) return; // Nothing else to do
|
||||
int groupRecv = group;
|
||||
group += 1;
|
||||
int groupSend = group;
|
||||
group += 1;
|
||||
if (tid < nThreadsSegment) {
|
||||
const int nThreads = nThreadsSegment;
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void run(ncclWork *work) {
|
||||
int tid = threadIdx.x;
|
||||
int group = 0;
|
||||
const int rank = ncclShmem->comm.rank;
|
||||
const int nRanks = ncclShmem->comm.nRanks;
|
||||
using Proto = ProtoSimple<1, 1>;
|
||||
|
||||
// Compute pointers
|
||||
const T* sendbuff = (const T*)args->sendbuff;
|
||||
T* recvbuff = (T*)args->recvbuff;
|
||||
const ssize_t sendCount = args->p2p.sendCount;
|
||||
const ssize_t recvCount = args->p2p.recvCount;
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s++) {
|
||||
ncclWorkElem *args = &work->elems[s];
|
||||
int nThreadsSegment = args->p2p.nThreads;
|
||||
if (args->active == 0 || nThreadsSegment == 0) break;
|
||||
|
||||
const int delta = args->p2p.delta;
|
||||
if (delta == 0) {
|
||||
if (tid < nThreads && sendbuff != recvbuff) {
|
||||
// local copy : ReduceOrCopyMulti takes an int as number of elements,
|
||||
// so we split it in blocks of 1G elements.
|
||||
int blockSize = 1<<30;
|
||||
for (size_t offset=0; offset<sendCount; offset += blockSize) {
|
||||
size_t remaining = sendCount - offset;
|
||||
if (remaining < blockSize) blockSize = remaining;
|
||||
ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nThreads, 1, &sendbuff, 1, &recvbuff, blockSize);
|
||||
sendbuff += blockSize; recvbuff += blockSize;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
struct ncclDevComm* comm = args->comm;
|
||||
struct ncclChannel* channel = comm->channels+blockIdx.x;
|
||||
int nThreadsSplit = nThreadsSegment/2;
|
||||
int groupRecv = group;
|
||||
group += Proto::calcGroupWidth(/*send=*/false, nThreadsSplit);
|
||||
int groupSend = group;
|
||||
group += Proto::calcGroupWidth(/*send=*/true, nThreadsSegment - nThreadsSplit);
|
||||
|
||||
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/(sizeof(T)*NCCL_STEPS);
|
||||
if (tid < nThreadsSegment) {
|
||||
// Compute pointers
|
||||
T const* sendbuff = (const T*)args->sendbuff;
|
||||
T* recvbuff = (T*)args->recvbuff;
|
||||
ssize_t const sendCount = args->p2p.sendCount;
|
||||
ssize_t const recvCount = args->p2p.recvCount;
|
||||
int const delta = args->p2p.delta;
|
||||
|
||||
int nThreadsSplit = nThreads/2;
|
||||
if ((tid < nThreadsSplit) && recvCount >= 0) {
|
||||
const int chunkSize = args->p2p.recvChunkSize/sizeof(T);
|
||||
int peer = (comm->rank-delta+comm->nRanks)%comm->nRanks;
|
||||
int nt = nThreadsSplit;
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 1, 0, 1, FUNC>
|
||||
prims(tid, nt, &peer, NULL, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(groupRecv, args->p2p.recvIdx));
|
||||
|
||||
if (recvCount == 0) {
|
||||
prims.recv(recvbuff, 0);
|
||||
} else for (ssize_t offset = 0; offset < recvCount; offset += chunkSize) {
|
||||
int realChunkSize = min(chunkSize, recvCount-offset);
|
||||
ALIGN_SIZE(realChunkSize, nt*sizeof(uint64_t)/sizeof(T));
|
||||
int nelem = min(realChunkSize, recvCount-offset);
|
||||
prims.directRecv(recvbuff+offset, offset, nelem);
|
||||
}
|
||||
}
|
||||
if ((tid >= nThreadsSplit) && sendCount >= 0) {
|
||||
const int chunkSize = args->p2p.sendChunkSize/sizeof(T);
|
||||
int peer = (comm->rank+delta)%comm->nRanks;
|
||||
int nt = nThreads-nThreadsSplit;
|
||||
ncclPrimitives<UNROLL, 1, 1, T, 0, 1, 1, FUNC>
|
||||
prims(tid-nThreadsSplit, nt, NULL, &peer, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(groupSend, args->p2p.sendIdx));
|
||||
|
||||
if (sendCount == 0) {
|
||||
prims.send(sendbuff, 0);
|
||||
} else for (ssize_t offset = 0; offset < sendCount; offset += chunkSize) {
|
||||
int realChunkSize = min(chunkSize, sendCount-offset);
|
||||
ALIGN_SIZE(realChunkSize, nt*sizeof(uint64_t)/sizeof(T));
|
||||
int nelem = min(realChunkSize, sendCount-offset);
|
||||
prims.directSend(sendbuff+offset, offset, nelem);
|
||||
}
|
||||
if (delta == 0) {
|
||||
if (sendbuff != recvbuff) {
|
||||
// local copy : ReduceOrCopyMulti takes an int as number of elements,
|
||||
// so we split it in blocks of 1G elements.
|
||||
int blockSize = 1<<30;
|
||||
for (size_t offset=0; offset<sendCount; offset += blockSize) {
|
||||
size_t remaining = sendCount - offset;
|
||||
if (remaining < blockSize) blockSize = remaining;
|
||||
ReduceOrCopyMulti<COLL_UNROLL, RedOp, T, 1, 1, 1, 1>(tid, nThreadsSegment, RedOp(), false, false, 1, &sendbuff, 1, &recvbuff, blockSize);
|
||||
sendbuff += blockSize;
|
||||
recvbuff += blockSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
tid -= nThreadsSegment;
|
||||
if (tid < 0) return;
|
||||
args++;
|
||||
else {
|
||||
if ((tid < nThreadsSplit) && recvCount >= 0) {
|
||||
int const peer = (rank - delta + nRanks)%nRanks;
|
||||
int const t0 = 0;
|
||||
int const nt = nThreadsSplit;
|
||||
int const chunkSize = args->p2p.recvChunkSize/sizeof(T);
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, 0, Proto> prims
|
||||
(tid-t0, nt, &peer, nullptr, nullptr, recvbuff, groupRecv, args->p2p.recvIdx);
|
||||
ssize_t offset = 0;
|
||||
do {
|
||||
int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T)));
|
||||
nelem = min(chunkSize, recvCount-offset);
|
||||
prims.directRecv(offset, nelem);
|
||||
offset += nelem;
|
||||
} while(offset < recvCount);
|
||||
}
|
||||
|
||||
if ((tid >= nThreadsSplit) && sendCount >= 0) {
|
||||
int const peer = (rank + delta)%nRanks;
|
||||
int const t0 = nThreadsSplit;
|
||||
int const nt = nThreadsSegment - nThreadsSplit;
|
||||
int const chunkSize = args->p2p.sendChunkSize/sizeof(T);
|
||||
Primitives<T, RedOp, FanAsymmetric<0, 1>, 0, Proto> prims
|
||||
(tid-t0, nt, nullptr, &peer, sendbuff, nullptr, groupSend, args->p2p.sendIdx);
|
||||
ssize_t offset = 0;
|
||||
do {
|
||||
int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T)));
|
||||
nelem = min(chunkSize, sendCount-offset);
|
||||
prims.directSend(offset, offset, nelem);
|
||||
offset += nelem;
|
||||
} while(offset < sendCount);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
tid -= nThreadsSegment;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -24,6 +24,31 @@
|
||||
NCCL_FUNC5(func, RING, redop, type), \
|
||||
NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int32_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint32_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int64_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint64_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, half), \
|
||||
(void*)NCCL_FUNC4(func, redop, float), \
|
||||
(void*)NCCL_FUNC4(func, redop, double), \
|
||||
(void*)NCCL_FUNC4(func, redop, __nv_bfloat16)
|
||||
#define NCCL_FUNCS3B(func, redop) \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t)
|
||||
#else
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
@@ -47,16 +72,19 @@
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t)
|
||||
#endif
|
||||
|
||||
// Must be consistent with ncclRedOp_t -- but we only generate kernel for sums.
|
||||
#define NCCL_FUNCS2A(func) \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum)
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
|
||||
@@ -112,11 +140,11 @@ static ncclResult_t getNextOp(struct ncclChannel* channel, struct ncclWork** wor
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
struct ncclWorkElem* e = w->elems;
|
||||
volatile uint8_t* activePtr = (volatile uint8_t*)&e->active;
|
||||
while (LOAD(activePtr) != 0) sched_yield();
|
||||
while (activePtr[0] != 0) sched_yield();
|
||||
memset(w, 0, sizeof(struct ncclWork));
|
||||
// Initialize with work elem if provided
|
||||
if (base) memcpy(e, base, sizeof(struct ncclWorkElem));
|
||||
STORE(&e->active, 1);
|
||||
e->active = 1;
|
||||
e->index = opIndex;
|
||||
channel->workFifoTail++;
|
||||
channel->workCount++;
|
||||
@@ -151,19 +179,18 @@ static ncclResult_t setupLaunch(struct ncclQueueInfo* eqInfo, int usingCudaGraph
|
||||
e->funcIndex = FUNC_INDEX_P2P;
|
||||
e->p2p.nThreads = 0;
|
||||
}
|
||||
STORE(&channel->workFifo[(channel->workFifoTail-1)%NCCL_MAX_OPS].elems[0].active, 2);
|
||||
channel->workFifo[(channel->workFifoTail-1)%NCCL_MAX_OPS].elems[0].active = 2;
|
||||
|
||||
if (c == 0) {
|
||||
// Find the first operation, choose the kernel accordingly and pass it as the first argument.
|
||||
// Note that changing cuda launch argument after capture is not supported by cudaGraph
|
||||
// As we inline the first coll directly, we can free it immediately.
|
||||
// Except P2P or aggregation cases
|
||||
struct ncclWork* work = channel->workFifo+((channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS);
|
||||
struct ncclWorkElem* elem = work->elems;
|
||||
if (!usingCudaGraph) {
|
||||
params->func = (void *)ncclKerns[0];
|
||||
memcpy(&comm->args, elem, sizeof(struct ncclWorkElem));
|
||||
}
|
||||
// As we inline the first coll directly, we can free it immediately.
|
||||
if (elem->funcIndex != FUNC_INDEX_P2P) elem->active = 0;
|
||||
if (elem->funcIndex != FUNC_INDEX_P2P && eqInfo->elemList->count() == 1) elem->active = 0;
|
||||
}
|
||||
|
||||
if (channel->gdrMemDesc) {
|
||||
@@ -186,7 +213,7 @@ static ncclResult_t setupLaunch(struct ncclQueueInfo* eqInfo, int usingCudaGraph
|
||||
|
||||
ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast) {
|
||||
volatile int* ptr = (volatile int*)(comm->intraBarrier+comm->intraPhase);
|
||||
int val = LOAD(ptr);
|
||||
int val = *ptr;
|
||||
bool done = false;
|
||||
while (done == false) {
|
||||
if (val >= comm->intraRanks) {
|
||||
@@ -208,7 +235,7 @@ ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast) {
|
||||
|
||||
ncclResult_t ncclCpuBarrierLast(struct ncclComm* comm) {
|
||||
volatile int* ptr = (volatile int*)(comm->intraBarrier+comm->intraPhase);
|
||||
int val = LOAD(ptr);
|
||||
int val = *ptr;
|
||||
if (__sync_bool_compare_and_swap(ptr, val, val+1) != true) {
|
||||
WARN("Trying to launch too many work elements, max is %d", NCCL_MAX_OPS);
|
||||
return ncclInternalError;
|
||||
@@ -218,7 +245,7 @@ ncclResult_t ncclCpuBarrierLast(struct ncclComm* comm) {
|
||||
|
||||
ncclResult_t ncclCpuBarrierOut(struct ncclComm* comm) {
|
||||
volatile int* ptr = (volatile int*)(comm->intraBarrier+comm->intraPhase);
|
||||
while (LOAD(ptr) < comm->intraRanks) pthread_yield();
|
||||
while (*ptr < comm->intraRanks) pthread_yield();
|
||||
comm->intraPhase ^= 1;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -292,6 +319,7 @@ static ncclResult_t ncclLaunchProxy(struct ncclQueueInfo* eqInfo) {
|
||||
for (int r=0; r<eqInfo->maxChannels; r++) {
|
||||
struct ncclChannel* channel = comm->channels+r;
|
||||
channel->workCount = 0;
|
||||
channel->totalSize = 0;
|
||||
}
|
||||
comm->lastChannel = 0;
|
||||
NCCLCHECK(ncclProxyStart(comm));
|
||||
@@ -323,8 +351,7 @@ ncclResult_t ncclLaunchReset(ncclComm_t comm) {
|
||||
// But we need to keep the current enqueue info for CUDA graph
|
||||
// Thus we need to creating a new enqueue info for the next run
|
||||
if (comm->usingCudaGraph) {
|
||||
NCCLCHECK(ncclCalloc(&comm->enqueueInfo, 1));
|
||||
comm->enqueueInfo->comm = comm;
|
||||
NCCLCHECK(ncclCreateQueueInfo(&comm->enqueueInfo, comm));
|
||||
} else {
|
||||
// If not in CUDA graph mode, we reuse the same info space
|
||||
NCCLCHECK(ncclResetQueueInfo(comm->enqueueInfo));
|
||||
@@ -346,23 +373,29 @@ ncclResult_t ncclLaunchReset(ncclComm_t comm) {
|
||||
/*****************************************************************************/
|
||||
RCCL_PARAM(SharpThreshold, "SHARP_THRESHOLD", 16384);
|
||||
|
||||
static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
|
||||
static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetTypeSupport) {
|
||||
if (info->comm->collNetSupport > 0 && info->nBytes < rcclParamSharpThreshold()) {
|
||||
ncclRedOp_t netOp = info->op == ncclAvg ? ncclSum : info->op;
|
||||
NCCLCHECK(collNetReduceSupport(info->datatype, netOp, collNetTypeSupport));
|
||||
} else {
|
||||
*collNetTypeSupport = 0;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, int numPipeOps) {
|
||||
struct ncclComm* comm = info->comm;
|
||||
float minTime = 3600000000.0; // Hopefully no operation will take an hour to complete.
|
||||
// Find algorithm / protocol.
|
||||
info->algorithm = -1;
|
||||
info->protocol = -1;
|
||||
if (comm->nRanks == 1) return ncclSuccess;
|
||||
int nAlgos = NCCL_NUM_ALGORITHMS;
|
||||
|
||||
// Check collNet support
|
||||
int collNetTypeSupport = 0;
|
||||
if (info->comm->collNetSupport > 0 && info->nBytes < rcclParamSharpThreshold())
|
||||
NCCLCHECK(collNetReduceSupport(info->datatype, info->op, &collNetTypeSupport));
|
||||
for (int a=0; a<nAlgos; a++) {
|
||||
if (a == NCCL_ALGO_COLLNET && collNetTypeSupport != 1) continue;
|
||||
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
|
||||
float time;
|
||||
NCCLCHECK(ncclTopoGetAlgoTime(info, a, p, &time));
|
||||
NCCLCHECK(ncclTopoGetAlgoTime(info, a, p, numPipeOps, &time));
|
||||
if (time >= 0 && time < minTime) {
|
||||
info->algorithm = a;
|
||||
info->protocol = p;
|
||||
@@ -405,7 +438,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
|
||||
#else
|
||||
if (info->protocol == NCCL_PROTO_SIMPLE) {
|
||||
nt += WARP_SIZE; // Extra warp for sync
|
||||
if (info->algorithm == NCCL_ALGO_TREE) nt += WARP_SIZE;
|
||||
if (info->algorithm == NCCL_ALGO_TREE) nt += 3*WARP_SIZE;
|
||||
if (info->algorithm == NCCL_ALGO_COLLNET) nt += 3*WARP_SIZE;
|
||||
}
|
||||
#endif
|
||||
@@ -458,8 +491,14 @@ RCCL_PARAM(IntraNetThreshold, "RCCL_INTRANET_THRESHOLD", 8388608);
|
||||
static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWorkElem* work, struct ncclProxyArgs* proxyArgs /* output */) {
|
||||
work->comm = info->comm->devComm;
|
||||
|
||||
int collNetTypeSupport = 0;
|
||||
// Check whether algo and proto have been preset
|
||||
if (info->nChannels > 0 && info->nThreads > 0) goto comp_next;
|
||||
NCCLCHECK(getCollNetSupport(info, &collNetTypeSupport));
|
||||
NCCLCHECK(getAlgoInfo(info, collNetTypeSupport, 1));
|
||||
|
||||
comp_next:
|
||||
// Set nstepsPerLoop and nchunksPerLoop
|
||||
NCCLCHECK(getAlgoInfo(info));
|
||||
NCCLCHECK(getPatternInfo(info));
|
||||
NCCLCHECK(getLoopInfo(info));
|
||||
|
||||
@@ -523,10 +562,9 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo
|
||||
work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
|
||||
} else if (info->algorithm == NCCL_ALGO_COLLNET && info->protocol == NCCL_PROTO_SIMPLE) {
|
||||
// Optimize chunkSize / nSteps
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*32 && chunkSize > 262144) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*16 && chunkSize > 131072) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*64 && chunkSize > 131072) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*8 && chunkSize > 65536) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*8 && chunkSize > 32768) chunkSize /= 2;
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth/2 && chunkSize > 16384) chunkSize /= 2;
|
||||
// Use lastChunkSize as chunkSize
|
||||
work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
|
||||
} else if (info->protocol == NCCL_PROTO_LL) {
|
||||
@@ -557,7 +595,9 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo
|
||||
proxyArgs->chunkSize = chunkSize;
|
||||
proxyArgs->protocol = info->protocol;
|
||||
proxyArgs->dtype = info->datatype;
|
||||
proxyArgs->redOp = (info->algorithm == NCCL_ALGO_COLLNET) ? info->op : ncclNumOps; // Only set redOp when using CollNet
|
||||
proxyArgs->redOp = info->algorithm != NCCL_ALGO_COLLNET ? ncclNumOps : // Only set redOp when using CollNet
|
||||
info->op == ncclAvg ? ncclSum : // Network sees avg as sum
|
||||
info->op;
|
||||
proxyArgs->pattern = info->pattern;
|
||||
proxyArgs->root = info->root;
|
||||
// This is used by P2P to reduce the receive buffer size. We don't use it in collectives
|
||||
@@ -595,7 +635,7 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) {
|
||||
|
||||
// Compute cuda kernel arg and proxy arg templates
|
||||
struct ncclQueueElem* eqElem;
|
||||
NCCLCHECK(ncclAddQueueElem(comm->enqueueInfo, &eqElem));
|
||||
NCCLCHECK(comm->enqueueInfo->elemList->getNewElem(&eqElem));
|
||||
struct ncclWorkElem* work = &eqElem->work;
|
||||
eqElem->proxyArgs.nsubs = 1;
|
||||
NCCLCHECK(computeColl(info, work, &eqElem->proxyArgs));
|
||||
@@ -618,6 +658,29 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static inline int findShortestChannel(ncclComm_t comm) {
|
||||
size_t minSize = SIZE_MAX;
|
||||
int minC = 0;
|
||||
for (int c=0; c<comm->nChannels; c++) {
|
||||
struct ncclChannel* channel = comm->channels+c;
|
||||
if (channel->totalSize < minSize) {
|
||||
minSize = channel->totalSize;
|
||||
minC = c;
|
||||
}
|
||||
}
|
||||
return minC;
|
||||
}
|
||||
|
||||
static inline ncclResult_t getNextChannel(ncclComm_t comm, int* nextChannel) {
|
||||
if (comm->asyncAllocMode == ncclComm::SHORTEST_QUEUE) {
|
||||
*nextChannel = findShortestChannel(comm);
|
||||
} else {
|
||||
*nextChannel = comm->lastChannel % comm->nChannels;
|
||||
comm->lastChannel++;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Dynamic enqueue code
|
||||
static ncclResult_t ncclEnqueueCollKernel(ncclComm_t comm, struct ncclQueueElem* eqElem) {
|
||||
struct ncclWorkElem* work = &eqElem->work;
|
||||
@@ -645,9 +708,6 @@ static ncclResult_t ncclEnqueueCollKernel(ncclComm_t comm, struct ncclQueueElem*
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64)
|
||||
#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */
|
||||
|
||||
ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
|
||||
if (comm->asyncOpCount == 0) {
|
||||
return ncclSuccess;
|
||||
@@ -658,19 +718,47 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
|
||||
NCCLCHECK(ncclSetupCollKernel(info));
|
||||
} else {
|
||||
// Aggregation
|
||||
size_t channelSize = NCCL_AGG_CHANNEL_SIZE * comm->nRanks; // scale channel size based on nranks as latency increases
|
||||
size_t channelSize;
|
||||
if (comm->channelSize > 0) {
|
||||
channelSize = comm->channelSize;
|
||||
} else if (comm->collNetSupport && comm->asyncOps[0].coll == ncclFuncAllReduce) {
|
||||
channelSize = 256 * 1024;
|
||||
} else {
|
||||
channelSize = NCCL_AGG_CHANNEL_SIZE * std::min(16, comm->nRanks); // scale channel size based on nranks as latency increases
|
||||
}
|
||||
// Reduce the per-channel size if we cannot fully utilize the channels
|
||||
while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2;
|
||||
int channelUsed = 0;
|
||||
ncclFunc_t commonColl = ncclNumFuncs;
|
||||
int fastPath = 1;
|
||||
int allCollNetSupport = comm->collNetSupport;
|
||||
for (int c = 0; c < comm->asyncOpCount; c++) {
|
||||
struct ncclInfo* info = comm->asyncOps+c;
|
||||
info->nChannels = std::min((int)DIVUP(info->nBytes, channelSize), comm->nChannels); // assign number of channels
|
||||
info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels
|
||||
channelUsed += info->nChannels;
|
||||
// We can use fast path if all collectives are the same
|
||||
if (commonColl == ncclNumFuncs) commonColl = info->coll;
|
||||
else if (commonColl != info->coll) fastPath = 0;
|
||||
else if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
|
||||
}
|
||||
// Compute algo, proto, nthreads for the entire kernel
|
||||
struct ncclInfo total;
|
||||
total.comm = comm;
|
||||
total.coll = commonColl;
|
||||
total.nBytes = comm->asyncTotalSize;
|
||||
total.nChannels = std::min(channelUsed, comm->nChannels);
|
||||
int perChannelOps = DIVUP(channelUsed, total.nChannels);
|
||||
if (fastPath) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps));
|
||||
for (int c = 0; c < comm->asyncOpCount; c++) {
|
||||
struct ncclInfo* info = comm->asyncOps+c;
|
||||
if (fastPath) {
|
||||
info->algorithm = total.algorithm;
|
||||
info->protocol = total.protocol;
|
||||
info->nThreads = total.nThreads;
|
||||
}
|
||||
NCCLCHECK(ncclSetupCollKernel(info));
|
||||
}
|
||||
// If we wrap around on channels, then the inlined op on channel 0 is not the last one on this channel
|
||||
// Then we need to change active from 2 to 1
|
||||
if (channelUsed > comm->nChannels) comm->args.active = 1;
|
||||
comm->args.active = 0; // disable inline argument
|
||||
}
|
||||
// Reset counters
|
||||
comm->asyncOpCount = 0;
|
||||
@@ -711,7 +799,7 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) {
|
||||
}
|
||||
}
|
||||
}
|
||||
NCCLCHECK(enqueueP2pInfo(comm->p2pSends+info->root, (void*)info->sendbuff, nBytes));
|
||||
NCCLCHECK(ncclSaveP2pInfo(comm->p2pSends[info->root], (void*)info->sendbuff, nBytes));
|
||||
comm->p2pSendCount++;
|
||||
} else {
|
||||
if (peer != comm->rank) {
|
||||
@@ -728,15 +816,22 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) {
|
||||
}
|
||||
}
|
||||
}
|
||||
NCCLCHECK(enqueueP2pInfo(comm->p2pRecvs+info->root, info->recvbuff, nBytes));
|
||||
NCCLCHECK(ncclSaveP2pInfo(comm->p2pRecvs[info->root], info->recvbuff, nBytes));
|
||||
comm->p2pRecvCount++;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static int getSegment(int delta, struct ncclWork* work) {
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS && work->elems[s].p2p.delta != delta; s++) {
|
||||
if (work->elems[s].p2p.nThreads == 0) return s;
|
||||
enum { COLL_SEGMENT=0, P2P_SEGMENT=1 };
|
||||
static int getSegment(int type, int delta, struct ncclWork* work) {
|
||||
if (type == P2P_SEGMENT) { // P2P
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS && work->elems[s].p2p.delta != delta; s++) {
|
||||
if (work->elems[s].active == 0) return s;
|
||||
}
|
||||
} else { // aggregation
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s++) {
|
||||
if (work->elems[s].active == 0) return s;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@@ -756,16 +851,19 @@ static ncclResult_t computeP2pWorkElem(struct ncclInfo* info /* input */, struct
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t enqueueP2pOp(struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s) {
|
||||
static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s) {
|
||||
// Copy element into corresponding segment of ncclWork
|
||||
memcpy(work->elems+s, elem, sizeof(struct ncclWorkElem));
|
||||
work->elems[s].active = 1;
|
||||
|
||||
// Determine nThreads at dynamic time
|
||||
const int nsegments = s+1;
|
||||
int nThreads = 512;
|
||||
while (nsegments*nThreads > 256) nThreads /= 2;
|
||||
//if (nThreads >= 128) nThreads += WARP_SIZE;
|
||||
for (int i=0; i<nsegments; i++) work->elems[i].p2p.nThreads = nThreads;
|
||||
if (type == P2P_SEGMENT) {
|
||||
const int nsegments = s+1;
|
||||
int nThreads = 512;
|
||||
while (nsegments*nThreads > 256) nThreads /= 2;
|
||||
//if (nThreads >= 128) nThreads += WARP_SIZE;
|
||||
for (int i=0; i<nsegments; i++) work->elems[i].p2p.nThreads = nThreads;
|
||||
}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -779,9 +877,9 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e
|
||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
int segment = -1;
|
||||
if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].p2p.nThreads == 0) {
|
||||
if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) {
|
||||
// Try to pack more segments into a single operation
|
||||
segment = getSegment(workElem->p2p.delta, w);
|
||||
segment = getSegment(P2P_SEGMENT, workElem->p2p.delta, w);
|
||||
}
|
||||
if (segment == -1) {
|
||||
NCCLCHECK(getNextOp(channel, &w, NULL));
|
||||
@@ -790,7 +888,7 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e
|
||||
|
||||
// store work element into FIFO
|
||||
NCCLCHECK(ncclProxySaveP2p(comm, proxyArgs));
|
||||
NCCLCHECK(enqueueP2pOp(workElem, w, segment));
|
||||
NCCLCHECK(enqueueSegOp(P2P_SEGMENT, workElem, w, segment));
|
||||
comm->p2pOpCount++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -799,8 +897,7 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) {
|
||||
ncclComm* comm = info->comm;
|
||||
// Compute cuda kernel arg and proxy arg templates
|
||||
struct ncclQueueElem* eqElem;
|
||||
NCCLCHECK(ncclAddQueueElem(comm->enqueueInfo, &eqElem));
|
||||
|
||||
NCCLCHECK(comm->enqueueInfo->elemList->getNewElem(&eqElem));
|
||||
// The proxy code will set and tune the send/recv chunk size, make sure to run it first.
|
||||
NCCLCHECK(ncclProxyComputeP2p(info, &eqElem->proxyArgs));
|
||||
NCCLCHECK(computeP2pWorkElem(info, &eqElem->work));
|
||||
@@ -821,11 +918,51 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) {
|
||||
// The CUDA kernel does not use the inlined first work element as fastpath argument
|
||||
if (params->func == NULL) {
|
||||
params->func = (void *)ncclKerns[0];
|
||||
memcpy(&comm->args, &eqElem->work, sizeof(struct ncclWorkElem));
|
||||
comm->args.comm = eqElem->work.comm;
|
||||
comm->args.active = 0;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem* eqElem) {
|
||||
struct ncclWorkElem* work = &eqElem->work;
|
||||
struct ncclProxyArgs* proxyArgs = &eqElem->proxyArgs;
|
||||
|
||||
int nChannels = work->coll.nChannels;
|
||||
size_t channelSize = work->coll.count*ncclTypeSize(proxyArgs->dtype)/work->coll.nChannels;
|
||||
for (int bid=0; bid<nChannels; bid++) {
|
||||
int channelId;
|
||||
NCCLCHECK(getNextChannel(comm, &channelId));
|
||||
struct ncclChannel* channel = comm->channels+channelId;
|
||||
|
||||
// Proxy
|
||||
proxyArgs->subs[0].channel = channel;
|
||||
proxyArgs->opCount = comm->collOpCount;
|
||||
proxyArgs->commOpCount = comm->opCount;
|
||||
if (proxyArgs->subs[0].nsteps) NCCLCHECK(ncclProxySaveColl(proxyArgs, comm->nRanks));
|
||||
|
||||
// Try to reuse last work if not full yet
|
||||
work->coll.bid = bid % nChannels;
|
||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
int segment = -1;
|
||||
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) {
|
||||
// Try to pack more segments into a single operation
|
||||
segment = getSegment(COLL_SEGMENT, 0, w);
|
||||
}
|
||||
if (segment == -1) {
|
||||
NCCLCHECK(getNextOp(channel, &w, NULL));
|
||||
segment = 0;
|
||||
}
|
||||
|
||||
// store work element into FIFO
|
||||
NCCLCHECK(enqueueSegOp(COLL_SEGMENT, work, w, segment));
|
||||
channel->totalSize += channelSize;
|
||||
}
|
||||
comm->collOpCount++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
template<int USING_CUDA_GRAPH>
|
||||
void HIPRT_CB ncclEnqueueHostSetup(void* arg) {
|
||||
ncclResult_t ret;
|
||||
@@ -833,14 +970,17 @@ void HIPRT_CB ncclEnqueueHostSetup(void* arg) {
|
||||
ncclComm_t comm = eqInfo->comm;
|
||||
|
||||
// Iterate through the element list
|
||||
struct ncclQueueElem* eqElem = eqInfo->elemList.head;
|
||||
while (eqElem != eqInfo->elemList.tail) { // The queue always has one extra element
|
||||
struct ncclQueueElem* eqElem = eqInfo->elemList->begin();
|
||||
while (eqElem != NULL) {
|
||||
if (eqElem->work.funcIndex == FUNC_INDEX_P2P) {
|
||||
NCCLCHECKGOTO(ncclEnqueueP2pKernel(comm, eqElem), ret, cb_end);
|
||||
} else if (eqInfo->elemList->count() > 1) {
|
||||
// We have more than one operation, hence aggregating
|
||||
NCCLCHECKGOTO(ncclEnqueueAsyncKernel(comm, eqElem), ret, cb_end);
|
||||
} else {
|
||||
NCCLCHECKGOTO(ncclEnqueueCollKernel(comm, eqElem), ret, cb_end);
|
||||
}
|
||||
eqElem = eqElem->next;
|
||||
eqElem = eqInfo->elemList->getNext();
|
||||
}
|
||||
|
||||
NCCLCHECKGOTO(setupLaunch(eqInfo, USING_CUDA_GRAPH), ret, cb_end);
|
||||
|
||||
@@ -405,7 +405,9 @@ ncclResult_t ncclTopoComputePaths(struct ncclTopoSystem* system, struct ncclPeer
|
||||
struct ncclPeerInfo* srcInfo = peerInfos+system->nodes[GPU].nodes[p].gpu.rank;
|
||||
int shm;
|
||||
NCCLCHECK(ncclTransports[TRANSPORT_SHM].canConnect(&shm, system, NULL, srcInfo, dstInfo));
|
||||
if (shm == 0) {
|
||||
int p2p;
|
||||
NCCLCHECK(ncclTransports[TRANSPORT_P2P].canConnect(&p2p, system, NULL, srcInfo, dstInfo));
|
||||
if (shm == 0 && p2p == 0) {
|
||||
// Mark this peer as inaccessible. We'll trim it later.
|
||||
system->nodes[GPU].nodes[p].paths[GPU][g].count = 0;
|
||||
}
|
||||
|
||||
@@ -333,7 +333,7 @@ ncclResult_t ncclTopoCompareGraphs(struct ncclTopoSystem* system, struct ncclTop
|
||||
return ncclSuccess;
|
||||
}
|
||||
// 3. Less hops (but not at the price of going cross NICs)
|
||||
if (graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops) *copy = 1;
|
||||
if (graph->pattern == refGraph->pattern && graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops) *copy = 1;
|
||||
|
||||
// 4. Prefer graph with more XGMI connections
|
||||
if (graph->nChannels == refGraph->nChannels
|
||||
@@ -758,11 +758,14 @@ ncclResult_t ncclTopoGetXmlFromGraphs(int ngraphs, struct ncclTopoGraph** graphs
|
||||
}
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
float speedArray[] = { 24.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
float speedArrayIntra[] = { 24.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
float speedArrayInter[] = { 24.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
#else
|
||||
float speedArray[] = { 42.0, 30.0, 24.0, 21.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
float speedArrayIntra[] = { 44.0, 30.0, 22.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0 };
|
||||
float speedArrayInter[] = { 48.0, 30.0, 24.0, 22.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
|
||||
#endif
|
||||
#define NSPEEDS (sizeof(speedArray)/sizeof(float))
|
||||
#define NSPEEDSINTRA (sizeof(speedArrayIntra)/sizeof(float))
|
||||
#define NSPEEDSINTER (sizeof(speedArrayInter)/sizeof(float))
|
||||
|
||||
RCCL_PARAM(ModelMatchingDisable, "MODEL_MATCHING_DISABLE", 0);
|
||||
RCCL_PARAM(EnableMultipleSAT, "ENABLE_MULTIPLE_SAT", 0);
|
||||
@@ -816,23 +819,26 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph
|
||||
graph->maxChannels = 1;
|
||||
if (ngpus == 1) if (graph->pattern != NCCL_TOPO_PATTERN_RING) graph->pattern = NCCL_TOPO_PATTERN_TREE;
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
// TODO: benchmark balance tree vs split tree
|
||||
//if (graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE) graph->pattern = NCCL_TOPO_PATTERN_SPLIT_TREE;
|
||||
#else
|
||||
// SPLIT_TREE works better on older archs.
|
||||
int ccMin;
|
||||
NCCLCHECK(ncclTopoGetCompCap(system, &ccMin, NULL));
|
||||
if (ccMin < 80 && graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE) graph->pattern = NCCL_TOPO_PATTERN_SPLIT_TREE;
|
||||
#endif
|
||||
|
||||
struct ncclTopoGraph tmpGraph;
|
||||
memcpy(&tmpGraph, graph, sizeof(struct ncclTopoGraph));
|
||||
|
||||
// First try crossnic, then decrease speed and finally increase speedIntra.
|
||||
int nspeeds = 0;
|
||||
float* speedArray = NULL;
|
||||
if (system->nodes[NET].count == 0) {
|
||||
nspeeds = NSPEEDSINTRA;
|
||||
speedArray = speedArrayIntra;
|
||||
} else {
|
||||
nspeeds = NSPEEDSINTER;
|
||||
speedArray = speedArrayInter;
|
||||
}
|
||||
int pass = 1;
|
||||
int speedIndex = 0;
|
||||
while (speedArray[speedIndex] > system->maxWidth && speedIndex < NSPEEDS-1) speedIndex++;
|
||||
while (speedArray[speedIndex] > system->maxWidth && speedIndex < nspeeds-1) speedIndex++;
|
||||
tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex];
|
||||
int64_t globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT;
|
||||
|
||||
@@ -899,12 +905,12 @@ search:
|
||||
tmpGraph.crossNic = 0;
|
||||
|
||||
// Decrease speed until we find a solution
|
||||
if ((speedIndex < NSPEEDS-1) && (graph->nChannels == 0 || (speedArray[speedIndex+1]/graph->speedInter > .49))) {
|
||||
if ((speedIndex < nspeeds-1) && (graph->nChannels == 0 || (speedArray[speedIndex+1]/graph->speedInter > .49))) {
|
||||
tmpGraph.speedInter = tmpGraph.speedIntra = speedArray[++speedIndex];
|
||||
goto search;
|
||||
}
|
||||
speedIndex = 0;
|
||||
while (speedArray[speedIndex] > system->maxWidth && speedIndex < NSPEEDS-1) speedIndex++;
|
||||
while (speedArray[speedIndex] > system->maxWidth && speedIndex < nspeeds-1) speedIndex++;
|
||||
tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex];
|
||||
|
||||
}
|
||||
@@ -915,7 +921,7 @@ done:
|
||||
time = -1;
|
||||
memcpy(&tmpGraph, graph, sizeof(tmpGraph));
|
||||
speedIndex = 0;
|
||||
while (speedArray[speedIndex] > graph->speedInter && speedIndex < NSPEEDS-1) speedIndex++;
|
||||
while (speedArray[speedIndex] > graph->speedInter && speedIndex < nspeeds-1) speedIndex++;
|
||||
tmpGraph.speedIntra = tmpGraph.speedInter = speedArray[speedIndex];
|
||||
tmpGraph.minChannels = graph->nChannels;
|
||||
pass = 2;
|
||||
|
||||
@@ -14,10 +14,6 @@
|
||||
#include "coll_net.h"
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
#include <hsa/hsa.h>
|
||||
#include <hsa/hsa_ext_amd.h>
|
||||
#endif
|
||||
#include "xml.h"
|
||||
#include "cpuset.h"
|
||||
|
||||
@@ -662,7 +658,10 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy
|
||||
char* xmlTopoFile = getenv("NCCL_TOPO_FILE");
|
||||
if (xmlTopoFile) {
|
||||
INFO(NCCL_ENV, "NCCL_TOPO_FILE set by environment to %s", xmlTopoFile);
|
||||
NCCLCHECK(ncclTopoGetXmlFromFile(xmlTopoFile, xml));
|
||||
NCCLCHECK(ncclTopoGetXmlFromFile(xmlTopoFile, xml, 1));
|
||||
} else {
|
||||
// Try default XML topology location
|
||||
NCCLCHECK(ncclTopoGetXmlFromFile("/var/run/nvidia-topologyd/virtualTopology.xml", xml, 0));
|
||||
}
|
||||
if (xml->maxIndex == 0) {
|
||||
// Create top tag
|
||||
@@ -770,7 +769,7 @@ ncclResult_t ncclTopoCpuType(struct ncclTopoSystem* system, int* arch, int* vend
|
||||
|
||||
NCCL_PARAM(IgnoreCpuAffinity, "IGNORE_CPU_AFFINITY", 0);
|
||||
|
||||
ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank) {
|
||||
ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu_set_t* affinity) {
|
||||
struct ncclTopoNode* cpu = NULL, *gpu = NULL;
|
||||
for (int g=0; g<system->nodes[GPU].count; g++) {
|
||||
if (system->nodes[GPU].nodes[g].gpu.rank == rank) {
|
||||
@@ -823,12 +822,13 @@ ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank) {
|
||||
// Use a subset of the GPU affinity set
|
||||
CPU_AND(&finalMask, &mask, &cpuMask);
|
||||
|
||||
memcpy(affinity, &finalMask, sizeof(cpu_set_t));
|
||||
|
||||
// If there is a non empty set, use it to set affinity
|
||||
if (CPU_COUNT(&finalMask)) {
|
||||
char affinityStr[sizeof(cpu_set_t)*2];
|
||||
NCCLCHECK(ncclCpusetToStr(&finalMask, affinityStr));
|
||||
INFO(NCCL_INIT, "Setting affinity for GPU %d to %s", gpu->gpu.dev, affinityStr);
|
||||
SYSCHECK(sched_setaffinity(0, sizeof(cpu_set_t), &finalMask), "sched_setaffinity");
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -10,12 +10,11 @@
|
||||
|
||||
#include "graph.h"
|
||||
#include "core.h"
|
||||
#include <sched.h>
|
||||
|
||||
#define LOC_WIDTH 5000.0
|
||||
#define SM60_NVLINK_WIDTH 18.0
|
||||
#define SM70_NVLINK_WIDTH 21.0
|
||||
#define SM80_NVLINK_WIDTH 21.0
|
||||
#define SM70_NVLINK_WIDTH 22.0
|
||||
#define SM80_NVLINK_WIDTH 22.0
|
||||
#define SM86_NVLINK_WIDTH 12.0
|
||||
#define PCI_WIDTH 12.0 // PCI Gen3 x16
|
||||
#define QPI_WIDTH 6.0
|
||||
|
||||
@@ -61,7 +61,7 @@ static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 39.0
|
||||
#define NCCL_HW_PCI 1
|
||||
#define NCCL_HW_NET 2
|
||||
// Tree/Simple is the latency a 256kB chunk, which is ~ base lat + 256k/12GB/s (+ 256k/12GB/s for the network).
|
||||
static const float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
|
||||
static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
|
||||
{ /* NVLINK */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 2.5, 2.5, 5.5 }, /* Ring (LL/LL128/Simple)*/ { 2.5, 2.5, 5 }, /* CollNet (LL/LL128/Simple)*/ { 1.2, 1.2, 3.8 } },
|
||||
/* PCI */
|
||||
@@ -70,11 +70,10 @@ static const float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 28.0, 28.0, 66.0 }, /* Ring (LL/LL128/Simple)*/ { 8.5, 8.5, 19.0 }, /* CollNet (LL/LL128/Simple)*/ { 9.8, 9.8, 19.5 } }
|
||||
};
|
||||
|
||||
// LL128 max BW (per channel) for the different collectives
|
||||
// ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce
|
||||
static const double ll128MaxBwPerCh[NCCL_NUM_FUNCTIONS] = { 18.8, 12.0, 18.3, 15.2, 16.9 };
|
||||
// LL128 max BW per channel
|
||||
static const double ll128MaxBwPerCh = 20.0;
|
||||
static const double llMaxBws[2][3] = { /* Volta-N1/Intel-N2/Intel-N4) */ {39.0, 39.0, 20.4}, /* Ampere-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0} };
|
||||
static const double perChMaxTreeBws[2][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 22.5, 16.0} };
|
||||
static const double perChMaxTreeBws[2][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 23.6, 17.8} };
|
||||
|
||||
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph) {
|
||||
int simpleDefaultThreads = (ringGraph->speedIntra*ringGraph->nChannels <= PCI_WIDTH) ? 256 : NCCL_SIMPLE_MAX_NTHREADS;
|
||||
@@ -109,6 +108,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
int index1 = nNodes == 1 ? compCap80 : cpuVendor == NCCL_TOPO_CPU_VENDOR_AMD ? 1 : 0;
|
||||
double llMaxBw = llMaxBws[index1][index2];
|
||||
double perChMaxTreeBw = perChMaxTreeBws[compCap80][index2];
|
||||
// De-penalize Tree/Simple latency on Power systems to favor Tree than Ring
|
||||
if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE];
|
||||
float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount
|
||||
|
||||
struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph };
|
||||
@@ -134,22 +135,22 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
// Various model refinements
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) busBw *= 1.0/5.0;
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh*graphs[a]->nChannels);
|
||||
double maxTreeBw = comm->nNodes > 2 ?
|
||||
compCap80 && p == NCCL_PROTO_LL128 ? 105.0 : 80.0 :
|
||||
compCap80 && p == NCCL_PROTO_LL128 ? 130.0 : 110.0;
|
||||
if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.27, comm->nNodes > 1 ? 70.0 : 90.0);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw *= 1.0/2.3;
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (comm->nNodes == 1 ? 7.0/9.0 : 0.915 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels*7.0/9.0);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (comm->nNodes == 1 ? 7.0/9.0 : 0.915 /*120.0/128.0*/), ll128MaxBwPerCh*graphs[a]->nChannels*7.0/9.0);
|
||||
#else
|
||||
if (compCap80) busBw = std::min(busBw, 235.0f);
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); }
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh*graphs[a]->nChannels);
|
||||
if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw = std::min(busBw*1.0/3.8, llMaxBw);
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 0.915 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels);
|
||||
#endif
|
||||
if (a == NCCL_ALGO_COLLNET) busBw *= .9;
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), ll128MaxBwPerCh*graphs[a]->nChannels);
|
||||
#endif
|
||||
if (a == NCCL_ALGO_COLLNET && p != NCCL_PROTO_SIMPLE) busBw = 0; // Oneshot CollNet only supports Simple
|
||||
|
||||
// Convert bus BW to algorithm BW
|
||||
@@ -177,7 +178,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
2 * ((nRanks/nNodes-1) * intraLat + log2i(nNodes) * interLat);
|
||||
} else {
|
||||
comm->latencies[coll][a][p] +=
|
||||
2 * (nRanks/nNodes-1) * intraLat + interLat;
|
||||
2 * (std::min(1, (nRanks/nNodes-1)) * intraLat + (nRanks/nNodes-1) * 0.5) + interLat; // Add 0.5 arity serialization latency
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -296,7 +297,7 @@ static float ringCorrectionFactor[NCCL_NUM_PROTOCOLS][22] = {
|
||||
{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .04, .08, .09, .09, .11, .13, .25, .40, .59, .76, .86, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 }
|
||||
};
|
||||
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, float* time) {
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time) {
|
||||
float bw = info->comm->bandwidths[info->coll][algorithm][protocol];
|
||||
float lat = info->comm->latencies[info->coll][algorithm][protocol];
|
||||
if (bw == 0) {
|
||||
@@ -313,6 +314,8 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int proto
|
||||
if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && info->comm->nNodes > 1
|
||||
&& info->coll == ncclFuncAllReduce && info->nBytes >= info->comm->nRanks/16.0*65536) lat *= 1.9; // Plateau effect of ring
|
||||
#endif
|
||||
*time = lat + (info->nBytes) / (1000 * bw);
|
||||
// Tree pipelining saves latency in aggregation cases
|
||||
int latCount = algorithm == NCCL_ALGO_RING ? numPipeOps : DIVUP(numPipeOps, NCCL_MAX_WORK_ELEMENTS);
|
||||
*time = lat * latCount + (info->nBytes) / (1000 * bw);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -10,10 +10,6 @@
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <ctype.h>
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
#include <hsa/hsa.h>
|
||||
#include <hsa/hsa_ext_amd.h>
|
||||
#endif
|
||||
#include "core.h"
|
||||
#include "nvmlwrap.h"
|
||||
#include "xml.h"
|
||||
@@ -310,12 +306,15 @@ ncclResult_t ncclTopoXmlLoadSystem(FILE* file, struct ncclXml* xml, struct ncclX
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml) {
|
||||
ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml, int warn) {
|
||||
FILE* file = fopen(xmlTopoFile, "r");
|
||||
if (file == NULL) {
|
||||
WARN("Could not open XML topology file %s : %s", xmlTopoFile, strerror(errno));
|
||||
if (warn) {
|
||||
WARN("Could not open XML topology file %s : %s", xmlTopoFile, strerror(errno));
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
INFO(NCCL_GRAPH, "Loading topology file %s", xmlTopoFile);
|
||||
struct xmlHandler handlers[] = { { "system", ncclTopoXmlLoadSystem } };
|
||||
xml->maxIndex = 0;
|
||||
NCCLCHECK(xmlLoadSub(file, xml, NULL, handlers, 1));
|
||||
@@ -451,8 +450,8 @@ ncclResult_t ncclTopoGetPciNode(struct ncclXml* xml, const char* busId, struct n
|
||||
NCCLCHECK(xmlFindTagKv(xml, "pci", pciNode, "busid", busId));
|
||||
if (*pciNode == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, NULL, "pci", pciNode));
|
||||
NCCLCHECK(xmlSetAttr(*pciNode, "busid", busId));
|
||||
}
|
||||
NCCLCHECK(xmlSetAttr(*pciNode, "busid", busId));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -473,109 +472,123 @@ ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml*
|
||||
const char* busId;
|
||||
NCCLCHECK(xmlGetAttr(pciNode, "busid", &busId));
|
||||
char* path = NULL;
|
||||
int index;
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "class", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
|
||||
ncclDebugNoWarn = NCCL_GRAPH;
|
||||
getPciPath(busId, &path);
|
||||
ncclDebugNoWarn = 0;
|
||||
|
||||
if (path) {
|
||||
NCCLCHECK(ncclTopoSetAttrFromSys(pciNode, path, "class", "class"));
|
||||
}
|
||||
int index;
|
||||
ncclDebugNoWarn = NCCL_GRAPH;
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "vendor", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) getPciPath(busId, &path);
|
||||
if (path) ncclTopoSetAttrFromSys(pciNode, path, "vendor", "vendor");
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "device", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) getPciPath(busId, &path);
|
||||
if (path) ncclTopoSetAttrFromSys(pciNode, path, "device", "device");
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "subsystem_vendor", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) getPciPath(busId, &path);
|
||||
if (path) ncclTopoSetAttrFromSys(pciNode, path, "subsystem_vendor", "subsystem_vendor");
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "subsystem_device", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) getPciPath(busId, &path);
|
||||
if (path) ncclTopoSetAttrFromSys(pciNode, path, "subsystem_device", "subsystem_device");
|
||||
}
|
||||
ncclDebugNoWarn = 0;
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "link_speed", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
|
||||
char deviceSpeedStr[MAX_STR_LEN];
|
||||
float deviceSpeed;
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_speed", deviceSpeedStr));
|
||||
sscanf(deviceSpeedStr, "%f GT/s", &deviceSpeed);
|
||||
char portSpeedStr[MAX_STR_LEN];
|
||||
float portSpeed;
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_speed", portSpeedStr));
|
||||
if (portSpeedStr[0])
|
||||
sscanf(portSpeedStr, "%f GT/s", &portSpeed);
|
||||
else
|
||||
portSpeed = deviceSpeed;
|
||||
NCCLCHECK(xmlSetAttr(pciNode, "link_speed", portSpeed < deviceSpeed ? portSpeedStr : deviceSpeedStr));
|
||||
if (path) {
|
||||
char deviceSpeedStr[MAX_STR_LEN];
|
||||
float deviceSpeed;
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_speed", deviceSpeedStr));
|
||||
sscanf(deviceSpeedStr, "%f GT/s", &deviceSpeed);
|
||||
char portSpeedStr[MAX_STR_LEN];
|
||||
float portSpeed;
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_speed", portSpeedStr));
|
||||
if (portSpeedStr[0])
|
||||
sscanf(portSpeedStr, "%f GT/s", &portSpeed);
|
||||
else
|
||||
portSpeed = deviceSpeed;
|
||||
NCCLCHECK(xmlSetAttr(pciNode, "link_speed", portSpeed < deviceSpeed ? portSpeedStr : deviceSpeedStr));
|
||||
} else {
|
||||
NCCLCHECK(xmlSetAttr(pciNode, "link_speed", ""));
|
||||
}
|
||||
}
|
||||
NCCLCHECK(xmlGetAttrIndex(pciNode, "link_width", &index));
|
||||
if (index == -1) {
|
||||
if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
|
||||
char strValue[MAX_STR_LEN];
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_width", strValue));
|
||||
int deviceWidth = strtol(strValue, NULL, 0);
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_width", strValue));
|
||||
int portWidth;
|
||||
if (strValue[0])
|
||||
portWidth = strtol(strValue, NULL, 0);
|
||||
else
|
||||
portWidth = deviceWidth;
|
||||
NCCLCHECK(xmlSetAttrInt(pciNode, "link_width", std::min(deviceWidth,portWidth)));
|
||||
if (path) {
|
||||
char strValue[MAX_STR_LEN];
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "max_link_width", strValue));
|
||||
int deviceWidth = strtol(strValue, NULL, 0);
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "../max_link_width", strValue));
|
||||
int portWidth;
|
||||
if (strValue[0])
|
||||
portWidth = strtol(strValue, NULL, 0);
|
||||
else
|
||||
portWidth = deviceWidth;
|
||||
NCCLCHECK(xmlSetAttrInt(pciNode, "link_width", std::min(deviceWidth,portWidth)));
|
||||
} else {
|
||||
NCCLCHECK(xmlSetAttr(pciNode, "link_width", ""));
|
||||
}
|
||||
}
|
||||
struct ncclXmlNode* parent = pciNode->parent;
|
||||
if (parent == NULL) {
|
||||
if (path == NULL) NCCLCHECK(getPciPath(busId, &path));
|
||||
if (path) {
|
||||
// Save that for later in case next step is a CPU
|
||||
char numaIdStr[MAX_STR_LEN];
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "numa_node", numaIdStr));
|
||||
// Workaround kernel bug for now
|
||||
if (strcmp(numaIdStr, "-1") == 0) strcpy(numaIdStr, "0");
|
||||
|
||||
// Save that for later in case next step is a CPU
|
||||
char numaIdStr[MAX_STR_LEN];
|
||||
NCCLCHECK(ncclTopoGetStrFromSys(path, "numa_node", numaIdStr));
|
||||
// Workaround kernel bug for now
|
||||
if (strcmp(numaIdStr, "-1") == 0) strcpy(numaIdStr, "0");
|
||||
|
||||
// Go up one level in the PCI tree. Rewind two "/" and follow the upper PCI
|
||||
// switch, or stop if we reach a CPU root complex.
|
||||
int slashCount = 0;
|
||||
int parentOffset;
|
||||
for (parentOffset = strlen(path)-1; parentOffset>0; parentOffset--) {
|
||||
if (path[parentOffset] == '/') {
|
||||
slashCount++;
|
||||
path[parentOffset] = '\0';
|
||||
int start = parentOffset - 1;
|
||||
while (start>0 && path[start] != '/') start--;
|
||||
// Check whether the parent path looks like "BBBB:BB:DD.F" or not.
|
||||
if (checkBDFFormat(path+start+1) == 0) {
|
||||
// This a CPU root complex. Create a CPU tag and stop there.
|
||||
struct ncclXmlNode* topNode;
|
||||
NCCLCHECK(xmlFindTag(xml, "system", &topNode));
|
||||
NCCLCHECK(xmlGetSubKv(topNode, "cpu", &parent, "numaid", numaIdStr));
|
||||
if (parent == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "numaid", numaIdStr));
|
||||
}
|
||||
} else if (slashCount == 2) {
|
||||
// Continue on the upper PCI switch
|
||||
for (int i = strlen(path)-1; i>0; i--) {
|
||||
if (path[i] == '/') {
|
||||
NCCLCHECK(xmlFindTagKv(xml, "pci", &parent, "busid", path+i+1));
|
||||
if (parent == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, NULL, "pci", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "busid", path+i+1));
|
||||
// Go up one level in the PCI tree. Rewind two "/" and follow the upper PCI
|
||||
// switch, or stop if we reach a CPU root complex.
|
||||
int slashCount = 0;
|
||||
int parentOffset;
|
||||
for (parentOffset = strlen(path)-1; parentOffset>0; parentOffset--) {
|
||||
if (path[parentOffset] == '/') {
|
||||
slashCount++;
|
||||
path[parentOffset] = '\0';
|
||||
int start = parentOffset - 1;
|
||||
while (start>0 && path[start] != '/') start--;
|
||||
// Check whether the parent path looks like "BBBB:BB:DD.F" or not.
|
||||
if (checkBDFFormat(path+start+1) == 0) {
|
||||
// This a CPU root complex. Create a CPU tag and stop there.
|
||||
struct ncclXmlNode* topNode;
|
||||
NCCLCHECK(xmlFindTag(xml, "system", &topNode));
|
||||
NCCLCHECK(xmlGetSubKv(topNode, "cpu", &parent, "numaid", numaIdStr));
|
||||
if (parent == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "numaid", numaIdStr));
|
||||
}
|
||||
} else if (slashCount == 2) {
|
||||
// Continue on the upper PCI switch
|
||||
for (int i = strlen(path)-1; i>0; i--) {
|
||||
if (path[i] == '/') {
|
||||
NCCLCHECK(xmlFindTagKv(xml, "pci", &parent, "busid", path+i+1));
|
||||
if (parent == NULL) {
|
||||
NCCLCHECK(xmlAddNode(xml, NULL, "pci", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "busid", path+i+1));
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parent) break;
|
||||
}
|
||||
} else {
|
||||
// No information on /sys, attach GPU to unknown CPU
|
||||
NCCLCHECK(xmlFindTagKv(xml, "cpu", &parent, "numaid", "-1"));
|
||||
if (parent == NULL) {
|
||||
struct ncclXmlNode* topNode;
|
||||
NCCLCHECK(xmlFindTag(xml, "system", &topNode));
|
||||
NCCLCHECK(xmlAddNode(xml, topNode, "cpu", &parent));
|
||||
NCCLCHECK(xmlSetAttr(parent, "numaid", "-1"));
|
||||
NCCLCHECK(ncclTopoGetXmlFromCpu(parent, xml));
|
||||
}
|
||||
if (parent) break;
|
||||
}
|
||||
pciNode->parent = parent;
|
||||
parent->subs[parent->nSubs++] = pciNode;
|
||||
@@ -735,12 +748,14 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
|
||||
if (index == -1) {
|
||||
const char* busId;
|
||||
NCCLCHECK(xmlGetAttr(sub, "target", &busId));
|
||||
if (strcmp(busId, "fffffff:ffff:ff") == 0) {
|
||||
char* path;
|
||||
ncclDebugNoWarn = NCCL_GRAPH;
|
||||
getPciPath(busId, &path);
|
||||
ncclDebugNoWarn = 0;
|
||||
if (path == NULL || strcmp(busId, "fffffff:ffff:ff") == 0) {
|
||||
// Remote NVLink device is not visible inside this VM. Assume NVSwitch.
|
||||
NCCLCHECK(xmlSetAttr(sub, "tclass", "0x068000"));
|
||||
} else {
|
||||
char* path;
|
||||
NCCLCHECK(getPciPath(busId, &path));
|
||||
NCCLCHECK(ncclTopoSetAttrFromSys(sub, path, "class", "tclass"));
|
||||
free(path);
|
||||
}
|
||||
@@ -753,6 +768,7 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
|
||||
ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct ncclXmlNode** gpuNode) {
|
||||
struct ncclXmlNode* node;
|
||||
NCCLCHECK(ncclTopoGetPciNode(xml, busId, &node));
|
||||
NCCLCHECK(xmlSetAttrIfUnset(node, "class", "0x03"));
|
||||
NCCLCHECK(ncclTopoGetXmlFromSys(node, xml));
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
uint32_t devIndex;
|
||||
@@ -817,6 +833,7 @@ ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const cha
|
||||
char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE];
|
||||
strcpy(busId, pciSysPath+offset+1);
|
||||
NCCLCHECK(ncclTopoGetPciNode(xml, busId, &parent));
|
||||
NCCLCHECK(xmlSetAttrIfUnset(parent, "class", "0x02"));
|
||||
NCCLCHECK(ncclTopoGetXmlFromSys(parent, xml));
|
||||
} else {
|
||||
// Virtual NIC, no PCI device, attach to first CPU
|
||||
|
||||
@@ -38,7 +38,7 @@ struct ncclXml {
|
||||
|
||||
/* File functions */
|
||||
#define NCCL_TOPO_XML_VERSION 2
|
||||
ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml);
|
||||
ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml, int warn);
|
||||
ncclResult_t ncclTopoDumpXmlToFile(const char* xmlTopoFile, struct ncclXml* xml);
|
||||
#define NCCL_GRAPH_XML_VERSION 1
|
||||
ncclResult_t ncclTopoGetXmlGraphFromFile(const char* xmlGraphFile, struct ncclXml* xml);
|
||||
@@ -137,6 +137,18 @@ static ncclResult_t xmlSetAttr(struct ncclXmlNode* node, const char* attrName, c
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t xmlSetAttrIfUnset(struct ncclXmlNode* node, const char* attrName, const char* value) {
|
||||
int index;
|
||||
NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
|
||||
if (index != -1) return ncclSuccess;
|
||||
index = node->nAttrs++;
|
||||
strncpy(node->attrs[index].key, attrName, MAX_STR_LEN);
|
||||
node->attrs[index].key[MAX_STR_LEN] = '\0';
|
||||
strncpy(node->attrs[index].value, value, MAX_STR_LEN);
|
||||
node->attrs[index].value[MAX_STR_LEN] = '\0';
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t xmlSetAttrInt(struct ncclXmlNode* node, const char* attrName, const int value) {
|
||||
int index;
|
||||
NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
|
||||
|
||||
@@ -139,6 +139,7 @@ void* ncclAsyncThreadPreconnect(void* args_) {
|
||||
struct ncclAsyncArgs* args = (struct ncclAsyncArgs*)args_;
|
||||
struct ncclComm* comm = args->coll.comm;
|
||||
CUDACHECKTHREAD(hipSetDevice(comm->cudaDev));
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity);
|
||||
NCCLCHECKTHREAD(ncclTransportP2pSetup(comm, NULL, args->coll.connIndex));
|
||||
return args;
|
||||
}
|
||||
@@ -249,8 +250,6 @@ ncclResult_t ncclGroupEnd() {
|
||||
struct ncclComm* comm = args->coll.comm;
|
||||
int rank = comm->rank;
|
||||
int nRanks = comm->nRanks;
|
||||
struct ncclP2Plist* p2pSends = comm->p2pSends;
|
||||
struct ncclP2Plist* p2pRecvs = comm->p2pRecvs;
|
||||
|
||||
// Compute how much to split operations
|
||||
// Natural step size matching buffer steps.
|
||||
@@ -273,8 +272,8 @@ ncclResult_t ncclGroupEnd() {
|
||||
sched_delta:
|
||||
uint32_t from = (rank+nRanks-delta)%nRanks;
|
||||
uint32_t to = (rank+delta)%nRanks;
|
||||
struct ncclP2Pinfo* recv = p2pRecvs[from].head;
|
||||
struct ncclP2Pinfo* send = p2pSends[to].head;
|
||||
struct ncclP2Pinfo* recv = comm->p2pRecvs[from] ? comm->p2pRecvs[from]->getNext() : NULL;
|
||||
struct ncclP2Pinfo* send = comm->p2pSends[to] ? comm->p2pSends[to]->getNext() : NULL;
|
||||
if (recv != NULL || send != NULL) {
|
||||
ssize_t totRecvBytes = -1, totSendBytes = -1;
|
||||
if (recv != NULL) totRecvBytes = recv->nbytes;
|
||||
@@ -311,15 +310,11 @@ sched_delta:
|
||||
sendOffset += sendChunkSize;
|
||||
chunk++;
|
||||
} while (sendRemaining || recvRemaining);
|
||||
if (recv) {
|
||||
NCCLCHECKGOTO(dequeueP2pInfo(p2pRecvs+from), ret, group_cleanup);
|
||||
comm->p2pRecvCount--;
|
||||
}
|
||||
if (send) {
|
||||
NCCLCHECKGOTO(dequeueP2pInfo(p2pSends+to), ret, group_cleanup);
|
||||
comm->p2pSendCount--;
|
||||
}
|
||||
if (recv) comm->p2pRecvCount--;
|
||||
if (send) comm->p2pSendCount--;
|
||||
}
|
||||
if (recv == NULL && comm->p2pRecvs[from]) comm->p2pRecvs[from]->recycle();
|
||||
if (send == NULL && comm->p2pSends[to]) comm->p2pSends[to]->recycle();
|
||||
index++;
|
||||
if (index == 1 && deltas[1] == deltas[0]) index++;
|
||||
if (index == 2 && deltas[2] == deltas[0]) index++;
|
||||
@@ -419,11 +414,9 @@ group_cleanup:
|
||||
comm->asyncTotalSize = 0;
|
||||
// Dequeue p2p lists
|
||||
if (comm->p2pSendCount > 0 || comm->p2pRecvCount > 0) {
|
||||
struct ncclP2Plist* p2pSends = comm->p2pSends;
|
||||
struct ncclP2Plist* p2pRecvs = comm->p2pRecvs;
|
||||
for (int peer=0; peer<comm->nRanks; peer++) {
|
||||
while (p2pSends[peer].head != NULL) dequeueP2pInfo(p2pSends+peer);
|
||||
while (p2pRecvs[peer].head != NULL) dequeueP2pInfo(p2pRecvs+peer);
|
||||
if (comm->p2pSends[peer]) comm->p2pSends[peer]->recycle();
|
||||
if (comm->p2pRecvs[peer]) comm->p2pRecvs[peer]->recycle();
|
||||
}
|
||||
comm->p2pSendCount = comm->p2pRecvCount = 0;
|
||||
}
|
||||
|
||||
@@ -16,4 +16,29 @@
|
||||
#define ALIGN_SIZE(size, align) \
|
||||
size = ((size + (align) - 1) / (align)) * (align);
|
||||
|
||||
#if !__CUDA_ARCH__
|
||||
#ifndef __host__
|
||||
#define __host__
|
||||
#endif
|
||||
#ifndef __device__
|
||||
#define __device__
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template<typename X, typename Y, typename Z = decltype(X()+Y())>
|
||||
__host__ __device__ constexpr Z divUp(X x, Y y) {
|
||||
return (x+y-1)/y;
|
||||
}
|
||||
|
||||
template<typename X, typename Y, typename Z = decltype(X()+Y())>
|
||||
__host__ __device__ constexpr Z roundUp(X x, Y y) {
|
||||
return (x+y-1) - (x+y-1)%y;
|
||||
}
|
||||
|
||||
// assumes second argument is a power of 2
|
||||
template<typename X, typename Z = decltype(X()+int())>
|
||||
__host__ __device__ constexpr Z alignUp(X x, int a) {
|
||||
return (x+a-1) & Z(-a);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
#include <sys/mman.h>
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCudaHostCalloc(T** ptr, size_t nelem) {
|
||||
static ncclResult_t ncclCudaHostCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
|
||||
CUDACHECK(hipHostMalloc(ptr, nelem*sizeof(T), hipHostMallocMapped));
|
||||
memset(*ptr, 0, nelem*sizeof(T));
|
||||
INFO(NCCL_ALLOC, "Cuda Host Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr);
|
||||
INFO(NCCL_ALLOC, "%s:%d Cuda Host Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclCudaHostCalloc(...) ncclCudaHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
static inline ncclResult_t ncclCudaHostFree(void* ptr) {
|
||||
CUDACHECK(hipHostFree(ptr));
|
||||
@@ -27,7 +28,7 @@ static inline ncclResult_t ncclCudaHostFree(void* ptr) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCalloc(T** ptr, size_t nelem) {
|
||||
static ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
|
||||
void* p = malloc(nelem*sizeof(T));
|
||||
if (p == NULL) {
|
||||
WARN("Failed to malloc %ld bytes", nelem*sizeof(T));
|
||||
@@ -35,9 +36,10 @@ static ncclResult_t ncclCalloc(T** ptr, size_t nelem) {
|
||||
}
|
||||
memset(p, 0, nelem*sizeof(T));
|
||||
*ptr = (T*)p;
|
||||
INFO(NCCL_ALLOC, "Mem Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr);
|
||||
INFO(NCCL_ALLOC, "%s:%d Mem Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclCalloc(...) ncclCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
struct __attribute__ ((aligned(64))) allocationTracker {
|
||||
union {
|
||||
@@ -53,7 +55,7 @@ static_assert(sizeof(struct allocationTracker) == 64, "allocationTracker must be
|
||||
extern struct allocationTracker allocTracker[];
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCudaCalloc(T** ptr, size_t nelem, bool isFineGrain = false) {
|
||||
static ncclResult_t ncclCudaCallocDebug(const char *filefunc, int line, T** ptr, size_t nelem, bool isFineGrain = false) {
|
||||
// Need async stream for P2P pre-connect + CUDA Graph
|
||||
hipStream_t stream;
|
||||
CUDACHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
|
||||
@@ -64,7 +66,7 @@ static ncclResult_t ncclCudaCalloc(T** ptr, size_t nelem, bool isFineGrain = fal
|
||||
CUDACHECK(hipMemsetAsync(*ptr, 0, nelem*sizeof(T), stream));
|
||||
CUDACHECK(hipStreamSynchronize(stream));
|
||||
CUDACHECK(hipStreamDestroy(stream));
|
||||
INFO(NCCL_ALLOC, "Cuda Alloc Size %ld pointer %p", nelem*sizeof(T), *ptr);
|
||||
INFO(NCCL_ALLOC, "%s:%d Cuda Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr);
|
||||
int dev;
|
||||
CUDACHECK(hipGetDevice(&dev));
|
||||
if (dev < MAX_ALLOC_TRACK_NGPU) {
|
||||
@@ -73,6 +75,7 @@ static ncclResult_t ncclCudaCalloc(T** ptr, size_t nelem, bool isFineGrain = fal
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclCudaCalloc(...) ncclCudaCallocDebug(__FILE__, __LINE__, __VA_ARGS__)
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCudaMemcpy(T* dst, T* src, size_t nelem) {
|
||||
@@ -93,7 +96,7 @@ static bool hasFineGrainVramPcie() {
|
||||
// Allocate memory to be potentially ibv_reg_mr'd. This needs to be
|
||||
// allocated on separate pages as those pages will be marked DONTFORK
|
||||
// and if they are shared, that could cause a crash in a child process
|
||||
static ncclResult_t ncclIbMalloc(void** ptr, size_t size) {
|
||||
static ncclResult_t ncclIbMallocDebug(void** ptr, size_t size, const char *filefunc, int line) {
|
||||
size_t page_size = sysconf(_SC_PAGESIZE);
|
||||
void* p;
|
||||
int size_aligned = ROUNDUP(size, page_size);
|
||||
@@ -101,8 +104,9 @@ static ncclResult_t ncclIbMalloc(void** ptr, size_t size) {
|
||||
if (ret != 0) return ncclSystemError;
|
||||
memset(p, 0, size);
|
||||
*ptr = p;
|
||||
INFO(NCCL_ALLOC, "Ib Alloc Size %ld pointer %p", size, *ptr);
|
||||
INFO(NCCL_ALLOC, "%s:%d Ib Alloc Size %ld pointer %p", filefunc, line, size, *ptr);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclIbMalloc(...) ncclIbMallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -16,6 +17,7 @@ ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commSt
|
||||
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size);
|
||||
ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size);
|
||||
ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size);
|
||||
ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks);
|
||||
ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, hipIpcMemHandle_t* ipc, void** ptr);
|
||||
ncclResult_t bootstrapRemFree(int id, int rank, void* commState);
|
||||
ncclResult_t bootstrapClose(void* commState);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
@@ -25,6 +25,9 @@
|
||||
extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args); \
|
||||
extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first); \
|
||||
|
||||
//extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(); \
|
||||
//extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(ncclWorkElem c); \
|
||||
|
||||
#define DECL4(func, algo, redop, type) \
|
||||
DECL5(func, algo, SIMPLE, redop, type) \
|
||||
DECL5(func, algo, LL, redop, type) \
|
||||
@@ -35,6 +38,19 @@
|
||||
DECL4(func, TREE, redop, type) \
|
||||
DECL4(func, COLLNET, redop, type)
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#define DECL2(func, redop) \
|
||||
DECL3(func, redop, int8_t) \
|
||||
DECL3(func, redop, uint8_t) \
|
||||
DECL3(func, redop, int32_t) \
|
||||
DECL3(func, redop, uint32_t) \
|
||||
DECL3(func, redop, int64_t) \
|
||||
DECL3(func, redop, uint64_t) \
|
||||
DECL3(func, redop, half) \
|
||||
DECL3(func, redop, float) \
|
||||
DECL3(func, redop, double) \
|
||||
DECL3(func, redop, __nv_bfloat16)
|
||||
#else
|
||||
#define DECL2(func, redop) \
|
||||
DECL3(func, redop, int8_t) \
|
||||
DECL3(func, redop, uint8_t) \
|
||||
@@ -46,12 +62,14 @@
|
||||
DECL3(func, redop, float) \
|
||||
DECL3(func, redop, double) \
|
||||
DECL3(func, redop, rccl_bfloat16)
|
||||
#endif
|
||||
|
||||
#define DECL(func) \
|
||||
DECL2(func, Sum) \
|
||||
DECL2(func, Prod) \
|
||||
DECL2(func, Min) \
|
||||
DECL2(func, Max)
|
||||
DECL2(func, Max) \
|
||||
DECL2(func, Avg) \
|
||||
|
||||
#define DECL_ALL \
|
||||
DECL2(Broadcast, Sum) \
|
||||
|
||||
@@ -82,6 +82,7 @@ struct ncclComm {
|
||||
int nRanks; // number of GPUs in communicator
|
||||
int cudaDev; // my cuda device index
|
||||
int64_t busId; // my PCI bus ID in int format
|
||||
cpu_set_t cpuAffinity; // CPU affinity of the GPU
|
||||
|
||||
int node;
|
||||
int nNodes;
|
||||
@@ -162,11 +163,13 @@ struct ncclComm {
|
||||
struct ncclInfo* asyncOps;
|
||||
int asyncOpCount;
|
||||
size_t asyncTotalSize;
|
||||
ssize_t channelSize;
|
||||
int lastChannel;
|
||||
enum { ROUND_ROBIN, SHORTEST_QUEUE } asyncAllocMode;
|
||||
|
||||
//list of async p2p operation queued in a group semantics
|
||||
struct ncclP2Plist* p2pSends;
|
||||
struct ncclP2Plist* p2pRecvs;
|
||||
ncclP2Plist** p2pSends;
|
||||
ncclP2Plist** p2pRecvs;
|
||||
int p2pSendCount;
|
||||
int p2pRecvCount;
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
/*************************************************************************
|
||||
<<<<<<< HEAD
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
=======
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
>>>>>>> nccl/master
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
|
||||
|
||||
#define NCCL_NUM_FUNCTIONS 5 // SendRecv not included for now
|
||||
typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv} ncclFunc_t;
|
||||
typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclNumFuncs} ncclFunc_t;
|
||||
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+1];
|
||||
|
||||
#define NCCL_NUM_ALGORITHMS 3 // Tree/Ring/CollNet
|
||||
@@ -142,6 +142,8 @@ struct ncclRing {
|
||||
// devices. Ordered from current device.
|
||||
int* userRanks;
|
||||
int* devUserRanks;
|
||||
|
||||
int index; // This rank's index in the ring
|
||||
};
|
||||
|
||||
|
||||
@@ -173,7 +175,7 @@ struct ncclPeer {
|
||||
struct ncclDevComm;
|
||||
|
||||
#pragma pack(push) /* push current alignment to stack */
|
||||
#pragma pack(4) /* set alignment to 4 bytes boundary */
|
||||
#pragma pack(8) /* set alignment to 4 bytes boundary */
|
||||
#define NCCL_MAX_WORK_ELEMENTS 1
|
||||
#define NCCL_MAX_GROUPS (NCCL_MAX_NTHREADS/WARP_SIZE)
|
||||
|
||||
@@ -255,6 +257,7 @@ struct ncclChannel {
|
||||
// Operation list for aggregation
|
||||
struct ncclWork* workFifo;
|
||||
int workCount;
|
||||
size_t totalSize;
|
||||
uint64_t workFifoTail; // Only used by CPU
|
||||
|
||||
#ifdef ENABLE_PROFILING
|
||||
@@ -285,7 +288,6 @@ struct ncclProf {
|
||||
struct {
|
||||
uint64_t total_cycle;
|
||||
uint64_t wait_cycle[MAXCHANNELS]; // total wait cycle
|
||||
uint64_t wait_recv_cycle[MAXCHANNELS]; // recv wait cycle
|
||||
// primtive cycles
|
||||
uint64_t send_cycle;
|
||||
uint64_t directSend_cycle;
|
||||
@@ -376,4 +378,9 @@ struct ncclDevComm {
|
||||
#endif
|
||||
};
|
||||
|
||||
struct ncclDevCommAndChannels {
|
||||
ncclDevComm comm;
|
||||
ncclChannel channels[MAXCHANNELS];
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
#include "group.h"
|
||||
#include "collectives.h"
|
||||
|
||||
#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64)
|
||||
#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */
|
||||
|
||||
size_t ncclKernMaxLocalSize();
|
||||
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info);
|
||||
ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast);
|
||||
@@ -31,39 +34,22 @@ ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph);
|
||||
struct ncclQueueElem {
|
||||
struct ncclWorkElem work;
|
||||
struct ncclProxyArgs proxyArgs;
|
||||
struct ncclQueueElem* next;
|
||||
};
|
||||
|
||||
// Store enqueue elements in a list
|
||||
struct ncclQueueElemList {
|
||||
struct ncclQueueElem* head;
|
||||
struct ncclQueueElem* tail;
|
||||
};
|
||||
typedef ncclRecyclableList<struct ncclQueueElem> ncclQueueElemList;
|
||||
|
||||
// Structure passed to CUDA graph
|
||||
struct ncclQueueInfo {
|
||||
ncclComm_t comm;
|
||||
int maxChannels; // Dynamic version of gridDim
|
||||
ncclResult_t ret; // Return value of host setup call
|
||||
struct ncclQueueElemList elemList;
|
||||
ncclQueueElemList* elemList;
|
||||
};
|
||||
|
||||
// Get next element from enqueue list
|
||||
static ncclResult_t ncclAddQueueElem(struct ncclQueueInfo* eqInfo, struct ncclQueueElem** elemOut) {
|
||||
if (eqInfo == NULL) return ncclInternalError;
|
||||
struct ncclQueueElemList* list = &eqInfo->elemList;
|
||||
if (list->tail != NULL) {
|
||||
*elemOut = list->tail;
|
||||
memset(*elemOut, 0, sizeof(struct ncclWorkElem) + sizeof(struct ncclProxyArgs));
|
||||
} else {
|
||||
NCCLCHECK(ncclCalloc(&list->tail, 1));
|
||||
*elemOut = list->tail;
|
||||
list->head = list->tail;
|
||||
}
|
||||
if (list->tail->next == NULL) {
|
||||
NCCLCHECK(ncclCalloc(&list->tail->next, 1));
|
||||
}
|
||||
list->tail = list->tail->next;
|
||||
static ncclResult_t ncclCreateQueueInfo(struct ncclQueueInfo** eqInfo, ncclComm_t comm) {
|
||||
NCCLCHECK(ncclCalloc(eqInfo, 1));
|
||||
(*eqInfo)->comm = comm;
|
||||
(*eqInfo)->elemList = new ncclQueueElemList();
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -72,7 +58,7 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) {
|
||||
if (eqInfo == NULL) return ncclInternalError;
|
||||
eqInfo->maxChannels = 0;
|
||||
eqInfo->ret = ncclSuccess;
|
||||
eqInfo->elemList.tail = eqInfo->elemList.head;
|
||||
eqInfo->elemList->recycle();
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -81,12 +67,7 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) {
|
||||
static void ncclDestroyQueueInfo(void* ptr) {
|
||||
if (ptr == NULL) return;
|
||||
struct ncclQueueInfo* eqInfo = (struct ncclQueueInfo*)ptr;
|
||||
struct ncclQueueElem* head = eqInfo->elemList.head;
|
||||
while (head != NULL) {
|
||||
struct ncclQueueElem* temp = head;
|
||||
head = head->next;
|
||||
free(temp);
|
||||
}
|
||||
delete eqInfo->elemList;
|
||||
free(eqInfo);
|
||||
}
|
||||
#endif // End include guard
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <ctype.h>
|
||||
#include <stdio.h>
|
||||
#include <sched.h>
|
||||
|
||||
ncclResult_t ncclTopoCudaPath(int cudaDev, char** path);
|
||||
|
||||
@@ -36,8 +37,8 @@ ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* topo, int64_t busId, int ne
|
||||
ncclResult_t ncclTopoGetIntraNetDev(struct ncclTopoSystem* system, int rank, struct ncclTopoGraph* graph, int channelId, int type, int* dev);
|
||||
ncclResult_t ncclTopoGetLinkType(struct ncclTopoSystem* system, int cudaDev1, int cudaDev2, bool* isXGMI, bool direct_only=false);
|
||||
|
||||
// Set CPU affinity
|
||||
ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank);
|
||||
// Find CPU affinity
|
||||
ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu_set_t* affinity);
|
||||
|
||||
#define NCCL_TOPO_CPU_ARCH_X86 1
|
||||
#define NCCL_TOPO_CPU_ARCH_POWER 2
|
||||
@@ -107,6 +108,6 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa
|
||||
|
||||
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph);
|
||||
#include "info.h"
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, float* time);
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -12,32 +12,16 @@
|
||||
struct ncclP2Pinfo {
|
||||
void* buff;
|
||||
ssize_t nbytes;
|
||||
struct ncclP2Pinfo* next;
|
||||
};
|
||||
|
||||
struct ncclP2Plist {
|
||||
struct ncclP2Pinfo *head;
|
||||
struct ncclP2Pinfo *tail;
|
||||
};
|
||||
typedef ncclRecyclableList<struct ncclP2Pinfo> ncclP2Plist;
|
||||
|
||||
static ncclResult_t enqueueP2pInfo(ncclP2Plist* p2p, void* buff, ssize_t nBytes) {
|
||||
if (p2p == NULL) return ncclInternalError;
|
||||
static ncclResult_t ncclSaveP2pInfo(ncclP2Plist* &p2p, void* buff, ssize_t nBytes) {
|
||||
if (p2p == NULL) p2p = new ncclP2Plist();
|
||||
struct ncclP2Pinfo* next;
|
||||
NCCLCHECK(ncclCalloc(&next, 1));
|
||||
NCCLCHECK(p2p->getNewElem(&next));
|
||||
next->buff = buff;
|
||||
next->nbytes = nBytes;
|
||||
if (p2p->tail != NULL) p2p->tail->next = next;
|
||||
p2p->tail = next;
|
||||
if (p2p->head == NULL) p2p->head = next;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t dequeueP2pInfo(ncclP2Plist* p2p) {
|
||||
if (p2p == NULL) return ncclInternalError;
|
||||
struct ncclP2Pinfo* temp = p2p->head;
|
||||
p2p->head = p2p->head->next;
|
||||
if (p2p->tail == temp) p2p->tail = NULL;
|
||||
free(temp);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -31,12 +31,13 @@ union socketAddress {
|
||||
struct sockaddr_in6 sin6;
|
||||
};
|
||||
|
||||
/* Format a string representation of a (struct sockaddr *) socket address using getnameinfo()
|
||||
/* Format a string representation of a (union socketAddress *) socket address using getnameinfo()
|
||||
*
|
||||
* Output: "IPv4/IPv6 address<port>"
|
||||
*/
|
||||
static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
|
||||
if (buf == NULL || saddr == NULL) return NULL;
|
||||
static inline const char *socketToString(union socketAddress *addr, char *buf) {
|
||||
if (buf == NULL || addr == NULL) return NULL;
|
||||
struct sockaddr *saddr = &addr->sa;
|
||||
if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0]='\0'; return buf; }
|
||||
char host[NI_MAXHOST], service[NI_MAXSERV];
|
||||
(void) getnameinfo(saddr, sizeof(union socketAddress), host, NI_MAXHOST, service, NI_MAXSERV, NI_NUMERICHOST|NI_NUMERICSERV);
|
||||
@@ -44,8 +45,9 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
|
||||
return buf;
|
||||
}
|
||||
|
||||
static inline uint16_t socketToPort(struct sockaddr *saddr) {
|
||||
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
|
||||
static inline uint16_t socketToPort(union socketAddress *addr) {
|
||||
struct sockaddr *saddr = &addr->sa;
|
||||
return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
|
||||
}
|
||||
|
||||
/* Allow the user to force the IPv4/IPv6 interface selection */
|
||||
@@ -86,7 +88,7 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre
|
||||
if (family != AF_INET && family != AF_INET6)
|
||||
continue;
|
||||
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Found interface %s:%s", interface->ifa_name, socketToString(interface->ifa_addr, line));
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Found interface %s:%s", interface->ifa_name, socketToString((union socketAddress *)interface->ifa_addr, line));
|
||||
|
||||
/* Allow the caller to force the socket family type */
|
||||
if (sock_family != -1 && family != sock_family)
|
||||
@@ -195,13 +197,13 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd
|
||||
// Store the interface name
|
||||
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
|
||||
|
||||
TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr->sa), line_a));
|
||||
TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(localAddrs+found, line), socketToString(remoteAddr, line_a));
|
||||
found++;
|
||||
if (found == maxIfs) break;
|
||||
}
|
||||
|
||||
if (found == 0) {
|
||||
WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr->sa), line_a));
|
||||
WARN("Net : No interface found in the same subnet as remote address %s", socketToString(remoteAddr, line_a));
|
||||
}
|
||||
freeifaddrs(interfaces);
|
||||
return found;
|
||||
@@ -339,7 +341,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
|
||||
localAddr->sin.sin_port = htons(port++);
|
||||
#endif
|
||||
|
||||
if (socketToPort(&localAddr->sa)) {
|
||||
if (socketToPort(localAddr)) {
|
||||
// Port is forced by env. Make sure we get the port.
|
||||
int opt = 1;
|
||||
#if defined(SO_REUSEPORT)
|
||||
@@ -358,7 +360,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
|
||||
|
||||
#ifdef ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(&localAddr->sa, line));
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(localAddr, line));
|
||||
#endif
|
||||
|
||||
/* Put the socket in listen mode
|
||||
@@ -370,10 +372,12 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
|
||||
}
|
||||
|
||||
static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
/* IPv4/IPv6 support */
|
||||
int family = remoteAddr->sa.sa_family;
|
||||
if (family != AF_INET && family != AF_INET6) {
|
||||
WARN("Error : connecting to address with family %d is neither AF_INET(%d) nor AF_INET6(%d)", family, AF_INET, AF_INET6);
|
||||
WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
|
||||
socketToString(remoteAddr, line), family, AF_INET, AF_INET6);
|
||||
return ncclInternalError;
|
||||
}
|
||||
int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
|
||||
@@ -392,8 +396,7 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
|
||||
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt");
|
||||
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/
|
||||
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(&remoteAddr->sa, line));
|
||||
TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(remoteAddr, line));
|
||||
|
||||
int ret;
|
||||
int timedout_retries = 0;
|
||||
@@ -409,25 +412,26 @@ retry:
|
||||
goto retry;
|
||||
}
|
||||
}
|
||||
WARN("Connect to %s failed : %s", socketToString(&remoteAddr->sa, line), strerror(errno));
|
||||
WARN("Net : Connect to %s failed : %s", socketToString(remoteAddr, line), strerror(errno));
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
#define NCCL_SOCKET_SEND 0
|
||||
#define NCCL_SOCKET_RECV 1
|
||||
static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) {
|
||||
static ncclResult_t socketProgressOpt(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset, int block) {
|
||||
int bytes = 0;
|
||||
char* data = (char*)ptr;
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
do {
|
||||
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
|
||||
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
|
||||
if (op == NCCL_SOCKET_RECV && bytes == 0) {
|
||||
WARN("Net : Connection closed by remote peer");
|
||||
WARN("Net : Connection closed by remote peer %s", socketToString(addr, line));
|
||||
return ncclSystemError;
|
||||
}
|
||||
if (bytes == -1) {
|
||||
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
|
||||
WARN("Call to recv failed : %s", strerror(errno));
|
||||
WARN("Net : Call to recv from %s failed : %s", socketToString(addr, line), strerror(errno));
|
||||
return ncclSystemError;
|
||||
} else {
|
||||
bytes = 0;
|
||||
@@ -438,25 +442,25 @@ static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int*
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
|
||||
return socketProgressOpt(op, fd, ptr, size, offset, 0);
|
||||
static ncclResult_t socketProgress(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset) {
|
||||
return socketProgressOpt(op, fd, addr, ptr, size, offset, 0);
|
||||
}
|
||||
|
||||
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
|
||||
static ncclResult_t socketWait(int op, int fd, union socketAddress *addr, void* ptr, int size, int* offset) {
|
||||
while (*offset < size)
|
||||
NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1));
|
||||
NCCLCHECK(socketProgressOpt(op, fd, addr, ptr, size, offset, 1));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t socketSend(int fd, void* ptr, int size) {
|
||||
static ncclResult_t socketSend(int fd, union socketAddress *addr, void* ptr, int size) {
|
||||
int offset = 0;
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, ptr, size, &offset));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, addr, ptr, size, &offset));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t socketRecv(int fd, void* ptr, int size) {
|
||||
static ncclResult_t socketRecv(int fd, union socketAddress *addr, void* ptr, int size) {
|
||||
int offset = 0;
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, addr, ptr, size, &offset));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
||||
@@ -37,4 +37,76 @@ static long log2i(long n) {
|
||||
return l;
|
||||
}
|
||||
|
||||
// Recyclable list that avoids frequent malloc/free
|
||||
template<typename T>
|
||||
struct ncclListElem {
|
||||
T data;
|
||||
struct ncclListElem* next;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class ncclRecyclableList {
|
||||
private:
|
||||
struct ncclListElem<T>* head;
|
||||
struct ncclListElem<T>* tail;
|
||||
struct ncclListElem<T>* cursor;
|
||||
int n;
|
||||
|
||||
public:
|
||||
ncclRecyclableList() {
|
||||
tail = cursor = head = NULL;
|
||||
n = 0;
|
||||
}
|
||||
|
||||
int count() const { return n; }
|
||||
|
||||
// Get a new element from the list and return pointer
|
||||
ncclResult_t getNewElem(T** dataOut) {
|
||||
if (tail != NULL) {
|
||||
*dataOut = &tail->data;
|
||||
memset(*dataOut, 0, sizeof(T));
|
||||
} else {
|
||||
NCCLCHECK(ncclCalloc(&tail, 1));
|
||||
*dataOut = &tail->data;
|
||||
cursor = head = tail;
|
||||
}
|
||||
if (tail->next == NULL) {
|
||||
NCCLCHECK(ncclCalloc(&tail->next, 1));
|
||||
}
|
||||
tail = tail->next;
|
||||
n += 1;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
T* begin() {
|
||||
if (head == NULL || head == tail) return NULL;
|
||||
cursor = head->next;
|
||||
return &head->data;
|
||||
}
|
||||
|
||||
// Get next element from the list during an iteration
|
||||
T* getNext() {
|
||||
// tail always points to the next element to be enqueued
|
||||
// hence does not contain valid data
|
||||
if (cursor == NULL || cursor == tail) return NULL;
|
||||
T* rv = &cursor->data;
|
||||
cursor = cursor->next;
|
||||
return rv;
|
||||
}
|
||||
|
||||
// Recycle the list without freeing the space
|
||||
void recycle() {
|
||||
tail = cursor = head;
|
||||
n = 0;
|
||||
}
|
||||
|
||||
~ncclRecyclableList() {
|
||||
while (head != NULL) {
|
||||
struct ncclListElem<T>* temp = head;
|
||||
head = head->next;
|
||||
free(temp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -92,21 +92,17 @@ ncclResult_t initNetPlugin(ncclNet_t** net, ncclCollNet_t** collnet) {
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
ncclNet_t* extNet = (ncclNet_t*) dlsym(netPluginLib, STR(NCCL_PLUGIN_SYMBOL));
|
||||
if (extNet == NULL) {
|
||||
*net = (ncclNet_t*) dlsym(netPluginLib, STR(NCCL_PLUGIN_SYMBOL));
|
||||
if (*net == NULL) {
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_PLUGIN_SYMBOL) " symbol.");
|
||||
} else if (initNet(extNet) == ncclSuccess) {
|
||||
*net = extNet;
|
||||
// Check for CollNet
|
||||
ncclCollNet_t* extCollNet = (ncclCollNet_t*) dlsym(netPluginLib, STR(NCCL_COLLNET_PLUGIN_SYMBOL));
|
||||
if (extCollNet == NULL) {
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol.");
|
||||
} else if (initCollNet(extCollNet) == ncclSuccess) {
|
||||
*collnet = extCollNet;
|
||||
}
|
||||
if (netPluginLib != NULL) dlclose(netPluginLib);
|
||||
return ncclSuccess;
|
||||
}
|
||||
if (netPluginLib != NULL) dlclose(netPluginLib);
|
||||
// Check for CollNet
|
||||
*collnet = (ncclCollNet_t*) dlsym(netPluginLib, STR(NCCL_COLLNET_PLUGIN_SYMBOL));
|
||||
if (*collnet == NULL) {
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol.");
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -114,13 +110,27 @@ ncclResult_t initNet() {
|
||||
// Always initialize bootstrap network
|
||||
NCCLCHECK(bootstrapNetInit());
|
||||
|
||||
NCCLCHECK(initNetPlugin(&ncclNet, &ncclCollNet));
|
||||
if (ncclNet != NULL) return ncclSuccess;
|
||||
if (initNet(&ncclNetIb) == ncclSuccess) {
|
||||
ncclNet = &ncclNetIb;
|
||||
} else {
|
||||
NCCLCHECK(initNet(&ncclNetSocket));
|
||||
ncclNet = &ncclNetSocket;
|
||||
// Initialize main communication network
|
||||
ncclNet_t* nets[3] = { NULL, &ncclNetIb, &ncclNetSocket };
|
||||
ncclCollNet_t* collNets[3] = { NULL, NULL, NULL };
|
||||
NCCLCHECK(initNetPlugin(nets+0, collNets+0));
|
||||
char* netName = getenv("NCCL_NET");
|
||||
|
||||
for (int i=0; i<3; i++) {
|
||||
if (nets[i] == NULL) continue;
|
||||
if (netName && strcmp(netName, nets[i]->name) != 0) continue;
|
||||
// net plugin is already initialized
|
||||
if (initNet(nets[i]) != ncclSuccess) continue;
|
||||
ncclNet = nets[i];
|
||||
if (collNets[i] && initCollNet(collNets[i]) == ncclSuccess) {
|
||||
ncclCollNet = collNets[i];
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (ncclNet == NULL) {
|
||||
WARN("Error: network %s not found.", netName ? netName : "");
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -291,6 +301,10 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
return ncclSuccess;
|
||||
free(comm->connectSend);
|
||||
free(comm->connectRecv);
|
||||
for (int peer=0; peer<comm->nRanks; peer++) {
|
||||
delete comm->p2pSends[peer];
|
||||
delete comm->p2pRecvs[peer];
|
||||
}
|
||||
free(comm->p2pSends);
|
||||
free(comm->p2pRecvs);
|
||||
free(comm->asyncOps);
|
||||
@@ -298,20 +312,18 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
#ifdef ENABLE_PROFILING
|
||||
struct ncclProf* prof = (struct ncclProf*)malloc(sizeof(struct ncclProf));
|
||||
CUDACHECK(hipMemcpy(prof, comm->hostDevComm.devProf, sizeof(struct ncclProf), hipMemcpyDeviceToHost));
|
||||
uint64_t wait_cycle = 0, wait_recv_cycle = 0;
|
||||
uint64_t wait_cycle = 0;
|
||||
for (int chan=0; chan<comm->nChannels; chan++) {
|
||||
wait_cycle += prof->wait_cycle[chan];
|
||||
wait_recv_cycle += prof->wait_recv_cycle[chan];
|
||||
}
|
||||
#define VEGA_GPU_RTC_FREQUENCY 2.5E7
|
||||
if (comm->rank == 0) {
|
||||
INFO(NCCL_INIT, "# %4s %6s %6s %6s %6s %6s %7s %6s %6s %6s %6s %6s", "Rank", "total", " wait", "w_recv", "send", "rcRdS", "dRcRdCS", "dRcCS", "dRc", "cS", "rc", "rcCS");
|
||||
INFO(NCCL_INIT, "# %4s %6s %6s %6s %6s %6s %7s %6s %6s %6s %6s %6s", "", "(s)", "(s)", "(s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)");
|
||||
INFO(NCCL_INIT, "# %4s %6s %6s %6s %6s %7s %6s %6s %6s %6s %6s", "Rank", "total", " wait", "send", "rcRdS", "dRcRdCS", "dRcCS", "dRc", "cS", "rc", "rcCS");
|
||||
INFO(NCCL_INIT, "# %4s %6s %6s %6s %6s %7s %6s %6s %6s %6s %6s", "", "(s)", "(s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)", "(GB/s)");
|
||||
}
|
||||
INFO(NCCL_INIT, "# %4d %6.4f %6.4f %6.4f %6.2f %6.2f %7.2f %6.2f %6.2f %6.2f %6.2f %6.2f",
|
||||
INFO(NCCL_INIT, "# %4d %6.4f %6.4f %6.2f %6.2f %7.2f %6.2f %6.2f %6.2f %6.2f %6.2f",
|
||||
comm->rank, (double)prof->total_cycle/VEGA_GPU_RTC_FREQUENCY/comm->nChannels,
|
||||
(double)wait_cycle/VEGA_GPU_RTC_FREQUENCY/comm->nChannels,
|
||||
(double)wait_recv_cycle/VEGA_GPU_RTC_FREQUENCY/comm->nChannels,
|
||||
(prof->send_cycle) ? (double)prof->send_byte*comm->nChannels/((double)prof->send_cycle/VEGA_GPU_RTC_FREQUENCY*1.0E9) : 0,
|
||||
(prof->recvReduceSend_cycle) ? (double)prof->recvReduceSend_byte*comm->nChannels/((double)prof->recvReduceSend_cycle/VEGA_GPU_RTC_FREQUENCY*1.0E9) : 0,
|
||||
(prof->directRecvReduceCopySend_cycle) ? (double)prof->directRecvReduceCopySend_byte*comm->nChannels/((double)prof->directRecvReduceCopySend_cycle/VEGA_GPU_RTC_FREQUENCY*1.0E9) : 0,
|
||||
@@ -348,8 +360,7 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
if (comm->bootstrap)
|
||||
NCCLCHECK(bootstrapClose(comm->bootstrap));
|
||||
|
||||
CUDACHECK(hipFree(comm->hostDevComm.channels));
|
||||
CUDACHECK(hipFree(comm->devComm));
|
||||
CUDACHECK(hipFree((ncclDevCommAndChannels*)comm->devComm));
|
||||
|
||||
for (int channel=0; channel<MAXCHANNELS; channel++)
|
||||
NCCLCHECK(freeChannel(comm->channels+channel, comm->nRanks));
|
||||
@@ -387,6 +398,7 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
|
||||
RCCL_PARAM(CliqueIgnoreTopo, "CLIQUE_IGNORE_TOPO", 0);
|
||||
RCCL_PARAM(P2pNetDisable, "P2P_NET_DISABLE", 0);
|
||||
NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2);
|
||||
|
||||
static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
if (ndev < 1) {
|
||||
@@ -427,7 +439,7 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
|
||||
NCCLCHECK(ncclCudaHostCalloc((uint32_t**)&comm->abortFlag, 1));
|
||||
comm->hostDevComm.abortFlag = comm->abortFlag;
|
||||
STORE(comm->abortFlag, 0);
|
||||
*comm->abortFlag = 0;
|
||||
|
||||
comm->collOpCount = 0;
|
||||
comm->p2pOpCount = 0x8000;
|
||||
@@ -452,9 +464,15 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
NCCLCHECK(ncclCalloc(&comm->asyncOps, NCCL_MAX_OPS));
|
||||
comm->asyncOpCount = 0;
|
||||
comm->asyncTotalSize = 0;
|
||||
comm->channelSize = ncclParamAggChannelSize();
|
||||
comm->asyncAllocMode = ncclComm::ROUND_ROBIN;
|
||||
char* str = getenv("NCCL_AGG_ALLOC_MODE");
|
||||
if (str) INFO(NCCL_ENV, "NCCL_AGG_ALLOC_MODE set by environment to %s", str);
|
||||
if (str && strcmp(str, "SHORTEST_QUEUE") == 0) {
|
||||
comm->asyncAllocMode = ncclComm::SHORTEST_QUEUE;
|
||||
}
|
||||
|
||||
NCCLCHECK(ncclCalloc(&comm->enqueueInfo, 1));
|
||||
comm->enqueueInfo->comm = comm;
|
||||
NCCLCHECK(ncclCreateQueueInfo(&comm->enqueueInfo, comm));
|
||||
comm->lastSetupNode = NULL;
|
||||
comm->lastCudaGraphId = -1;
|
||||
|
||||
@@ -477,10 +495,14 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
}
|
||||
|
||||
static ncclResult_t devCommSetup(ncclComm_t comm) {
|
||||
ncclDevCommAndChannels *devCommAndChans;
|
||||
NCCLCHECK(ncclCudaCalloc(&devCommAndChans, 1));
|
||||
comm->devComm = &devCommAndChans->comm;
|
||||
comm->hostDevComm.channels = devCommAndChans->channels;
|
||||
|
||||
// Duplicate the channels on the device
|
||||
int nChannels = std::max(comm->nChannels, comm->p2pnChannels);
|
||||
NCCLCHECK(ncclCudaCalloc(&comm->hostDevComm.channels, std::max(comm->nChannels, comm->p2pnChannels)));
|
||||
NCCLCHECK(ncclCudaMemcpy(comm->hostDevComm.channels, comm->channels, std::max(comm->nChannels, comm->p2pnChannels)));
|
||||
NCCLCHECK(ncclCudaMemcpy(comm->hostDevComm.channels, comm->channels, nChannels));
|
||||
|
||||
// Copy userRanks and peers
|
||||
for (int r=0; r<comm->nChannels; r++) {
|
||||
@@ -488,7 +510,6 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
|
||||
}
|
||||
|
||||
// Duplicate the dev comm on the device
|
||||
NCCLCHECK(ncclCudaCalloc(&comm->devComm, 1));
|
||||
NCCLCHECK(ncclCudaMemcpy(comm->devComm, &comm->hostDevComm, 1));
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -534,23 +555,23 @@ static ncclResult_t setupChannel(struct ncclComm* comm, int channelId, int rank,
|
||||
NCCLCHECK(initChannel(comm, channelId));
|
||||
|
||||
struct ncclRing* ring = &comm->channels[channelId].ring;
|
||||
// Reorganize ranks to start with rank.
|
||||
int shift;
|
||||
for (shift = 0; shift<nranks; shift++) {
|
||||
if (ringRanks[shift] == rank) {
|
||||
break;
|
||||
}
|
||||
// Find our ring-distance from rank zero and reorganize ranks to start with rank.
|
||||
int ixZero=0, ixRank=0;
|
||||
for (int i=0; i < nranks; i++) {
|
||||
if (ringRanks[i] == 0) ixZero = i;
|
||||
if (ringRanks[i] == rank) ixRank = i;
|
||||
}
|
||||
ring->index = (ixRank-ixZero + nranks)%nranks;
|
||||
for (int i=0; i<nranks; i++) {
|
||||
ring->userRanks[i] = ringRanks[(i+shift)%nranks];
|
||||
ring->userRanks[i] = ringRanks[(i+ixRank)%nranks];
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
void* waitForNonNullPtr(void* p) {
|
||||
volatile void** ptr = (volatile void**) p;
|
||||
while (LOAD(ptr) == NULL) sched_yield();
|
||||
return (void*)(LOAD(ptr));
|
||||
while (*ptr == NULL) sched_yield();
|
||||
return (void*)*ptr;
|
||||
}
|
||||
|
||||
ncclResult_t initParams(struct ncclComm* comm) {
|
||||
@@ -564,7 +585,7 @@ ncclResult_t initParams(struct ncclComm* comm) {
|
||||
}
|
||||
|
||||
// Allocate/Set Intra Process Structures and set CG options
|
||||
ncclResult_t ncclCommSetIntra(struct ncclComm* comm, int rank, int ranks, struct ncclComm* comm0) {
|
||||
ncclResult_t ncclCommSetIntraProc(struct ncclComm* comm, int rank, int ranks, struct ncclComm* comm0) {
|
||||
comm->intraRank = rank;
|
||||
comm->intraRanks = ranks;
|
||||
comm->intraPhase = 0;
|
||||
@@ -694,37 +715,45 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
}
|
||||
|
||||
// Compute intra ranks and minimum CUDA Compute capabilities of intra-node GPUs and all GPUs
|
||||
int intraRank0 = -1, intraRank = -1, intraRanks = 0;
|
||||
int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0;
|
||||
int intraNodeRank0 = -1, intraNodeRank = -1, intraNodeRanks = 0;
|
||||
int myCompCap = allGather1Data[rank].cudaCompCap;
|
||||
int minCompCap = myCompCap, maxCompCap = myCompCap;
|
||||
uint64_t otherHostHash;
|
||||
int tmpNnodes = 1;
|
||||
int intraNodeGlobalRanks[256];
|
||||
for (int i = 0; i < nranks; i++) {
|
||||
if (allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) {
|
||||
// Rank is on same node
|
||||
if (intraNodeRanks == 0) intraNodeRank0 = i;
|
||||
if (i == rank) intraNodeRank = intraNodeRanks;
|
||||
intraNodeGlobalRanks[intraNodeRanks++] = i;
|
||||
if (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash) {
|
||||
if (intraRanks == 0) intraRank0 = i;
|
||||
if (i == rank) intraRank = intraRanks;
|
||||
intraRanks++;
|
||||
}
|
||||
} else { // Determine whether number of nodes is 2 (for use in tree pattern determination)
|
||||
if (tmpNnodes == 1) {
|
||||
otherHostHash = allGather1Data[i].peerInfo.hostHash;
|
||||
tmpNnodes = 2;
|
||||
} else if (tmpNnodes == 2 && otherHostHash != allGather1Data[i].peerInfo.hostHash) {
|
||||
tmpNnodes = 3;
|
||||
// Rank is in same process
|
||||
if (intraProcRanks == 0) intraProcRank0 = i;
|
||||
if (i == rank) intraProcRank = intraProcRanks;
|
||||
intraProcRanks++;
|
||||
}
|
||||
}
|
||||
minCompCap = std::min(allGather1Data[i].cudaCompCap, minCompCap);
|
||||
maxCompCap = std::max(allGather1Data[i].cudaCompCap, maxCompCap);
|
||||
}
|
||||
TRACE(NCCL_INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
|
||||
if (intraRank == -1 || intraRank0 == -1 || allGather1Data[intraRank0].comm == NULL) {
|
||||
WARN("Failed to determine intra ranks hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
|
||||
TRACE(NCCL_INIT,"hostHash[%d] %lx intraNodeRank %d intraNodeRanks %d intraNodeRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraNodeRank, intraNodeRanks, intraNodeRank0);
|
||||
TRACE(NCCL_INIT,"pidHash[%d] %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.pidHash, intraProcRank, intraProcRanks, intraProcRank0);
|
||||
if (intraProcRank == -1 || intraProcRank0 == -1 || allGather1Data[intraProcRank0].comm == NULL) {
|
||||
WARN("Failed to determine intra proc ranks rank %d hostHash %lx pidHash %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, allGather1Data[rank].peerInfo.pidHash,
|
||||
intraProcRank, intraProcRanks, intraProcRank0);
|
||||
return ncclInternalError;
|
||||
}
|
||||
struct ncclComm* intraRank0Comm = allGather1Data[intraRank0].comm;
|
||||
if (intraNodeRank == -1 || intraNodeRank0 == -1 || intraNodeRanks == 0) {
|
||||
WARN("Failed to determine intra node ranks rank %d hostHash %lx pidHash %lx intraNodeRank %d intraNodeRanks %d intraNodeRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, allGather1Data[rank].peerInfo.pidHash,
|
||||
intraNodeRank, intraNodeRanks, intraNodeRank0);
|
||||
return ncclInternalError;
|
||||
}
|
||||
struct ncclComm* intraProcRank0Comm = allGather1Data[intraProcRank0].comm;
|
||||
uint64_t intraNodeRank0pidHash = allGather1Data[intraNodeRank0].peerInfo.pidHash;
|
||||
|
||||
// AllGather1 - end
|
||||
|
||||
@@ -756,7 +785,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
|
||||
struct ncclTopoGraph treeGraph;
|
||||
treeGraph.id = 1;
|
||||
treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE;
|
||||
treeGraph.pattern = NCCL_TOPO_PATTERN_BALANCED_TREE;
|
||||
treeGraph.crossNic = ncclParamCrossNic();
|
||||
treeGraph.collNet = 0;
|
||||
treeGraph.minChannels = comm->topo->nodes[NET].count != 0 ? 1 : ringGraph.nChannels;
|
||||
@@ -802,7 +831,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
}
|
||||
if (hasPeerAccess)
|
||||
{
|
||||
if (intraRanks == nranks)
|
||||
if (intraProcRanks == nranks)
|
||||
cliqueMode = CliqueManager::CLIQUE_SINGLE_PROCESS;
|
||||
else
|
||||
cliqueMode = CliqueManager::CLIQUE_SINGLE_NODE;
|
||||
@@ -825,8 +854,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
}
|
||||
|
||||
// Determine local CollNet support before all-gather
|
||||
if (tmpNnodes > 1 && ncclParamCollNetEnable() == 1 && collNetSupport() == 1 && collNetGraph.nChannels > 0) comm->collNetSupport = 1;
|
||||
if (intraRanks > 8) {
|
||||
if (ncclParamCollNetEnable() == 1 && collNetSupport() == 1 && collNetGraph.nChannels > 0) comm->collNetSupport = 1;
|
||||
if (intraNodeRanks > 8) {
|
||||
if (comm->collNetSupport == 1) WARN("CollNet currently only supports up to 8 GPUs per node");
|
||||
comm->collNetSupport = 0;
|
||||
}
|
||||
@@ -993,9 +1022,12 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
|
||||
// Set Affinity to a CPU local the our GPU, so that all memory we allocate
|
||||
// on the host is local.
|
||||
NCCLCHECK(ncclTopoGetCpuAffinity(comm->topo, comm->rank, &comm->cpuAffinity));
|
||||
cpu_set_t affinitySave;
|
||||
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
NCCLCHECK(ncclTopoSetAffinity(comm->topo, comm->rank));
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) {
|
||||
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity);
|
||||
}
|
||||
ncclResult_t ret;
|
||||
|
||||
NCCLCHECK(computeBuffSizes(comm));
|
||||
@@ -1046,10 +1078,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
struct ncclChannel* channel = comm->channels+c;
|
||||
for (int h=0; h<nHeads; h++) {
|
||||
const int head = heads[h];
|
||||
if (ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetRecv) != 1)
|
||||
collNetSetupFail = 1;
|
||||
else if (ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetSend) != 1)
|
||||
collNetSetupFail = 1;
|
||||
collNetSetupFail = ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetRecv);
|
||||
if (!collNetSetupFail) collNetSetupFail = ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetSend);
|
||||
}
|
||||
// Verify CollNet setup across ranks after trying the first channel
|
||||
if (c == 0) {
|
||||
@@ -1115,14 +1145,17 @@ collnet_cleanup:
|
||||
free(nvbPeers);
|
||||
}
|
||||
|
||||
NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, intraRank0Comm));
|
||||
NCCLCHECK(ncclCommSetIntraProc(comm, intraProcRank, intraProcRanks, intraProcRank0Comm));
|
||||
|
||||
/* Local intra-node barrier */
|
||||
NCCLCHECK(bootstrapBarrier(comm->bootstrap, intraNodeGlobalRanks, (int)intraNodeRank0pidHash, intraNodeRank, intraNodeRanks));
|
||||
|
||||
if (comm->nNodes) NCCLCHECK(ncclProxyCreate(comm));
|
||||
|
||||
// We should have allocated all buffers, collective fifos, ... we can
|
||||
// restore the affinity.
|
||||
affinity_restore:
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
if (ret != ncclSuccess) return ret;
|
||||
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks);
|
||||
@@ -1270,7 +1303,7 @@ ncclResult_t ncclCommAbort(ncclComm_t comm) {
|
||||
return ncclSuccess;
|
||||
|
||||
// Ask anything that might still be running on the device to quit
|
||||
STORE(comm->abortFlag, 1);
|
||||
*comm->abortFlag = 1;
|
||||
|
||||
// do not destroy comm because kernel maybe still running
|
||||
// return commDestroy(comm);
|
||||
|
||||
@@ -145,7 +145,8 @@ typedef enum { ncclSum = 0,
|
||||
ncclProd = 1,
|
||||
ncclMax = 2,
|
||||
ncclMin = 3,
|
||||
ncclNumOps = 4 } ncclRedOp_t;
|
||||
ncclAvg = 4,
|
||||
ncclNumOps = 5 } ncclRedOp_t;
|
||||
|
||||
/*! @brief Data types */
|
||||
typedef enum { ncclInt8 = 0, ncclChar = 0,
|
||||
|
||||
@@ -42,9 +42,19 @@ static ncclResult_t allocateArgs(struct ncclComm* comm, struct ncclProxyArgs** a
|
||||
state->poolReturned = NULL;
|
||||
pthread_mutex_unlock(&state->poolMutex);
|
||||
} else {
|
||||
// Allocate a new pool of elements
|
||||
// Allocate a new pool of elements. Make sure we allocate the memory close
|
||||
// to the network thread
|
||||
struct ncclProxyPool* newPool;
|
||||
cpu_set_t affinitySave;
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) {
|
||||
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity);
|
||||
}
|
||||
NCCLCHECK(ncclCalloc(&newPool, 1));
|
||||
if (CPU_COUNT(&comm->cpuAffinity)) {
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
}
|
||||
|
||||
struct ncclProxyArgs* newElems = newPool->elems;
|
||||
// Chain newly allocated elements
|
||||
for (int i=0; i<PROXYARGS_ALLOCATE_SIZE; i++) {
|
||||
@@ -420,11 +430,11 @@ void* persistentThread(void *comm_) {
|
||||
|
||||
struct ncclProxyArgs** opsPtr = &state->ops;
|
||||
while (1) {
|
||||
if (LOAD(comm->abortFlag)) {
|
||||
if (*comm->abortFlag) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
while (LOAD(opsPtr) == NULL) {
|
||||
while (*opsPtr == NULL) {
|
||||
if (state->stop) {
|
||||
// No more commands to process and proxy has been requested to stop
|
||||
return NULL;
|
||||
|
||||
@@ -155,14 +155,13 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
|
||||
extern struct ncclTransport collNetTransport;
|
||||
|
||||
// All ranks must participate in collNetSetup call
|
||||
// return: 0 - unsupported, 1 - supported
|
||||
// We do not NCCLCHECK this call because we would fall back to P2P network in case CollNet setup fails
|
||||
int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, struct ncclChannel* channel, int masterRank, int masterPeer, int collNetGraphChannelId, int type) {
|
||||
int fail = 1;
|
||||
int rank = comm->rank;
|
||||
int nranks = comm->nRanks;
|
||||
int nMasters = comm->nNodes;
|
||||
int rankInCollNet = -1;
|
||||
int supported = 0;
|
||||
int isMaster = (rank == masterRank) ? 1 : 0;
|
||||
struct {
|
||||
int collNetRank;
|
||||
@@ -172,9 +171,9 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
// check if we can connect to collnet, whose root is the nranks-th rank
|
||||
struct ncclPeerInfo *myInfo = comm->peerInfo+rank, *peerInfo = comm->peerInfo+nranks;
|
||||
peerInfo->rank = nranks;
|
||||
int ret = 1;
|
||||
int support = 1;
|
||||
if (isMaster) {
|
||||
NCCLCHECK(collNetTransport.canConnect(&ret, comm->topo, collNetGraph, myInfo, peerInfo));
|
||||
NCCLCHECK(collNetTransport.canConnect(&support, comm->topo, collNetGraph, myInfo, peerInfo));
|
||||
}
|
||||
|
||||
// send master receives connect info from peer recv master
|
||||
@@ -192,7 +191,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
conn->transportComm = transportComm;
|
||||
// setup
|
||||
struct ncclConnect myConnect;
|
||||
if (isMaster && ret > 0) {
|
||||
if (isMaster && support) {
|
||||
NCCLCHECK(transportComm->setup(comm, collNetGraph, myInfo, peerInfo, &myConnect, conn, collNetGraphChannelId, type));
|
||||
}
|
||||
// prepare connect handles
|
||||
@@ -222,7 +221,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
if (isMaster) memcpy(masterConnects+rankInCollNet, &(sendrecvExchange.connect), sizeof(struct ncclConnect));
|
||||
}
|
||||
// connect
|
||||
if (isMaster && ret > 0) {
|
||||
if (isMaster && support) {
|
||||
NCCLCHECKGOTO(transportComm->connect(comm, masterConnects, nMasters, rankInCollNet, conn), res, cleanup);
|
||||
struct ncclPeer* devRoot = channel->devPeers+nranks;
|
||||
struct ncclConnector* devConn = (type == collNetRecv) ? devRoot->recv+type : devRoot->send+type;
|
||||
@@ -235,13 +234,11 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, masterPeer, collNetGraph->id, &sendrecvExchange, sizeof(sendrecvExchange)), res, cleanup);
|
||||
TRACE(NCCL_INIT, "CollNet [recv] : rank %d collNetRank %d collNetNranks %d sent connect to rank %d", rank, rankInCollNet, nMasters, masterPeer);
|
||||
}
|
||||
if (ret > 0) {
|
||||
supported = 1;
|
||||
}
|
||||
if (support) fail = 0;
|
||||
cleanup:
|
||||
if (allConnects != NULL) free(allConnects);
|
||||
if (masterConnects != NULL) free(masterConnects);
|
||||
return supported;
|
||||
return fail;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportCollNetCheck(struct ncclComm* comm, int collNetSetupFail) {
|
||||
|
||||
@@ -470,11 +470,10 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
int buffSlot = (sub->base+sub->posted)%NCCL_STEPS;
|
||||
char* ptr;
|
||||
int sharedBuffSlot = sub->posted%NCCL_STEPS;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, 0, &ptr));
|
||||
args->sharedBuff[sharedBuffSlot] = ptr;
|
||||
int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
|
||||
reqFifo[group][buffSlot].recvBuff = args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize;
|
||||
TRACE(NCCL_NET, "recvProxy [%lu/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
|
||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, startChannel, &ptr));
|
||||
reqFifo[group][buffSlot].recvBuff = ptr;
|
||||
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
|
||||
sub->posted += args->sliceSteps;
|
||||
args->idle = 0;
|
||||
continue;
|
||||
@@ -489,9 +488,10 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
TRACE(NCCL_NET, "recvProxy [%lu/%d/%d] received, size %d", sub->received, group, buffSlot, totalSize);
|
||||
sub->received += args->sliceSteps;
|
||||
if (reqFifo[group][buffSlot].size > 0 && p == NCCL_PROTO_SIMPLE && resources->useGdr) {
|
||||
int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
|
||||
char* recvAddress = (char*)args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize;
|
||||
NCCLCHECK(collNetIflush(resources->collNetComm, recvAddress, totalSize, mhandle, sub->requests+buffSlot));
|
||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||
char* groupRecvAddress;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, 1, 1, sharedBuffSlot, startChannel, &groupRecvAddress));
|
||||
NCCLCHECK(collNetIflush(resources->collNetComm, groupRecvAddress, totalSize, mhandle, sub->requests+buffSlot));
|
||||
} else {
|
||||
for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps;
|
||||
}
|
||||
@@ -516,8 +516,10 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
int group = s / COLLNET_GROUP_NSUBS;
|
||||
int buffSlot = (sub->base + sub->transmitted)%NCCL_STEPS;
|
||||
int sharedBuffSlot = sub->transmitted%NCCL_STEPS;
|
||||
int slotSize = sub->connector->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
|
||||
char* ptr = args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*slotSize + (s%COLLNET_GROUP_NSUBS)*args->sharedSize[sharedBuffSlot];
|
||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||
char* groupRecvAddress;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, 1, 1, sharedBuffSlot, startChannel, &groupRecvAddress));
|
||||
char* ptr = groupRecvAddress + (s%COLLNET_GROUP_NSUBS)*args->sharedSize[sharedBuffSlot];
|
||||
if (p == NCCL_PROTO_SIMPLE) {
|
||||
volatile void** ptrsFifo = (volatile void**)resources->recvMem->ptrsFifo;
|
||||
ptrsFifo[buffSlot] = ptr;
|
||||
|
||||
@@ -207,7 +207,7 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) {
|
||||
}
|
||||
line[1023] = '\0';
|
||||
char addrline[SOCKET_NAME_MAXLEN+1];
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s ; OOB %s:%s", line, ncclIbIfName, socketToString(&ncclIbIfAddr.sa, addrline));
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s ; OOB %s:%s", line, ncclIbIfName, socketToString(&ncclIbIfAddr, addrline));
|
||||
}
|
||||
pthread_mutex_unlock(&ncclIbLock);
|
||||
}
|
||||
@@ -262,10 +262,12 @@ ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) {
|
||||
|
||||
#define MAX_REQUESTS NCCL_NET_MAX_REQUESTS
|
||||
|
||||
#define NCCL_IB_MAX_QPS 128
|
||||
|
||||
struct ncclIbQpInfo {
|
||||
uint32_t lid;
|
||||
uint8_t ib_port;
|
||||
uint32_t qpn;
|
||||
uint32_t qpn[NCCL_IB_MAX_QPS];
|
||||
|
||||
// For RoCE
|
||||
uint64_t spn;
|
||||
@@ -287,6 +289,7 @@ struct ncclIbRequest {
|
||||
struct ncclIbVerbs* verbs;
|
||||
int events;
|
||||
int size;
|
||||
union socketAddress *addr;
|
||||
};
|
||||
|
||||
struct ncclIbVerbs {
|
||||
@@ -315,8 +318,10 @@ struct ncclIbSendComm {
|
||||
struct ncclIbSendFifo fifo[MAX_REQUESTS];
|
||||
uint32_t fifoHead;
|
||||
int fd;
|
||||
union socketAddress addr;
|
||||
int ready;
|
||||
struct ibv_qp* qp;
|
||||
struct ibv_qp* qps[NCCL_IB_MAX_QPS];
|
||||
int nqps;
|
||||
struct ibv_mr* fifoMr;
|
||||
};
|
||||
// The SendFifo needs to be 32-byte aligned and each element needs
|
||||
@@ -347,16 +352,20 @@ struct ncclIbRecvComm {
|
||||
struct ncclIbVerbs verbs;
|
||||
struct ncclIbRemFifo remFifo;
|
||||
int fd;
|
||||
union socketAddress addr;
|
||||
int ready;
|
||||
struct ibv_qp* qp;
|
||||
struct ibv_qp* qps[NCCL_IB_MAX_QPS];
|
||||
int nqps;
|
||||
struct ncclIbGpuFlush gpuFlush;
|
||||
};
|
||||
static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned");
|
||||
|
||||
NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1);
|
||||
|
||||
ncclResult_t ncclIbInitVerbs(ibv_context* ctx, struct ncclIbVerbs* verbs) {
|
||||
NCCLCHECK(wrap_ibv_alloc_pd(&verbs->pd, ctx));
|
||||
// Recv requests can generate 2 completions (one for the post FIFO, one for the Recv).
|
||||
NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS, NULL, NULL, 0));
|
||||
NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS*ncclParamIbQpsPerConn(), NULL, NULL, 0));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -389,12 +398,12 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbVerbs* verbs, int acce
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclIbRtrQp(ibv_qp* qp, struct ncclIbQpInfo* info) {
|
||||
ncclResult_t ncclIbRtrQp(ibv_qp* qp, uint32_t qpn, struct ncclIbQpInfo* info) {
|
||||
struct ibv_qp_attr qpAttr;
|
||||
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
|
||||
qpAttr.qp_state = IBV_QPS_RTR;
|
||||
qpAttr.path_mtu = info->mtu;
|
||||
qpAttr.dest_qp_num = info->qpn;
|
||||
qpAttr.dest_qp_num = qpn;
|
||||
qpAttr.rq_psn = 0;
|
||||
qpAttr.max_dest_rd_atomic = 1;
|
||||
qpAttr.min_rnr_timer = 12;
|
||||
@@ -451,18 +460,23 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) {
|
||||
NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
|
||||
*sendComm = comm;
|
||||
|
||||
comm->addr = handle->connectAddr;
|
||||
|
||||
// IB Setup
|
||||
ibv_context* ctx = ncclIbDevs[dev].context;
|
||||
NCCLCHECK(ncclIbInitVerbs(ctx, &comm->verbs));
|
||||
uint8_t ib_port = ncclIbDevs[dev].port;
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, &comm->qp));
|
||||
comm->nqps = ncclParamIbQpsPerConn();
|
||||
for (int q=0; q<comm->nqps; q++) {
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, comm->qps+q));
|
||||
}
|
||||
|
||||
// Send my QP Info to receiver through the socket. Hope this won't block.
|
||||
struct ibv_port_attr portAttr;
|
||||
NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr));
|
||||
struct ncclIbQpInfo qpInfo;
|
||||
qpInfo.ib_port = ib_port;
|
||||
qpInfo.qpn = comm->qp->qp_num;
|
||||
for (int q=0; q<comm->nqps; q++) qpInfo.qpn[q] = comm->qps[q]->qp_num;
|
||||
qpInfo.mtu = portAttr.active_mtu;
|
||||
|
||||
// Prepare my fifo
|
||||
@@ -473,16 +487,18 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) {
|
||||
// RoCE support
|
||||
qpInfo.lid = portAttr.lid;
|
||||
if (qpInfo.lid) { // IB
|
||||
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn, qpInfo.mtu, qpInfo.lid);
|
||||
for (int q=0; q<comm->nqps; q++)
|
||||
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.lid);
|
||||
} else { // RoCE
|
||||
union ibv_gid gid;
|
||||
NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &gid));
|
||||
qpInfo.spn = gid.global.subnet_prefix;
|
||||
qpInfo.iid = gid.global.interface_id;
|
||||
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn, qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid);
|
||||
for (int q=0; q<comm->nqps; q++)
|
||||
INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid);
|
||||
}
|
||||
|
||||
NCCLCHECK(socketSend(comm->fd, &qpInfo, sizeof(qpInfo)));
|
||||
NCCLCHECK(socketSend(comm->fd, &comm->addr, &qpInfo, sizeof(qpInfo)));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -493,11 +509,10 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
|
||||
struct ncclIbRecvComm* rComm;
|
||||
NCCLCHECK(ncclIbMalloc((void**)&rComm, sizeof(struct ncclIbRecvComm)));
|
||||
|
||||
struct sockaddr_in sockaddr;
|
||||
socklen_t socklen = sizeof(struct sockaddr_in);
|
||||
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
|
||||
socklen_t socklen = sizeof(union socketAddress);
|
||||
SYSCHECKVAL(accept(lComm->fd, &rComm->addr.sa, &socklen), "accept", rComm->fd);
|
||||
struct ncclIbQpInfo remQpInfo;
|
||||
NCCLCHECK(socketRecv(rComm->fd, &remQpInfo, sizeof(remQpInfo)));
|
||||
NCCLCHECK(socketRecv(rComm->fd, &rComm->addr, &remQpInfo, sizeof(remQpInfo)));
|
||||
|
||||
// IB setup
|
||||
ibv_context* ctx = ncclIbDevs[lComm->dev].context;
|
||||
@@ -509,15 +524,20 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
|
||||
|
||||
// QP Creation
|
||||
NCCLCHECK(ncclIbInitVerbs(ctx, &rComm->verbs));
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, &rComm->qp));
|
||||
rComm->nqps = ncclParamIbQpsPerConn();
|
||||
for (int q=0; q<rComm->nqps; q++) {
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, rComm->qps+q));
|
||||
}
|
||||
|
||||
// Adjust the MTU
|
||||
remQpInfo.mtu = (enum ibv_mtu)std::min(remQpInfo.mtu, portAttr.active_mtu);
|
||||
|
||||
// Setup QP
|
||||
struct ibv_qp* qp = rComm->qp;
|
||||
NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(qp));
|
||||
for (int q=0; q<rComm->nqps; q++) {
|
||||
struct ibv_qp* qp = rComm->qps[q];
|
||||
NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(qp));
|
||||
}
|
||||
|
||||
// Retain remote fifo info and prepare my RDMA ops
|
||||
rComm->remFifo.rkey = remQpInfo.fifoRkey;
|
||||
@@ -535,29 +555,26 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
|
||||
rComm->gpuFlush.sge.length = 1;
|
||||
rComm->gpuFlush.sge.lkey = rComm->gpuFlush.hostMr->lkey;
|
||||
NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ, &rComm->gpuFlush.qp));
|
||||
struct ncclIbQpInfo localQpInfo = {
|
||||
.lid=portAttr.lid,
|
||||
.ib_port=ib_port,
|
||||
.qpn=rComm->gpuFlush.qp->qp_num,
|
||||
.spn=gid.global.subnet_prefix,
|
||||
.iid=gid.global.interface_id,
|
||||
.mtu=portAttr.active_mtu
|
||||
};
|
||||
NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, &localQpInfo));
|
||||
struct ncclIbQpInfo localQpInfo;
|
||||
localQpInfo.lid=portAttr.lid;
|
||||
localQpInfo.ib_port=ib_port;
|
||||
localQpInfo.spn=gid.global.subnet_prefix;
|
||||
localQpInfo.iid=gid.global.interface_id;
|
||||
localQpInfo.mtu=portAttr.active_mtu;
|
||||
NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, rComm->gpuFlush.qp->qp_num, &localQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(rComm->gpuFlush.qp));
|
||||
}
|
||||
|
||||
// Fill Handle
|
||||
struct ncclIbQpInfo qpInfo = {
|
||||
.lid=portAttr.lid,
|
||||
.ib_port=ib_port,
|
||||
.qpn=qp->qp_num,
|
||||
.spn=gid.global.subnet_prefix,
|
||||
.iid=gid.global.interface_id,
|
||||
.mtu=remQpInfo.mtu
|
||||
};
|
||||
struct ncclIbQpInfo qpInfo;
|
||||
qpInfo.lid=portAttr.lid;
|
||||
qpInfo.ib_port=ib_port;
|
||||
for (int q=0; q<rComm->nqps; q++) qpInfo.qpn[q]=rComm->qps[q]->qp_num;
|
||||
qpInfo.spn=gid.global.subnet_prefix;
|
||||
qpInfo.iid=gid.global.interface_id;
|
||||
qpInfo.mtu=remQpInfo.mtu;
|
||||
|
||||
NCCLCHECK(socketSend(rComm->fd, &qpInfo, sizeof(qpInfo)));
|
||||
NCCLCHECK(socketSend(rComm->fd, &rComm->addr, &qpInfo, sizeof(qpInfo)));
|
||||
*recvComm = rComm;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -571,6 +588,7 @@ ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest**
|
||||
r->verbs = verbs;
|
||||
r->events = 1;
|
||||
r->size = -1;
|
||||
r->addr = NULL;
|
||||
*req = r;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -586,19 +604,21 @@ ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) {
|
||||
|
||||
ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) {
|
||||
struct ncclIbQpInfo remQpInfo;
|
||||
struct ibv_qp* qp = comm->qp;
|
||||
|
||||
// Do not block on this receive, return if not ready.
|
||||
int bytes = 0;
|
||||
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes));
|
||||
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &remQpInfo, sizeof(remQpInfo), &bytes));
|
||||
if (bytes == 0) return ncclSuccess; // Try again later
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &remQpInfo, sizeof(remQpInfo), &bytes));
|
||||
|
||||
NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(qp));
|
||||
for (int q=0; q<comm->nqps; q++) {
|
||||
struct ibv_qp* qp = comm->qps[q];
|
||||
NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo));
|
||||
NCCLCHECK(ncclIbRtsQp(qp));
|
||||
}
|
||||
comm->ready = 1;
|
||||
// Block until this is done. It *should* not block indefinitely.
|
||||
NCCLCHECK(socketSend(comm->fd, &comm->ready, sizeof(int)));
|
||||
NCCLCHECK(socketSend(comm->fd, &comm->addr, &comm->ready, sizeof(int)));
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -606,9 +626,9 @@ ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) {
|
||||
ncclResult_t ncclRecvCheck(struct ncclIbRecvComm* comm) {
|
||||
// Do not block on this receive, return if not ready.
|
||||
int bytes = 0;
|
||||
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes));
|
||||
NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &comm->ready, sizeof(int), &bytes));
|
||||
if (bytes == 0) return ncclSuccess; // Try again later
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->addr, &comm->ready, sizeof(int), &bytes));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -648,25 +668,20 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
|
||||
// Wait for the receiver to have posted the corresponding receive
|
||||
volatile struct ncclIbSendFifo* slot = comm->fifo + (comm->fifoHead%MAX_REQUESTS);
|
||||
volatile uint32_t * readyPtr = &slot->ready;
|
||||
if (LOAD(readyPtr) == 0) { *request = NULL; return ncclSuccess; }
|
||||
if (*readyPtr == 0) { *request = NULL; return ncclSuccess; }
|
||||
|
||||
struct ncclIbRequest* req;
|
||||
NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
|
||||
req->size = size;
|
||||
req->addr = &comm->addr;
|
||||
|
||||
struct ibv_send_wr wr[2];
|
||||
memset(&wr[0], 0, sizeof(wr[0]));
|
||||
wr[0].wr_id = (uint64_t)req;
|
||||
|
||||
struct ibv_sge sge;
|
||||
if (size == 0) {
|
||||
wr[0].sg_list = NULL;
|
||||
wr[0].num_sge = 0;
|
||||
} else {
|
||||
sge.addr=(uintptr_t)data; sge.length=(unsigned int)size; sge.lkey=mr->lkey;
|
||||
wr[0].sg_list = &sge;
|
||||
wr[0].num_sge = 1;
|
||||
}
|
||||
sge.addr=(uintptr_t)data; sge.lkey=mr->lkey;
|
||||
|
||||
#if USE_RDMA_WRITE == 0
|
||||
wr[0].opcode = IBV_WR_SEND;
|
||||
wr[0].send_flags = IBV_SEND_SIGNALED;
|
||||
@@ -674,9 +689,10 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
|
||||
__sync_synchronize(); // order the readyPtr load against rkey load below
|
||||
// Sanity checks to catch user collective call count/size mismatches
|
||||
// plus any potential programming errors
|
||||
if (size > LOAD(&slot->size) || LOAD(&slot->size) <= 0 || LOAD(&slot->addr) == 0 || LOAD(&slot->rkey) == 0 || LOAD(&slot->seq) != comm->fifoHead) {
|
||||
WARN("NET/IB : collective mismatch error local size %d remote %d addr %lx rkey %x seq %x/%x",
|
||||
size, LOAD(&slot->size), LOAD(&slot->addr), LOAD(&slot->rkey), LOAD(&slot->seq), comm->fifoHead);
|
||||
if (size > slot->size || slot->size < 0 || slot->addr == 0 || slot->rkey == 0 || slot->seq != comm->fifoHead) {
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
WARN("NET/IB : peer %s collective mismatch error local size %d remote %d addr %lx rkey %x seq %x/%x",
|
||||
socketToString(req->addr, line), size, slot->size, slot->addr, slot->rkey, slot->seq, comm->fifoHead);
|
||||
return ncclInternalError;
|
||||
}
|
||||
wr[0].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
@@ -688,9 +704,9 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
|
||||
#endif
|
||||
// We must clear slot->ready, but reset other fields to aid
|
||||
// debugging and sanity checks
|
||||
STORE(&slot->ready, 0);
|
||||
STORE(&slot->addr, 0);
|
||||
STORE(&slot->rkey, 0); STORE(&slot->size, 0); STORE(&slot->seq, 0);
|
||||
slot->ready = 0;
|
||||
slot->addr = 0ULL;
|
||||
slot->rkey = slot->size = slot->seq = 0;
|
||||
comm->fifoHead++;
|
||||
|
||||
|
||||
@@ -713,8 +729,26 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
|
||||
}
|
||||
#endif
|
||||
|
||||
struct ibv_send_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_send(comm->qp, wr, &bad_wr));
|
||||
int chunkSize = std::max(8, DIVUP(size, comm->nqps));
|
||||
|
||||
int offset = 0;
|
||||
for (int q=0; q<comm->nqps; q++) {
|
||||
int length = std::min(size-offset, chunkSize);
|
||||
if (length <= 0) {
|
||||
wr[0].sg_list = NULL;
|
||||
wr[0].num_sge = 0;
|
||||
} else {
|
||||
sge.length = length;
|
||||
wr[0].sg_list = &sge;
|
||||
wr[0].num_sge = 1;
|
||||
}
|
||||
struct ibv_send_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_send(comm->qps[q], wr, &bad_wr));
|
||||
offset += chunkSize;
|
||||
sge.addr += chunkSize;
|
||||
wr[0].wr.rdma.remote_addr += chunkSize;
|
||||
}
|
||||
req->events = comm->nqps;
|
||||
|
||||
*request = req;
|
||||
return ncclSuccess;
|
||||
@@ -767,7 +801,7 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, uint32_t rkey, uint64_t
|
||||
}
|
||||
|
||||
struct ibv_send_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_send(comm->qp, &wr, &bad_wr));
|
||||
NCCLCHECK(wrap_ibv_post_send(comm->qps[0], &wr, &bad_wr));
|
||||
comm->remFifo.tail++;
|
||||
|
||||
return ncclSuccess;
|
||||
@@ -783,23 +817,22 @@ ncclResult_t ncclIbIrecv(void* recvComm, void* data, int size, void* mhandle, vo
|
||||
struct ncclIbRequest* req;
|
||||
NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
|
||||
req->size = size;
|
||||
req->addr = &comm->addr;
|
||||
|
||||
struct ibv_recv_wr wr;
|
||||
memset(&wr, 0, sizeof(wr));
|
||||
wr.wr_id = (uint64_t)req;
|
||||
|
||||
struct ibv_sge sge;
|
||||
if (size == 0) {
|
||||
wr.sg_list = NULL;
|
||||
wr.num_sge = 0;
|
||||
} else {
|
||||
sge.addr=(uintptr_t)data; sge.length=(unsigned int)size; sge.lkey=mr->lkey;
|
||||
wr.sg_list = &sge;
|
||||
wr.num_sge = 1;
|
||||
}
|
||||
wr.sg_list = NULL;
|
||||
wr.num_sge = 0;
|
||||
|
||||
for (int q=0; q<comm->nqps; q++) {
|
||||
struct ibv_qp* qp = comm->qps[q];
|
||||
struct ibv_recv_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_recv(qp, &wr, &bad_wr));
|
||||
}
|
||||
req->events = comm->nqps;
|
||||
|
||||
struct ibv_recv_wr* bad_wr;
|
||||
NCCLCHECK(wrap_ibv_post_recv(comm->qp, &wr, &bad_wr));
|
||||
*request = req;
|
||||
|
||||
// Post to FIFO to notify sender
|
||||
@@ -813,6 +846,7 @@ ncclResult_t ncclIbIflush(void* recvComm, void* data, int size, void* mhandle, v
|
||||
|
||||
struct ncclIbRequest* req;
|
||||
NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
|
||||
req->addr = &comm->addr;
|
||||
struct ibv_mr* mr = (struct ibv_mr*)mhandle;
|
||||
|
||||
struct ibv_send_wr wr;
|
||||
@@ -853,7 +887,9 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size) {
|
||||
for (int w=0; w<wrDone; w++) {
|
||||
struct ibv_wc *wc = wcs+w;
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("NET/IB : Got completion with error %d, opcode %d, len %d, vendor err %d", wc->status, wc->opcode, wc->byte_len, wc->vendor_err);
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d",
|
||||
socketToString(r->addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err);
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
@@ -863,7 +899,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size) {
|
||||
doneReq->size = wc->byte_len;
|
||||
#if USE_RDMA_WRITE
|
||||
} else if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
|
||||
doneReq->size = wc->imm_data;
|
||||
if (doneReq->size == -1)
|
||||
doneReq->size = wc->imm_data;
|
||||
else
|
||||
doneReq->size += wc->imm_data;
|
||||
#endif
|
||||
}
|
||||
doneReq->events--;
|
||||
@@ -876,7 +915,8 @@ ncclResult_t ncclIbCloseSend(void* sendComm) {
|
||||
struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm;
|
||||
if (comm) {
|
||||
close(comm->fd);
|
||||
if (comm->qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qp));
|
||||
for (int q=0; q<comm->nqps; q++)
|
||||
if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q]));
|
||||
if (comm->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr));
|
||||
NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs));
|
||||
free(comm);
|
||||
@@ -888,7 +928,8 @@ ncclResult_t ncclIbCloseRecv(void* recvComm) {
|
||||
struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm;
|
||||
if (comm) {
|
||||
close(comm->fd);
|
||||
if (comm->qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qp));
|
||||
for (int q=0; q<comm->nqps; q++)
|
||||
if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q]));
|
||||
if (comm->gpuFlush.enabled) {
|
||||
if (comm->gpuFlush.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->gpuFlush.qp));
|
||||
if (comm->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->gpuFlush.hostMr));
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -57,7 +56,7 @@ ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) {
|
||||
memcpy(&ncclSocketDevs[i].addr, addrs+i, sizeof(union socketAddress));
|
||||
NCCLCHECK(ncclSocketGetPciPath(ncclSocketDevs[i].devName, &ncclSocketDevs[i].pciPath));
|
||||
snprintf(line+strlen(line), MAX_LINE_LEN-strlen(line), " [%d]%s:%s", i, names+i*MAX_IF_NAME_SIZE,
|
||||
socketToString(&addrs[i].sa, addrline));
|
||||
socketToString(&addrs[i], addrline));
|
||||
}
|
||||
line[MAX_LINE_LEN] = '\0';
|
||||
INFO(NCCL_INIT|NCCL_NET,"NET/Socket : Using%s", line);
|
||||
@@ -130,6 +129,7 @@ struct ncclSocketTask {
|
||||
void* data;
|
||||
int size;
|
||||
int fd;
|
||||
union socketAddress *addr;
|
||||
int offset;
|
||||
int used;
|
||||
ncclResult_t result;
|
||||
@@ -140,6 +140,7 @@ struct ncclSocketRequest {
|
||||
void* data;
|
||||
int size;
|
||||
int ctrlFd;
|
||||
union socketAddress *addr;
|
||||
int offset;
|
||||
int used;
|
||||
struct ncclSocketComm* comm;
|
||||
@@ -171,6 +172,7 @@ struct ncclSocketListenComm {
|
||||
|
||||
struct ncclSocketComm {
|
||||
int ctrlFd;
|
||||
union socketAddress addr;
|
||||
int fds[MAX_SOCKETS];
|
||||
int nSocks;
|
||||
int nThreads;
|
||||
@@ -196,7 +198,7 @@ void* persistentSocketThread(void *args_) {
|
||||
for (int j=0; j<nSocksPerThread; j++) {
|
||||
struct ncclSocketTask* r = myQueue->tasks+i+j;
|
||||
if (r != NULL && r->used == 1 && r->offset < r->size) {
|
||||
r->result = socketProgress(r->op, r->fd, r->data, r->size, &r->offset);
|
||||
r->result = socketProgress(r->op, r->fd, r->addr, r->data, r->size, &r->offset);
|
||||
if (r->result != ncclSuccess) {
|
||||
WARN("NET/Socket : socket progress error");
|
||||
return NULL;
|
||||
@@ -209,12 +211,12 @@ void* persistentSocketThread(void *args_) {
|
||||
}
|
||||
if (idle) {
|
||||
pthread_mutex_lock(&resource->threadLock);
|
||||
while (mark == myQueue->next && LOAD(state) != stop) { // no new tasks, wait
|
||||
while (mark == myQueue->next && *state != stop) { // no new tasks, wait
|
||||
pthread_cond_wait(&resource->threadCond, &resource->threadLock);
|
||||
}
|
||||
pthread_mutex_unlock(&resource->threadLock);
|
||||
}
|
||||
if (LOAD(state) == stop) return NULL;
|
||||
if (*state == stop) return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,11 +314,12 @@ ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) {
|
||||
for (int i=0; i<comm->nSocks+1; i++) {
|
||||
int tmpFd, offset=0;
|
||||
NCCLCHECK(connectAddress(&tmpFd, &handle->connectAddr));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset));
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &handle->connectAddr, &i, sizeof(int), &offset));
|
||||
if (i == comm->nSocks) comm->ctrlFd = tmpFd;
|
||||
else comm->fds[i] = tmpFd;
|
||||
}
|
||||
*sendComm = comm;
|
||||
comm->addr = handle->connectAddr;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -328,10 +331,9 @@ ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) {
|
||||
rComm->nThreads = lComm->nThreads;
|
||||
for (int i=0; i<rComm->nSocks+1; i++) {
|
||||
int tmpFd, sendSockIdx, offset=0;
|
||||
struct sockaddr_in sockaddr;
|
||||
socklen_t socklen = sizeof(struct sockaddr_in);
|
||||
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", tmpFd);
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &sendSockIdx, sizeof(int), &offset));
|
||||
socklen_t socklen = sizeof(union socketAddress);
|
||||
SYSCHECKVAL(accept(lComm->fd, &rComm->addr.sa, &socklen), "accept", tmpFd);
|
||||
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &rComm->addr, &sendSockIdx, sizeof(int), &offset));
|
||||
if (sendSockIdx == rComm->nSocks) rComm->ctrlFd = tmpFd;
|
||||
else rComm->fds[sendSockIdx] = tmpFd;
|
||||
}
|
||||
@@ -347,6 +349,7 @@ ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* dat
|
||||
r->data = data;
|
||||
r->size = size;
|
||||
r->ctrlFd = comm->ctrlFd;
|
||||
r->addr = &comm->addr;
|
||||
r->used = 1;
|
||||
r->comm = comm;
|
||||
r->nSubs = 0;
|
||||
@@ -381,6 +384,7 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data,
|
||||
r->data = data;
|
||||
r->size = size;
|
||||
r->fd = comm->fds[comm->nextFd];
|
||||
r->addr = &comm->addr;
|
||||
r->offset = 0;
|
||||
r->result = ncclSuccess;
|
||||
comm->nextFd = (comm->nextFd + 1) % comm->nSocks;
|
||||
@@ -407,16 +411,17 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
|
||||
if (r->used == 1) { /* try to send/recv size */
|
||||
int data = r->size;
|
||||
int offset = 0;
|
||||
NCCLCHECK(socketProgress(r->op, r->ctrlFd, &data, sizeof(int), &offset));
|
||||
NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->addr, &data, sizeof(int), &offset));
|
||||
|
||||
if (offset == 0) return ncclSuccess; /* Not ready -- retry later */
|
||||
|
||||
// Not sure we could ever receive less than 4 bytes, but just in case ...
|
||||
if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, &data, sizeof(int), &offset));
|
||||
if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, r->addr, &data, sizeof(int), &offset));
|
||||
|
||||
// Check size is less or equal to the size provided by the user
|
||||
if (r->op == NCCL_SOCKET_RECV && data > r->size) {
|
||||
WARN("NET/Socket : message truncated : receiving %d bytes instead of %d", data, r->size);
|
||||
char line[SOCKET_NAME_MAXLEN+1];
|
||||
WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d", socketToString(r->addr, line), data, r->size);
|
||||
return ncclInternalError;
|
||||
}
|
||||
r->size = data;
|
||||
@@ -454,7 +459,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
|
||||
}
|
||||
} else { // progress request using main thread
|
||||
if (r->offset < r->size) {
|
||||
NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->data, r->size, &r->offset));
|
||||
NCCLCHECK(socketProgress(r->op, r->ctrlFd, r->addr, r->data, r->size, &r->offset));
|
||||
}
|
||||
if (r->offset == r->size) {
|
||||
if (size) *size = r->size;
|
||||
|
||||
@@ -62,8 +62,8 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
|
||||
}
|
||||
#endif
|
||||
|
||||
// Rule out different nodes
|
||||
if (info1->hostHash != info2->hostHash) {
|
||||
// Rule out different nodes / isolated containers
|
||||
if (info1->hostHash != info2->hostHash || info1->shmDev != info2->shmDev) {
|
||||
*ret = 0;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ NCCL_PARAM(CrossNic, "CROSS_NIC", 2);
|
||||
NCCL_PARAM(CollNetEnable, "COLLNET_ENABLE", 0);
|
||||
NCCL_PARAM(GraphDumpFileRank, "GRAPH_DUMP_FILE_RANK", 0);
|
||||
RCCL_PARAM(P2pNetDisable, "P2P_NET_DISABLE", 0);
|
||||
NCCL_PARAM(CollNetNodeThreshold, "COLLNET_NODE_THRESHOLD", 2);
|
||||
|
||||
thread_local int ncclDebugNoWarn = 0;
|
||||
ncclCollNet_t* ncclCollNet = NULL;
|
||||
@@ -120,7 +121,7 @@ void ncclDebugLog(ncclDebugLogLevel level, unsigned long flags, const char *file
|
||||
ncclResult_t ncclTopoGetSystem(const char* xmlTopoFile, struct ncclTopoSystem** system) {
|
||||
struct ncclXml* xml;
|
||||
NCCLCHECK(ncclCalloc(&xml, 1));
|
||||
NCCLCHECK(ncclTopoGetXmlFromFile(xmlTopoFile, xml));
|
||||
NCCLCHECK(ncclTopoGetXmlFromFile(xmlTopoFile, xml, 0));
|
||||
NCCLCHECK(ncclTopoGetSystemFromXml(xml, system));
|
||||
free(xml);
|
||||
return ncclSuccess;
|
||||
@@ -141,17 +142,197 @@ ncclResult_t bootstrapAllGather(struct ncclComm* comm, struct allGather1Data_t *
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
void initCollNet() {
|
||||
if (ncclParamCollNetEnable() == 1 && ncclCollNet == 0)
|
||||
ncclCollNet = (ncclCollNet_t*)0x12345678;
|
||||
}
|
||||
|
||||
ncclResult_t initChannel(struct ncclComm* comm, int channelid) {
|
||||
struct ncclChannel* channel = comm->channels+channelid;
|
||||
if (channel->id != -1) return ncclSuccess;
|
||||
channel->id = channelid;
|
||||
|
||||
// Ring index to user rank table.
|
||||
//NCCLCHECK(ncclCudaCalloc(&channel->ring.devUserRanks, comm->nRanks));
|
||||
NCCLCHECK(ncclCalloc(&channel->ring.userRanks, comm->nRanks));
|
||||
|
||||
// Communication structures with peers.
|
||||
//NCCLCHECK(ncclCudaCalloc(&channel->devPeers, comm->nRanks+1)); // The extra one rank is for collnet root (i.e. network)
|
||||
NCCLCHECK(ncclCalloc(&channel->peers, comm->nRanks+1));
|
||||
for (size_t i=0; i<comm->nRanks+1; ++i) {
|
||||
for (int b=0; b<NCCL_MAX_CONNS; b++) {
|
||||
channel->peers[i].send[b].comm = comm;
|
||||
channel->peers[i].recv[b].comm = comm;
|
||||
}
|
||||
}
|
||||
|
||||
// Per-channel operation list.
|
||||
//NCCLCHECK(ncclCudaHostCalloc(&channel->workFifo, NCCL_MAX_OPS));
|
||||
//if (ncclGdrCopy != NULL && ncclParamGdrCopyFifoEnable() == 1) {
|
||||
// GDRCOPY support
|
||||
// We allocate a workFifo in GDR mapped CUDA memory
|
||||
// But we still allocate the Host workFifo so that we
|
||||
// can copy the work elements to CUDA memory on kernel launch
|
||||
//NCCLCHECK(ncclGdrCudaCalloc(&channel->workFifoGdr, &channel->workFifoDev, NCCL_MAX_OPS, &channel->gdrMemDesc));
|
||||
//} else {
|
||||
// The device workFifo is the Host one
|
||||
//channel->workFifoDev = channel->workFifo;
|
||||
//}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t setupChannel(struct ncclComm* comm, int channelId, int rank, int nranks, int* ringRanks) {
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d", rank, nranks);
|
||||
NCCLCHECK(initChannel(comm, channelId));
|
||||
|
||||
struct ncclRing* ring = &comm->channels[channelId].ring;
|
||||
// Find our ring-distance from rank zero and reorganize ranks to start with rank.
|
||||
int ixZero=0, ixRank=0;
|
||||
for (int i=0; i < nranks; i++) {
|
||||
if (ringRanks[i] == 0) ixZero = i;
|
||||
if (ringRanks[i] == rank) ixRank = i;
|
||||
}
|
||||
ring->index = (ixRank-ixZero + nranks)%nranks;
|
||||
for (int i=0; i<nranks; i++) {
|
||||
ring->userRanks[i] = ringRanks[(i+ixRank)%nranks];
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t connectedByXGMI(int* ret, struct ncclTopoSystem* system, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) {
|
||||
*ret = 0;
|
||||
if (info1->hostHash != info2->hostHash) return ncclSuccess;
|
||||
int g1, g2;
|
||||
NCCLCHECK(ncclTopoRankToIndex(system, info1->rank, &g1));
|
||||
NCCLCHECK(ncclTopoRankToIndex(system, info2->rank, &g2));
|
||||
if (system->nodes[GPU].nodes[g1].paths[GPU][g2].type == PATH_NVL) *ret = 1;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
template <int type>
|
||||
static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex) {
|
||||
struct ncclPeerInfo* myInfo = comm->peerInfo+comm->rank;
|
||||
struct ncclPeerInfo* peerInfo = comm->peerInfo+peer;
|
||||
struct ncclConnector* connector = (type == 1) ? comm->channels[channelId].peers[peer].send + connIndex :
|
||||
comm->channels[channelId].peers[peer].recv + connIndex;
|
||||
|
||||
// handle intra-node network connections
|
||||
int n1 = -1, n2 = -1;
|
||||
if (connIndex == NCCL_CONN_IDX_P2P_NET) {
|
||||
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, comm->rank, graph, channelId, (type == 1) ? 1 : 0, &n1));
|
||||
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, peer, graph, channelId, (type == 1) ? 0 : 1, &n2));
|
||||
}
|
||||
|
||||
int xgmi;
|
||||
NCCLCHECK(connectedByXGMI(&xgmi, comm->topo, myInfo, peerInfo));
|
||||
for (int t=0; t<NTRANSPORTS; t++) {
|
||||
if (graph == NULL && connIndex == NCCL_CONN_IDX_P2P_NET && (t == TRANSPORT_SHM || (!xgmi && t == TRANSPORT_P2P))) continue;
|
||||
if (graph && n1 >= 0 && n2 >= 0 && t != TRANSPORT_NET) continue;
|
||||
struct ncclTransport *transport = ncclTransports+t;
|
||||
struct ncclTransportComm* transportComm = type == 1 ? &transport->send : &transport->recv;
|
||||
int ret = 0;
|
||||
NCCLCHECK(transport->canConnect(&ret, comm->topo, graph, myInfo, peerInfo));
|
||||
if (ret) {
|
||||
connector->transportComm = transportComm;
|
||||
NCCLCHECK(transportComm->setup(comm, graph, myInfo, peerInfo, connect, connector, channelId, connIndex));
|
||||
return ncclSuccess;
|
||||
}
|
||||
}
|
||||
WARN("No transport found !");
|
||||
return ncclInternalError;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, struct ncclChannel* channel, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex) {
|
||||
TRACE(NCCL_INIT, "nsend %d nrecv %d", nsend, nrecv);
|
||||
uint32_t mask = 1 << channel->id;
|
||||
for (int i=0; i<nrecv; i++) {
|
||||
int peer = peerRecv[i];
|
||||
if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer].recv[connIndex].connected) continue;
|
||||
comm->connectRecv[peer+comm->nRanks*connIndex] |= mask;
|
||||
}
|
||||
for (int i=0; i<nsend; i++) {
|
||||
int peer = peerSend[i];
|
||||
if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer].send[connIndex].connected) continue;
|
||||
comm->connectSend[peer+comm->nRanks*connIndex] |= mask;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex) {
|
||||
// Stream used during transport setup; need for P2P pre-connect + CUDA Graph
|
||||
//hipStream_t transportSetupStream;
|
||||
//CUDACHECK(hipStreamCreateWithFlags(&transportSetupStream, hipStreamNonBlocking));
|
||||
|
||||
struct ncclConnect data[2*MAXCHANNELS];
|
||||
for (int i=1; i<comm->nRanks; i++) {
|
||||
int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0);
|
||||
int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks;
|
||||
int sendPeer = (comm->rank + i) % comm->nRanks;
|
||||
uint32_t recvMask = comm->connectRecv[recvPeer+comm->nRanks*connIndex];
|
||||
uint32_t sendMask = comm->connectSend[sendPeer+comm->nRanks*connIndex];
|
||||
|
||||
struct ncclConnect* recvData = data;
|
||||
int sendChannels = 0, recvChannels = 0;
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (recvMask & (1<<c)) {
|
||||
NCCLCHECK(selectTransport<0>(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex));
|
||||
}
|
||||
}
|
||||
struct ncclConnect* sendData = recvData+recvChannels;
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (sendMask & (1<<c)) {
|
||||
NCCLCHECK(selectTransport<1>(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex));
|
||||
}
|
||||
}
|
||||
|
||||
if (sendPeer == recvPeer) {
|
||||
if (recvChannels+sendChannels) {
|
||||
//NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)));
|
||||
//NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)));
|
||||
sendData = data;
|
||||
recvData = data+sendChannels;
|
||||
}
|
||||
} else {
|
||||
//if (recvChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels));
|
||||
//if (sendChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels));
|
||||
//if (sendChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels));
|
||||
//if (recvChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels));
|
||||
}
|
||||
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (sendMask & (1<<c)) {
|
||||
struct ncclConnector* conn = comm->channels[c].peers[sendPeer].send + connIndex;
|
||||
//NCCLCHECK(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn));
|
||||
conn->connected = 1;
|
||||
//CUDACHECK(hipMemcpyAsync(comm->channels[c].devPeers[sendPeer].send+connIndex, conn, sizeof(struct ncclConnector), hipMemcpyHostToDevice, transportSetupStream));
|
||||
}
|
||||
}
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (recvMask & (1<<c)) {
|
||||
struct ncclConnector* conn = comm->channels[c].peers[recvPeer].recv + connIndex;
|
||||
//NCCLCHECK(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn));
|
||||
conn->connected = 1;
|
||||
//CUDACHECK(hipMemcpyAsync(comm->channels[c].devPeers[recvPeer].recv+connIndex, conn, sizeof(struct ncclConnector), hipMemcpyHostToDevice, transportSetupStream));
|
||||
}
|
||||
}
|
||||
comm->connectRecv[recvPeer] = comm->connectSend[sendPeer] = 0;
|
||||
}
|
||||
//CUDACHECK(hipStreamSynchronize(transportSetupStream));
|
||||
//CUDACHECK(hipStreamDestroy(transportSetupStream));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
extern struct ncclTransport collNetTransport;
|
||||
|
||||
// All ranks must participate in collNetSetup call
|
||||
// return: 0 - unsupported, 1 - supported
|
||||
// We do not NCCLCHECK this call because we would fall back to P2P network in case CollNet setup fails
|
||||
int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, struct ncclChannel* channel, int masterRank, int masterPeer, int collNetGraphChannelId, int type) {
|
||||
int fail = 1;
|
||||
int rank = comm->rank;
|
||||
int nranks = comm->nRanks;
|
||||
int nMasters = comm->nNodes;
|
||||
int rankInCollNet = -1;
|
||||
int supported = 0;
|
||||
int isMaster = (rank == masterRank) ? 1 : 0;
|
||||
struct {
|
||||
int collNetRank;
|
||||
@@ -161,9 +342,9 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
// check if we can connect to collnet, whose root is the nranks-th rank
|
||||
struct ncclPeerInfo *myInfo = comm->peerInfo+rank, *peerInfo = comm->peerInfo+nranks;
|
||||
peerInfo->rank = nranks;
|
||||
int ret = 1;
|
||||
int support = 1;
|
||||
if (isMaster) {
|
||||
NCCLCHECK(collNetTransport.canConnect(&ret, comm->topo, collNetGraph, myInfo, peerInfo));
|
||||
NCCLCHECK(collNetTransport.canConnect(&support, comm->topo, collNetGraph, myInfo, peerInfo));
|
||||
}
|
||||
|
||||
// send master receives connect info from peer recv master
|
||||
@@ -181,7 +362,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
conn->transportComm = transportComm;
|
||||
// setup
|
||||
struct ncclConnect myConnect;
|
||||
if (isMaster && ret > 0) {
|
||||
if (isMaster && support) {
|
||||
NCCLCHECK(transportComm->setup(comm, collNetGraph, myInfo, peerInfo, &myConnect, conn, collNetGraphChannelId, type));
|
||||
}
|
||||
// prepare connect handles
|
||||
@@ -211,7 +392,7 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
//if (isMaster) memcpy(masterConnects+rankInCollNet, &(sendrecvExchange.connect), sizeof(struct ncclConnect));
|
||||
}
|
||||
// connect
|
||||
if (isMaster && ret > 0) {
|
||||
if (isMaster && support) {
|
||||
//NCCLCHECKGOTO(transportComm->connect(comm, masterConnects, nMasters, rankInCollNet, conn), res, cleanup);
|
||||
struct ncclPeer* devRoot = channel->devPeers+nranks;
|
||||
struct ncclConnector* devConn = (type == collNetRecv) ? devRoot->recv+type : devRoot->send+type;
|
||||
@@ -224,18 +405,11 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN
|
||||
//NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, masterPeer, collNetGraph->id, &sendrecvExchange, sizeof(sendrecvExchange)), res, cleanup);
|
||||
TRACE(NCCL_INIT, "CollNet [recv] : rank %d collNetRank %d collNetNranks %d sent connect to rank %d", rank, rankInCollNet, nMasters, masterPeer);
|
||||
}
|
||||
if (ret > 0) {
|
||||
supported = 1;
|
||||
}
|
||||
if (support) fail = 0;
|
||||
cleanup:
|
||||
if (allConnects != NULL) free(allConnects);
|
||||
if (masterConnects != NULL) free(masterConnects);
|
||||
return supported;
|
||||
}
|
||||
|
||||
void initCollNet() {
|
||||
if (ncclParamCollNetEnable() == 1 && ncclCollNet == 0)
|
||||
ncclCollNet = (ncclCollNet_t*)0x12345678;
|
||||
return fail;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportCollNetCheck(struct ncclComm* comm, int collNetSetupFail) {
|
||||
@@ -311,37 +485,45 @@ ncclResult_t initTransportsRank_1(struct ncclComm* comm, struct allGather1Data_t
|
||||
}
|
||||
|
||||
// Compute intra ranks and minimum CUDA Compute capabilities of intra-node GPUs and all GPUs
|
||||
int intraRank0 = -1, intraRank = -1, intraRanks = 0;
|
||||
int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0;
|
||||
int intraNodeRank0 = -1, intraNodeRank = -1, intraNodeRanks = 0;
|
||||
int myCompCap = allGather1Data[rank].cudaCompCap;
|
||||
int minCompCap = myCompCap, maxCompCap = myCompCap;
|
||||
uint64_t otherHostHash;
|
||||
int tmpNnodes = 1;
|
||||
int intraNodeGlobalRanks[256];
|
||||
for (int i = 0; i < nranks; i++) {
|
||||
if (allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) {
|
||||
// Rank is on same node
|
||||
if (intraNodeRanks == 0) intraNodeRank0 = i;
|
||||
if (i == rank) intraNodeRank = intraNodeRanks;
|
||||
intraNodeGlobalRanks[intraNodeRanks++] = i;
|
||||
if (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash) {
|
||||
if (intraRanks == 0) intraRank0 = i;
|
||||
if (i == rank) intraRank = intraRanks;
|
||||
intraRanks++;
|
||||
}
|
||||
} else { // Determine whether number of nodes is 2 (for use in tree pattern determination)
|
||||
if (tmpNnodes == 1) {
|
||||
otherHostHash = allGather1Data[i].peerInfo.hostHash;
|
||||
tmpNnodes = 2;
|
||||
} else if (tmpNnodes == 2 && otherHostHash != allGather1Data[i].peerInfo.hostHash) {
|
||||
tmpNnodes = 3;
|
||||
// Rank is in same process
|
||||
if (intraProcRanks == 0) intraProcRank0 = i;
|
||||
if (i == rank) intraProcRank = intraProcRanks;
|
||||
intraProcRanks++;
|
||||
}
|
||||
}
|
||||
minCompCap = std::min(allGather1Data[i].cudaCompCap, minCompCap);
|
||||
maxCompCap = std::max(allGather1Data[i].cudaCompCap, maxCompCap);
|
||||
}
|
||||
TRACE(NCCL_INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
|
||||
if (intraRank == -1 || intraRank0 == -1 || allGather1Data[intraRank0].comm == NULL) {
|
||||
WARN("Failed to determine intra ranks hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
|
||||
TRACE(NCCL_INIT,"hostHash[%d] %lx intraNodeRank %d intraNodeRanks %d intraNodeRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, intraNodeRank, intraNodeRanks, intraNodeRank0);
|
||||
TRACE(NCCL_INIT,"pidHash[%d] %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.pidHash, intraProcRank, intraProcRanks, intraProcRank0);
|
||||
if (intraProcRank == -1 || intraProcRank0 == -1 || allGather1Data[intraProcRank0].comm == NULL) {
|
||||
WARN("Failed to determine intra proc ranks rank %d hostHash %lx pidHash %lx intraProcRank %d intraProcRanks %d intraProcRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, allGather1Data[rank].peerInfo.pidHash,
|
||||
intraProcRank, intraProcRanks, intraProcRank0);
|
||||
return ncclInternalError;
|
||||
}
|
||||
struct ncclComm* intraRank0Comm = allGather1Data[intraRank0].comm;
|
||||
if (intraNodeRank == -1 || intraNodeRank0 == -1 || intraNodeRanks == 0) {
|
||||
WARN("Failed to determine intra node ranks rank %d hostHash %lx pidHash %lx intraNodeRank %d intraNodeRanks %d intraNodeRank0 %d",
|
||||
rank, allGather1Data[rank].peerInfo.hostHash, allGather1Data[rank].peerInfo.pidHash,
|
||||
intraNodeRank, intraNodeRanks, intraNodeRank0);
|
||||
return ncclInternalError;
|
||||
}
|
||||
struct ncclComm* intraProcRank0Comm = allGather1Data[intraProcRank0].comm;
|
||||
uint64_t intraNodeRank0pidHash = allGather1Data[intraNodeRank0].peerInfo.pidHash;
|
||||
|
||||
// AllGather1 - end
|
||||
|
||||
@@ -373,7 +555,7 @@ ncclResult_t initTransportsRank_1(struct ncclComm* comm, struct allGather1Data_t
|
||||
|
||||
//struct ncclTopoGraph treeGraph;
|
||||
treeGraph.id = 1;
|
||||
treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE;
|
||||
treeGraph.pattern = NCCL_TOPO_PATTERN_BALANCED_TREE;
|
||||
treeGraph.crossNic = ncclParamCrossNic();
|
||||
treeGraph.collNet = 0;
|
||||
treeGraph.minChannels = comm->topo->nodes[NET].count != 0 ? 1 : ringGraph.nChannels;
|
||||
@@ -441,9 +623,9 @@ ncclResult_t initTransportsRank_1(struct ncclComm* comm, struct allGather1Data_t
|
||||
NCCLCHECK(ncclTopoDumpGraphs(comm->topo, 3, graphs));
|
||||
}
|
||||
|
||||
// Determine CollNet support
|
||||
if (tmpNnodes > 1 && ncclParamCollNetEnable() == 1 && collNetSupport() == 1 && collNetGraph.nChannels > 0) comm->collNetSupport = 1;
|
||||
if (intraRanks > 8) {
|
||||
// Determine local CollNet support before all-gather
|
||||
if (ncclParamCollNetEnable() == 1 && collNetSupport() == 1 && collNetGraph.nChannels > 0) comm->collNetSupport = 1;
|
||||
if (intraNodeRanks > 8) {
|
||||
if (comm->collNetSupport == 1) WARN("CollNet currently only supports up to 8 GPUs per node");
|
||||
comm->collNetSupport = 0;
|
||||
}
|
||||
@@ -519,182 +701,6 @@ ncclResult_t initTransportsRank_1(struct ncclComm* comm, struct allGather1Data_t
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t initChannel(struct ncclComm* comm, int channelid) {
|
||||
struct ncclChannel* channel = comm->channels+channelid;
|
||||
if (channel->id != -1) return ncclSuccess;
|
||||
channel->id = channelid;
|
||||
|
||||
// Ring index to user rank table.
|
||||
//NCCLCHECK(ncclCudaCalloc(&channel->ring.devUserRanks, comm->nRanks));
|
||||
NCCLCHECK(ncclCalloc(&channel->ring.userRanks, comm->nRanks));
|
||||
|
||||
// Communication structures with peers.
|
||||
//NCCLCHECK(ncclCudaCalloc(&channel->devPeers, comm->nRanks+1)); // The extra one rank is for collnet root (i.e. network)
|
||||
NCCLCHECK(ncclCalloc(&channel->peers, comm->nRanks+1));
|
||||
for (size_t i=0; i<comm->nRanks+1; ++i) {
|
||||
for (int b=0; b<NCCL_MAX_CONNS; b++) {
|
||||
channel->peers[i].send[b].comm = comm;
|
||||
channel->peers[i].recv[b].comm = comm;
|
||||
}
|
||||
}
|
||||
|
||||
// Per-channel operation list.
|
||||
//NCCLCHECK(ncclCudaHostCalloc(&channel->workFifo, NCCL_MAX_OPS));
|
||||
//if (ncclGdrCopy != NULL && ncclParamGdrCopyFifoEnable() == 1) {
|
||||
// GDRCOPY support
|
||||
// We allocate a workFifo in GDR mapped CUDA memory
|
||||
// But we still allocate the Host workFifo so that we
|
||||
// can copy the work elements to CUDA memory on kernel launch
|
||||
//NCCLCHECK(ncclGdrCudaCalloc(&channel->workFifoGdr, &channel->workFifoDev, NCCL_MAX_OPS, &channel->gdrMemDesc));
|
||||
//} else {
|
||||
// The device workFifo is the Host one
|
||||
//channel->workFifoDev = channel->workFifo;
|
||||
//}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t setupChannel(struct ncclComm* comm, int channelId, int rank, int nranks, int* ringRanks) {
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d", rank, nranks);
|
||||
NCCLCHECK(initChannel(comm, channelId));
|
||||
|
||||
struct ncclRing* ring = &comm->channels[channelId].ring;
|
||||
// Reorganize ranks to start with rank.
|
||||
int shift;
|
||||
for (shift = 0; shift<nranks; shift++) {
|
||||
if (ringRanks[shift] == rank) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int i=0; i<nranks; i++) {
|
||||
ring->userRanks[i] = ringRanks[(i+shift)%nranks];
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t connectedByXGMI(int* ret, struct ncclTopoSystem* system, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) {
|
||||
*ret = 0;
|
||||
if (info1->hostHash != info2->hostHash) return ncclSuccess;
|
||||
int g1, g2;
|
||||
NCCLCHECK(ncclTopoRankToIndex(system, info1->rank, &g1));
|
||||
NCCLCHECK(ncclTopoRankToIndex(system, info2->rank, &g2));
|
||||
if (system->nodes[GPU].nodes[g1].paths[GPU][g2].type == PATH_NVL) *ret = 1;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
template <int type>
|
||||
static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex) {
|
||||
struct ncclPeerInfo* myInfo = comm->peerInfo+comm->rank;
|
||||
struct ncclPeerInfo* peerInfo = comm->peerInfo+peer;
|
||||
struct ncclConnector* connector = (type == 1) ? comm->channels[channelId].peers[peer].send + connIndex :
|
||||
comm->channels[channelId].peers[peer].recv + connIndex;
|
||||
// handle intra-node network connections
|
||||
int n1 = -1, n2 = -1;
|
||||
if (connIndex == NCCL_CONN_IDX_P2P_NET) {
|
||||
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, comm->rank, graph, channelId, (type == 1) ? 1 : 0, &n1));
|
||||
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, peer, graph, channelId, (type == 1) ? 0 : 1, &n2));
|
||||
}
|
||||
|
||||
int xgmi;
|
||||
NCCLCHECK(connectedByXGMI(&xgmi, comm->topo, myInfo, peerInfo));
|
||||
for (int t=0; t<NTRANSPORTS; t++) {
|
||||
if (graph == NULL && connIndex == NCCL_CONN_IDX_P2P_NET && (t == TRANSPORT_SHM || (!xgmi && t == TRANSPORT_P2P))) continue;
|
||||
if (graph && n1 >= 0 && n2 >= 0 && t != TRANSPORT_NET) continue;
|
||||
struct ncclTransport *transport = ncclTransports+t;
|
||||
struct ncclTransportComm* transportComm = type == 1 ? &transport->send : &transport->recv;
|
||||
int ret = 0;
|
||||
NCCLCHECK(transport->canConnect(&ret, comm->topo, graph, myInfo, peerInfo));
|
||||
if (ret) {
|
||||
connector->transportComm = transportComm;
|
||||
NCCLCHECK(transportComm->setup(comm, graph, myInfo, peerInfo, connect, connector, channelId, connIndex));
|
||||
return ncclSuccess;
|
||||
}
|
||||
}
|
||||
WARN("No transport found !");
|
||||
return ncclInternalError;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, struct ncclChannel* channel, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex) {
|
||||
TRACE(NCCL_INIT, "nsend %d nrecv %d", nsend, nrecv);
|
||||
uint32_t mask = 1 << channel->id;
|
||||
for (int i=0; i<nrecv; i++) {
|
||||
int peer = peerRecv[i];
|
||||
if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer].recv[connIndex].connected) continue;
|
||||
comm->connectRecv[peer] |= mask;
|
||||
}
|
||||
for (int i=0; i<nsend; i++) {
|
||||
int peer = peerSend[i];
|
||||
if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer].send[connIndex].connected) continue;
|
||||
comm->connectSend[peer] |= mask;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex) {
|
||||
// Stream used during transport setup; need for P2P pre-connect + CUDA Graph
|
||||
//hipStream_t transportSetupStream;
|
||||
//CUDACHECK(hipStreamCreateWithFlags(&transportSetupStream, hipStreamNonBlocking));
|
||||
|
||||
struct ncclConnect data[2*MAXCHANNELS];
|
||||
for (int i=1; i<comm->nRanks; i++) {
|
||||
int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0);
|
||||
int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks;
|
||||
int sendPeer = (comm->rank + i) % comm->nRanks;
|
||||
uint32_t recvMask = comm->connectRecv[recvPeer];
|
||||
uint32_t sendMask = comm->connectSend[sendPeer];
|
||||
|
||||
struct ncclConnect* recvData = data;
|
||||
int sendChannels = 0, recvChannels = 0;
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (recvMask & (1<<c)) {
|
||||
NCCLCHECK(selectTransport<0>(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex));
|
||||
}
|
||||
}
|
||||
struct ncclConnect* sendData = recvData+recvChannels;
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (sendMask & (1<<c)) {
|
||||
NCCLCHECK(selectTransport<1>(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex));
|
||||
}
|
||||
}
|
||||
|
||||
if (sendPeer == recvPeer) {
|
||||
if (recvChannels+sendChannels) {
|
||||
//NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)));
|
||||
//NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)));
|
||||
sendData = data;
|
||||
recvData = data+sendChannels;
|
||||
}
|
||||
} else {
|
||||
//if (recvChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels));
|
||||
//if (sendChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels));
|
||||
//if (sendChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels));
|
||||
//if (recvChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels));
|
||||
}
|
||||
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (sendMask & (1<<c)) {
|
||||
struct ncclConnector* conn = comm->channels[c].peers[sendPeer].send + connIndex;
|
||||
//NCCLCHECK(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn));
|
||||
conn->connected = 1;
|
||||
//CUDACHECK(hipMemcpyAsync(comm->channels[c].devPeers[sendPeer].send+connIndex, conn, sizeof(struct ncclConnector), hipMemcpyHostToDevice, transportSetupStream));
|
||||
}
|
||||
}
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (recvMask & (1<<c)) {
|
||||
struct ncclConnector* conn = comm->channels[c].peers[recvPeer].recv + connIndex;
|
||||
//NCCLCHECK(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn));
|
||||
conn->connected = 1;
|
||||
//CUDACHECK(hipMemcpyAsync(comm->channels[c].devPeers[recvPeer].recv+connIndex, conn, sizeof(struct ncclConnector), hipMemcpyHostToDevice, transportSetupStream));
|
||||
}
|
||||
}
|
||||
comm->connectRecv[recvPeer] = comm->connectSend[sendPeer] = 0;
|
||||
}
|
||||
//CUDACHECK(hipStreamSynchronize(transportSetupStream));
|
||||
//CUDACHECK(hipStreamDestroy(transportSetupStream));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
||||
ncclResult_t initTransportsRank_3(struct ncclComm* comm, struct allGather3Data_t *allGather3Data,
|
||||
struct ncclTopoGraph& treeGraph, struct ncclTopoGraph& ringGraph, struct ncclTopoGraph& collNetGraph) {
|
||||
int rank = comm->rank;
|
||||
@@ -758,6 +764,14 @@ ncclResult_t initTransportsRank_3(struct ncclComm* comm, struct allGather3Data_t
|
||||
for (int i=0; i<comm->nChannels; i++) memcpy(comm->channels+comm->nChannels+i, comm->channels+nChannelsOrig+i, sizeof(struct ncclChannel));
|
||||
}
|
||||
|
||||
// Determine CollNet support after all-gather now that we know nNodes
|
||||
int collNetNodeThreshold = ncclParamCollNetNodeThreshold();
|
||||
if (comm->nNodes < collNetNodeThreshold) {
|
||||
if (comm->collNetSupport == 1)
|
||||
INFO(NCCL_INIT, "Communicator has %d nodes which is less than CollNet node threshold %d, disabling CollNet", comm->nNodes, collNetNodeThreshold);
|
||||
comm->collNetSupport = 0;
|
||||
}
|
||||
|
||||
int *rings;
|
||||
NCCLCHECK(ncclCalloc(&rings, nranks*MAXCHANNELS));
|
||||
NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph, nc));
|
||||
@@ -784,9 +798,12 @@ ncclResult_t initTransportsRank_3(struct ncclComm* comm, struct allGather3Data_t
|
||||
|
||||
// Set Affinity to a CPU local the our GPU, so that all memory we allocate
|
||||
// on the host is local.
|
||||
cpu_set_t affinitySave;
|
||||
sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
NCCLCHECK(ncclTopoSetAffinity(comm->topo, comm->rank));
|
||||
//NCCLCHECK(ncclTopoGetCpuAffinity(comm->topo, comm->rank, &comm->cpuAffinity));
|
||||
//cpu_set_t affinitySave;
|
||||
//if (CPU_COUNT(&comm->cpuAffinity)) {
|
||||
// sched_getaffinity(0, sizeof(cpu_set_t), &affinitySave);
|
||||
// sched_setaffinity(0, sizeof(cpu_set_t), &comm->cpuAffinity);
|
||||
//}
|
||||
ncclResult_t ret;
|
||||
|
||||
//NCCLCHECK(computeBuffSizes(comm));
|
||||
@@ -837,10 +854,8 @@ ncclResult_t initTransportsRank_3(struct ncclComm* comm, struct allGather3Data_t
|
||||
struct ncclChannel* channel = comm->channels+c;
|
||||
for (int h=0; h<nHeads; h++) {
|
||||
const int head = heads[h];
|
||||
if (ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetRecv) != 1)
|
||||
collNetSetupFail = 1;
|
||||
else if (ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetSend) != 1)
|
||||
collNetSetupFail = 1;
|
||||
collNetSetupFail = ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetRecv);
|
||||
if (!collNetSetupFail) collNetSetupFail = ncclTransportCollNetSetup(comm, &collNetGraph, channel, head, head, h, collNetSend);
|
||||
}
|
||||
// Verify CollNet setup across ranks after trying the first channel
|
||||
if (c == 0) {
|
||||
|
||||
Ссылка в новой задаче
Block a user