diff --git a/test/CorrectnessTest.hpp b/test/CorrectnessTest.hpp index 368f06ae38..9df068c162 100644 --- a/test/CorrectnessTest.hpp +++ b/test/CorrectnessTest.hpp @@ -689,15 +689,15 @@ dropback: // NOTE: Currently half-precision float tests are unsupported due to half being supported // on GPU only and not host - // Fills input data[i][j] with (i + j) % 6 + // Fills input data[i][j] with (i + j) % 256 // - Keeping range small to reduce likelihood of overflow // - Sticking with floating points values that are perfectly representable for (int i = 0; i < dataset.numDevices; i++) { for (int j = 0; j < dataset.NumBytes(ncclInputBuffer)/DataTypeToBytes(dataset.dataType); j++) { - int valueI = (i + j) % 6; - float valueF = (float)valueI; + int valueI = (i + j) % 256; + double valueF = 1.0L/((double)valueI+1.0L); switch (dataset.dataType) { @@ -812,7 +812,7 @@ dropback: case ncclUint64: isMatch &= (outputU8[j] == expectedU8[j]); break; case ncclFloat32: isMatch &= (fabs(outputF4[j] - expectedF4[j]) < 1e-5); break; case ncclFloat64: isMatch &= (fabs(outputF8[j] - expectedF8[j]) < 1e-12); break; - case ncclBfloat16: isMatch &= (fabs((float)outputB2[j] - (float)expectedB2[j]) < 1e-2); break; + case ncclBfloat16: isMatch &= (fabs((float)outputB2[j] - (float)expectedB2[j]) < 5e-2); break; default: fprintf(stderr, "[ERROR] Unsupported datatype\n"); exit(0);