diff --git a/verifiable/verifiable.cu b/verifiable/verifiable.cu index 9d8e56aba9..a375809bcf 100644 --- a/verifiable/verifiable.cu +++ b/verifiable/verifiable.cu @@ -14,9 +14,6 @@ #include "rccl/rccl.h" - -#define RCCL_BFLOAT 1 - #if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && RCCL_BFLOAT16 ==1 #define HAVE_ncclBfloat16 1 #else @@ -124,7 +121,7 @@ namespace { return Y(x); } template<> - __host__ __device__ half castTo<__half>(float x) { + __host__ __device__ __half castTo<__half>(float x) { return __float2half(x); } #if RCCL_BFLOAT16 == 1 @@ -425,7 +422,7 @@ __host__ __device__ void genSumXY( // Let s be the number of ranks per partition. This is either rn/pn as we // intended, or y/p_sum if that's smaller to prevent overshooting our target y. uint32_t s = y/p_sum < rn/pn ? y/p_sum : rn/pn; - x = r/s < pn ? 1 + r/s : 0; // First s*pn ranks contribute partition index +1. + x = (s != 0 && r/s < pn) ? 1 + r/s : 0; // First s*pn ranks contribute partition index +1. x += r == rn-1 ? y - s*p_sum : 0; // Last rank contributes discrepancy. } }