Fix when more than 64 channels are used for multi-collective group calls (#1688)

* Fix when more than 64 channels are used for multi-collective group calls

* Update CHANGELOG.md

Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>

---------

Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>
Este commit está contenido en:
gilbertlee-amd
2025-05-12 18:05:57 -05:00
cometido por GitHub
padre 5f6805b4f4
commit 9ef45df8f7
Se han modificado 2 ficheros con 17 adiciones y 12 borrados
+3 -4
Ver fichero
@@ -4,6 +4,9 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
## Unreleased - RCCL 2.24.3 for ROCm 6.5.0
### Resolved issues
* Resolved an issue when using more than 64 channels when multiple collectives are used in the same `ncclGroup()` call.
### Added
* Added new GPU target `gfx950`.
@@ -12,10 +15,6 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
* Compatibility with NCCL 2.24.3
### Known issue
* Using more than 64 channels can cause a segmentation fault when multiple collectives are used in the same `ncclGroup()` call.
## Unreleased - RCCL 2.23.4 for ROCm 6.4.1
### Added
+14 -8
Ver fichero
@@ -214,12 +214,16 @@ static void finishPlan(struct ncclComm* comm, struct ncclKernelPlan* plan) {
struct channelMasks hasBatchMask = plan->channelMask;
struct ncclDevWorkBatch* batchPrev[MAXCHANNELS] = {}; // {0...}
struct ncclDevWorkBatch* batchZero = (struct ncclDevWorkBatch*)(plan->kernelArgs+1);
// [RCCL] Preparing batchZero slightly different to support > 64 Channels
// Need to ensure that all channels are processed first before dealing with
// adding additional batches
int batchIx = 0;
for (int maskIdx = 0; maskIdx < MAXCHANNELS/64; maskIdx++) {
while (hasBatchMask.masks[maskIdx] != 0) {
uint64_t tmpMask = hasBatchMask.masks[maskIdx]; // channels with a batch for this round.
do {
int c = popFirstOneBit(&tmpMask) + maskIdx * 64;
int done = 0;
while (!done) {
done = 1;
for (int c = 0; c < MAXCHANNELS; c++) {
if (hasBatchMask.masks[c / 64] & (1ULL << (c%64))) {
if (!ncclIntruQueueEmpty(&wipChannels[c].workBatchQueue)) {
struct ncclWorkBatchList* batchNode = ncclIntruQueueDequeue(&wipChannels[c].workBatchQueue);
if (batchPrev[c] != nullptr) {
@@ -229,9 +233,11 @@ static void finishPlan(struct ncclComm* comm, struct ncclKernelPlan* plan) {
batchZero[batchIx++] = batchNode->batch;
}
if (ncclIntruQueueEmpty(&wipChannels[c].workBatchQueue)) {
hasBatchMask.masks[maskIdx] ^= 1ull<<(c%64);
hasBatchMask.masks[c / 64] ^= (1ULL << (c%64));
} else {
done = 0;
}
} while (tmpMask != 0);
}
}
}
@@ -2131,7 +2137,7 @@ static ncclResult_t hostToDevRedOp(
uint64_t allBits = uint64_t(-1)>>(64-nbits);
uint64_t signBit = allBits^(allBits>>1);
bool datatype_signed = false;
switch (int(op)) {
case ncclSum: opFull->op = ncclDevSum; break;
case ncclProd: opFull->op = ncclDevProd; break;