device: update the logic for channelId assignment

Este commit está contenido en:
Nusrat Islam
2024-05-22 13:37:15 -05:00
padre 506f16c506
commit 48859a97b1
+20 -14
Ver fichero
@@ -235,23 +235,29 @@ __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct
switch (tid/WARP_SIZE) {
case 0:
ncclShmem.channelId = blockIdx.x;
/*for (int i = 0; i < num; i++) {
if (channelMask.masks[i] & (1ull<<x)) {
y = __popcll(channelMask.masks[i] & ((1ull<<x)-1));
y = total + y;
if (blockIdx.x == y) ncclShmem.channelId = x;
}
if (WARP_SIZE < 64) {
x = WARP_SIZE + tid;
//ncclShmem.channelId = blockIdx.x;
for (int i = 0; i < num; i++) {
if (channelMask.masks[i] & (1ull<<x)) {
y = __popcll(channelMask.masks[i] & ((1ull<<x)-1));
y = y + total;
if (blockIdx.x == y) ncclShmem.channelId = x;
y = __popcll(channelMask.masks[i] & ((1ull<<x)-1));
y = total + y;
if (blockIdx.x == y) {
ncclShmem.channelId = y;
break;
}
}
if (WARP_SIZE < 64) {
x = WARP_SIZE + tid;
if (channelMask.masks[i] & (1ull<<x)) {
y = __popcll(channelMask.masks[i] & ((1ull<<x)-1));
y = y + total;
if (blockIdx.x == y) {
ncclShmem.channelId = y;
break;
}
}
}
total = total + __popcll(channelMask.masks[i]);
}
total = __popcll(channelMask.masks[i]);
}*/
break;
case 1:
if (tid < WARP_SIZE + NCCL_MAX_GROUPS)