diff --git a/src/misc/argcheck.cc b/src/misc/argcheck.cc index 6977c2d6ac..70a3dc341c 100644 --- a/src/misc/argcheck.cc +++ b/src/misc/argcheck.cc @@ -73,10 +73,10 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) { } } else { // Check CUDA device pointers - if (info->coll != ncclCollBroadcast || info->comm->rank == info->root) { + if ((info->coll != ncclCollBroadcast && info->coll != ncclCollScatter) || info->comm->rank == info->root) { NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", info->opName)); } - if (info->coll != ncclCollReduce || info->comm->rank == info->root) { + if ((info->coll != ncclCollReduce && info->coll != ncclCollGather) || info->comm->rank == info->root) { NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", info->opName)); } }