device: update the logic for channelId assignment
Este commit está contenido en:
+20
-14
@@ -235,23 +235,29 @@ __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct
|
||||
|
||||
switch (tid/WARP_SIZE) {
|
||||
case 0:
|
||||
ncclShmem.channelId = blockIdx.x;
|
||||
/*for (int i = 0; i < num; i++) {
|
||||
if (channelMask.masks[i] & (1ull<<x)) {
|
||||
y = __popcll(channelMask.masks[i] & ((1ull<<x)-1));
|
||||
y = total + y;
|
||||
if (blockIdx.x == y) ncclShmem.channelId = x;
|
||||
}
|
||||
if (WARP_SIZE < 64) {
|
||||
x = WARP_SIZE + tid;
|
||||
//ncclShmem.channelId = blockIdx.x;
|
||||
for (int i = 0; i < num; i++) {
|
||||
if (channelMask.masks[i] & (1ull<<x)) {
|
||||
y = __popcll(channelMask.masks[i] & ((1ull<<x)-1));
|
||||
y = y + total;
|
||||
if (blockIdx.x == y) ncclShmem.channelId = x;
|
||||
y = __popcll(channelMask.masks[i] & ((1ull<<x)-1));
|
||||
y = total + y;
|
||||
if (blockIdx.x == y) {
|
||||
ncclShmem.channelId = y;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (WARP_SIZE < 64) {
|
||||
x = WARP_SIZE + tid;
|
||||
if (channelMask.masks[i] & (1ull<<x)) {
|
||||
y = __popcll(channelMask.masks[i] & ((1ull<<x)-1));
|
||||
y = y + total;
|
||||
if (blockIdx.x == y) {
|
||||
ncclShmem.channelId = y;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
total = total + __popcll(channelMask.masks[i]);
|
||||
}
|
||||
total = __popcll(channelMask.masks[i]);
|
||||
}*/
|
||||
break;
|
||||
case 1:
|
||||
if (tid < WARP_SIZE + NCCL_MAX_GROUPS)
|
||||
|
||||
Referencia en una nueva incidencia
Block a user