Use bit reversal based mapping for multi-node (#1572)
[ROCm/rccl commit: 85eb1f16bc]
Цей коміт міститься в:
@@ -163,7 +163,7 @@ struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
|
||||
struct ncclDevWorkP2p* work = &works[workIx];
|
||||
size_t bytes = isSend ? work->sendBytes : work->recvBytes;
|
||||
int nParts = isSend ? work->nSendChannels : work->nRecvChannels;
|
||||
int part = ncclP2pChannelToPart(work->nP2pChannels, work->channelBase, ncclShmem.channelId, ncclShmem.comm.p2pnChannelsPerPeer);
|
||||
int part = ncclP2pChannelToPart(work->nP2pChannels, work->channelBase, ncclShmem.channelId, ncclShmem.comm.p2pnChannelsPerPeer, ncclShmem.comm.nNodes);
|
||||
hasWork = (part < nParts);
|
||||
if (nParts != 0) {
|
||||
size_t partBeg, partEnd;
|
||||
|
||||
@@ -890,7 +890,7 @@ static ncclResult_t addP2pToPlan(
|
||||
|
||||
if (!selfSend) {
|
||||
for (int part=0; part < nChannelsMax; part++) {
|
||||
int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, nChannelsMax);
|
||||
int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, nChannelsMax, comm->nNodes);
|
||||
struct ncclChannelPeer** channelPeers = comm->channels[channelId].peers;
|
||||
for (int dir=0; dir <= 1; dir++) {
|
||||
int peerRank = dir ? sendRank : recvRank;
|
||||
@@ -1006,7 +1006,7 @@ static ncclResult_t addP2pToPlan(
|
||||
|
||||
nChannelsMax = std::max(nChannels[0], nChannels[1]);
|
||||
for (int part=0; part < nChannelsMax; part++) {
|
||||
int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, comm->p2pnChannelsPerPeer);
|
||||
int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, comm->p2pnChannelsPerPeer, comm->nNodes);
|
||||
plan->channelMask.masks[channelId/64] |= uint64_t(1)<<(channelId%64);
|
||||
// Add batch first.
|
||||
int funcIdx = ncclDevFuncId_P2p();
|
||||
@@ -2133,7 +2133,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) {
|
||||
}
|
||||
uint8_t base = ncclP2pChannelBaseForRound(comm, round);
|
||||
for (int c=0; c < comm->p2pnChannelsPerPeer; c++) {
|
||||
int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, c, comm->p2pnChannelsPerPeer);
|
||||
int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, c, comm->p2pnChannelsPerPeer, comm->nNodes);
|
||||
if (isSendNotRecv) {
|
||||
if (comm->channels[channelId].peers[peer]->send[1].connected == 0) { // P2P uses only 1 connector
|
||||
//comm->connectSend[peer] |= (1UL<<channelId);
|
||||
|
||||
@@ -267,11 +267,25 @@ inline __host__ uint8_t ncclP2pChannelBaseForRound(struct ncclComm* comm, int p2
|
||||
|
||||
// ncclP2pChannelToPart and ncclP2pChannelForPart are inverses. The device code
|
||||
// uses ncclP2pChannelToPart to determine which part "this" channel is responsible for.
|
||||
inline __host__ int ncclP2pChannelForPart(int nP2pChannels, int base, int part, int nParts) {
|
||||
return (base * nParts + part) & (nP2pChannels-1);
|
||||
inline __host__ int ncclP2pChannelForPart(int nP2pChannels, int base, int part, int nParts, int nNodes) {
|
||||
if (nNodes > 1) {
|
||||
// Only works because nP2pChannels is pow2
|
||||
int nChannelsLog2 = countOneBits(nP2pChannels-1);
|
||||
int delta = reverseBits(part, nChannelsLog2);
|
||||
return (base + delta) & (nP2pChannels-1);
|
||||
} else {
|
||||
return (base * nParts + part) & (nP2pChannels-1);
|
||||
}
|
||||
}
|
||||
inline __device__ int ncclP2pChannelToPart(int nP2pChannels, int base, int channel, int nParts) {
|
||||
return (channel - base * nParts) & (nParts-1);
|
||||
inline __device__ int ncclP2pChannelToPart(int nP2pChannels, int base, int channel, int nParts, int nNodes) {
|
||||
if (nNodes > 1) {
|
||||
// Only works because nP2pChannels is pow2
|
||||
int nChannelsLog2 = countOneBits(nP2pChannels-1);
|
||||
int delta = (channel-base) & (nP2pChannels-1);
|
||||
return reverseBits(delta, nChannelsLog2);
|
||||
} else {
|
||||
return (channel - base * nParts) & (nParts-1);
|
||||
}
|
||||
}
|
||||
|
||||
struct alignas(16) ncclDevWorkColl {
|
||||
|
||||
@@ -1677,11 +1677,11 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
|
||||
uint8_t recvBase = ncclP2pChannelBaseForRound(comm, recvRound);
|
||||
for (int c=0; c<comm->p2pnChannelsPerPeer; c++) {
|
||||
int channelId;
|
||||
channelId = ncclP2pChannelForPart(comm->p2pnChannels, sendBase, c, comm->p2pnChannelsPerPeer);
|
||||
channelId = ncclP2pChannelForPart(comm->p2pnChannels, sendBase, c, comm->p2pnChannelsPerPeer, comm->nNodes);
|
||||
if (comm->channels[channelId].peers[peer]->send[1].connected == 0) {
|
||||
comm->connectSend[peer].masks[channelId/64] |= (1UL<<(channelId%64));
|
||||
}
|
||||
channelId = ncclP2pChannelForPart(comm->p2pnChannels, recvBase, c, comm->p2pnChannelsPerPeer);
|
||||
channelId = ncclP2pChannelForPart(comm->p2pnChannels, recvBase, c, comm->p2pnChannelsPerPeer, comm->nNodes);
|
||||
if (comm->channels[channelId].peers[peer]->recv[1].connected == 0) {
|
||||
comm->connectRecv[peer].masks[channelId/64] |= (1UL<<(channelId%64));
|
||||
}
|
||||
|
||||
Посилання в новій задачі
Заблокувати користувача