Merge pull request #154 from wenkaidu/bf16

Add bfloat16 support in RCCL
This commit is contained in:
Wenkai Du
2019-11-19 09:07:51 -08:00
committed by GitHub
21 changed files with 402 additions and 16 deletions
+5
View File
@@ -2,6 +2,11 @@
cmake_minimum_required(VERSION 2.8.12)
# We use C++14 features, this will add compile option: -std=c++14
set( CMAKE_CXX_STANDARD 14 )
# Without this line, it will add -std=gnu++14 instead, which has some issues.
set( CMAKE_CXX_EXTENSIONS OFF )
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
project(rccl CXX)
+2 -1
View File
@@ -39,7 +39,8 @@
DECL_COLL3(coll, op, u64) \
DECL_COLL3(coll, op, f16) \
DECL_COLL3(coll, op, f32) \
DECL_COLL3(coll, op, f64)
DECL_COLL3(coll, op, f64) \
DECL_COLL3(coll, op, b16)
#define DECL_COLL(coll) \
DECL_COLL2(coll, sum) \
+9 -6
View File
@@ -53,7 +53,8 @@ static inline __device__ void exitIfAbortBarrier(int abort) {
NCCL_FUNC4(coll, op, u64), \
NCCL_FUNC4(coll, op, f16), \
NCCL_FUNC4(coll, op, f32), \
NCCL_FUNC4(coll, op, f64)
NCCL_FUNC4(coll, op, f64), \
NCCL_FUNC4(coll, op, b16)
#define NCCL_FUNCS3B(coll, op) \
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8), \
@@ -63,6 +64,7 @@ static inline __device__ void exitIfAbortBarrier(int abort) {
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8)
// Must be consistent with ncclRedOp_t
@@ -121,20 +123,20 @@ struct Caller<f, f + 1>{
inline
__device__
void NCCL_CALL_FUNCTIONS(struct ncclColl* const c) noexcept {
if (c->funcIndex < 144) {
if (c->funcIndex < 160) {
if (c->funcIndex % 4 == 0) ncclBroadcastRing_copy_i8(&c->args);
else if (c->funcIndex % 4 == 1) ncclBroadcastRingLL_copy_i8(&c->args);
else if (c->funcIndex % 4 == 2) ncclBroadcastTree_copy_i8(&c->args);
else ncclBroadcastTreeLL_copy_i8(&c->args);
}
else if (c->funcIndex < 288) Caller<144, 288>::call(c);
else if (c->funcIndex < 432) {
else if (c->funcIndex < 320) Caller<160, 320>::call(c);
else if (c->funcIndex < 480) {
if (c->funcIndex % 4 == 0) ncclAllGatherRing_copy_i8(&c->args);
else if (c->funcIndex % 4 == 1) ncclAllGatherRingLL_copy_i8(&c->args);
else if (c->funcIndex % 4 == 2) ncclAllGatherTree_copy_i8(&c->args);
else ncclAllGatherTreeLL_copy_i8(&c->args);
}
else Caller<432, 720>::call(c);
else Caller<480, 800>::call(c);
}
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid, uint32_t* abortCount) {
@@ -227,7 +229,8 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
IMPL_COLL3(coll, op, ncclFunc, u64, uint64_t, ncclColl, ncclOp, ncclUint64) \
IMPL_COLL3(coll, op, ncclFunc, f16, half, ncclColl, ncclOp, ncclFloat16) \
IMPL_COLL3(coll, op, ncclFunc, f32, float, ncclColl, ncclOp, ncclFloat32) \
IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64)
IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64) \
IMPL_COLL3(coll, op, ncclFunc, b16, rccl_bfloat16, ncclColl, ncclOp, ncclBfloat16)
#define COLL_UNROLL 2
+12
View File
@@ -241,6 +241,18 @@ template<> inline __device__
void vStore<half>(volatile half* ptr, const half val) {
((half*)ptr)[0] = val;
}
template<> inline __device__
rccl_bfloat16 vFetch<rccl_bfloat16>(const volatile rccl_bfloat16* ptr) {
rccl_bfloat16 r;
r.data = ptr->data;
return r;
}
template<> inline __device__
void vStore<rccl_bfloat16>(volatile rccl_bfloat16* ptr, const rccl_bfloat16 val) {
ptr->data = val.data;
}
#endif
typedef ulong2 Pack128;
+80
View File
@@ -134,6 +134,86 @@ struct FuncMin : private FuncBase<T> {
}
};
template<>
struct FuncSum<rccl_bfloat16> {
static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16);
__device__ PackType operator()(PackType x, PackType y) const
{
union converter { PackType storage; rccl_bfloat16 vec[n]; };
static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter.");
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
for (auto i = 0u; i != n; ++i) {
cr.vec[i] = cx.vec[i] + cy.vec[i];
}
return cr.storage;
}
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
return x + y;
}
};
template<>
struct FuncProd<rccl_bfloat16> {
static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16);
__device__ PackType operator()(PackType x, PackType y) const
{
union converter { PackType storage; rccl_bfloat16 vec[n]; };
static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter.");
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
for (auto i = 0u; i != n; ++i) {
cr.vec[i] = cx.vec[i] * cy.vec[i];
}
return cr.storage;
}
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
return x * y;
}
};
template<>
struct FuncMax<rccl_bfloat16> {
static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16);
__device__ PackType operator()(PackType x, PackType y) const
{
union converter { PackType storage; rccl_bfloat16 vec[n]; };
static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter.");
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
for (auto i = 0u; i != n; ++i) {
cr.vec[i] = cx.vec[i] < cy.vec[i] ? cy.vec[i] : cx.vec[i];
}
return cr.storage;
}
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
return x < y ? y : x;
}
};
template<>
struct FuncMin<rccl_bfloat16> {
static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16);
__device__ PackType operator()(PackType x, PackType y) const
{
union converter { PackType storage; rccl_bfloat16 vec[n]; };
static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter.");
converter cx, cy, cr;
cx.storage = x;
cy.storage = y;
for (auto i = 0u; i != n; ++i) {
cr.vec[i] = cx.vec[i] < cy.vec[i] ? cx.vec[i] : cy.vec[i];
}
return cr.storage;
}
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
return x < y ? x : y;
}
};
#else
template<typename T>
+3 -1
View File
@@ -30,7 +30,8 @@
NCCL_FUNC4(coll, op, u64), \
NCCL_FUNC4(coll, op, f16), \
NCCL_FUNC4(coll, op, f32), \
NCCL_FUNC4(coll, op, f64)
NCCL_FUNC4(coll, op, f64), \
NCCL_FUNC4(coll, op, b16)
#define NCCL_FUNCS3B(coll, op) \
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8), \
@@ -40,6 +41,7 @@
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8), \
NCCL_FUNC4(coll, op, i8)
// Must be consistent with ncclRedOp_t -- but we only generate kernel for sums.
+1
View File
@@ -48,6 +48,7 @@ static __inline__ int ncclTypeSize(ncclDataType_t type) {
case ncclUint8:
return 1;
case ncclFloat16:
case ncclBfloat16:
return 2;
case ncclInt32:
case ncclUint32:
+1
View File
@@ -9,6 +9,7 @@
#define NCCL_DEVICE_H_
#include "nccl.h"
#include "rccl_bfloat16.h"
#include <stdint.h>
// Convert volatile access to atomic
+253
View File
@@ -0,0 +1,253 @@
/**
* MIT License
*
* Copyright 2019 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
/*!\file
* \brief rccl_bfloat16.h provides struct for rccl_bfloat16 typedef
*/
#ifndef _RCCL_BFLOAT16_H_
#define _RCCL_BFLOAT16_H_
#if __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__))
// If this is a C compiler, C++ compiler below C++14, or a host-only compiler, we only
// include a minimal definition of rccl_bfloat16
#include <stdint.h>
/*! \brief Struct to represent a 16 bit brain floating point number. */
typedef struct
{
uint16_t data;
} rccl_bfloat16;
#else // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__))
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <hip/hip_runtime.h>
#include <ostream>
#include <type_traits>
struct rccl_bfloat16
{
uint16_t data;
__host__ __device__ rccl_bfloat16() = default;
// round upper 16 bits of IEEE float to convert to bfloat16
explicit constexpr __host__ __device__ rccl_bfloat16(float f)
: data(float_to_bfloat16(f))
{
}
// zero extend lower 16 bits of bfloat16 to convert to IEEE float
constexpr __host__ __device__ operator float() const
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(data) << 16};
return u.fp32;
}
private:
static constexpr __host__ __device__ uint16_t float_to_bfloat16(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}
};
typedef struct
{
uint16_t data;
} rccl_bfloat16_public;
static_assert(std::is_standard_layout<rccl_bfloat16>{},
"rccl_bfloat16 is not a standard layout type, and thus is "
"incompatible with C.");
static_assert(std::is_trivial<rccl_bfloat16>{},
"rccl_bfloat16 is not a trivial type, and thus is "
"incompatible with C.");
static_assert(sizeof(rccl_bfloat16) == sizeof(rccl_bfloat16_public)
&& offsetof(rccl_bfloat16, data) == offsetof(rccl_bfloat16_public, data),
"internal rccl_bfloat16 does not match public rccl_bfloat16");
inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat16& bf16)
{
return os << float(bf16);
}
constexpr __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a)
{
return a;
}
constexpr __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a)
{
a.data ^= 0x8000;
return a;
}
constexpr __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a, rccl_bfloat16 b)
{
return rccl_bfloat16(float(a) + float(b));
}
constexpr __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a, rccl_bfloat16 b)
{
return rccl_bfloat16(float(a) - float(b));
}
constexpr __host__ __device__ rccl_bfloat16 operator*(rccl_bfloat16 a, rccl_bfloat16 b)
{
return rccl_bfloat16(float(a) * float(b));
}
constexpr __host__ __device__ rccl_bfloat16 operator/(rccl_bfloat16 a, rccl_bfloat16 b)
{
return rccl_bfloat16(float(a) / float(b));
}
constexpr __host__ __device__ bool operator<(rccl_bfloat16 a, rccl_bfloat16 b)
{
return float(a) < float(b);
}
constexpr __host__ __device__ bool operator==(rccl_bfloat16 a, rccl_bfloat16 b)
{
return float(a) == float(b);
}
constexpr __host__ __device__ bool operator>(rccl_bfloat16 a, rccl_bfloat16 b)
{
return b < a;
}
constexpr __host__ __device__ bool operator<=(rccl_bfloat16 a, rccl_bfloat16 b)
{
return !(a > b);
}
constexpr __host__ __device__ bool operator!=(rccl_bfloat16 a, rccl_bfloat16 b)
{
return !(a == b);
}
constexpr __host__ __device__ bool operator>=(rccl_bfloat16 a, rccl_bfloat16 b)
{
return !(a < b);
}
constexpr __host__ __device__ rccl_bfloat16& operator+=(rccl_bfloat16& a, rccl_bfloat16 b)
{
return a = a + b;
}
constexpr __host__ __device__ rccl_bfloat16& operator-=(rccl_bfloat16& a, rccl_bfloat16 b)
{
return a = a - b;
}
constexpr __host__ __device__ rccl_bfloat16& operator*=(rccl_bfloat16& a, rccl_bfloat16 b)
{
return a = a * b;
}
constexpr __host__ __device__ rccl_bfloat16& operator/=(rccl_bfloat16& a, rccl_bfloat16 b)
{
return a = a / b;
}
constexpr __host__ __device__ rccl_bfloat16& operator++(rccl_bfloat16& a)
{
return a += rccl_bfloat16(1.0f);
}
constexpr __host__ __device__ rccl_bfloat16& operator--(rccl_bfloat16& a)
{
return a -= rccl_bfloat16(1.0f);
}
constexpr __host__ __device__ rccl_bfloat16 operator++(rccl_bfloat16& a, int)
{
rccl_bfloat16 orig = a;
++a;
return orig;
}
constexpr __host__ __device__ rccl_bfloat16 operator--(rccl_bfloat16& a, int)
{
rccl_bfloat16 orig = a;
--a;
return orig;
}
namespace std
{
constexpr __host__ __device__ bool isinf(rccl_bfloat16 a)
{
return !(~a.data & 0x7f80) && !(a.data & 0x7f);
}
constexpr __host__ __device__ bool isnan(rccl_bfloat16 a)
{
return !(~a.data & 0x7f80) && +(a.data & 0x7f);
}
constexpr __host__ __device__ bool iszero(rccl_bfloat16 a)
{
return !(a.data & 0x7fff);
}
inline rccl_bfloat16 sin(rccl_bfloat16 a)
{
return rccl_bfloat16(sinf(float(a)));
}
inline rccl_bfloat16 cos(rccl_bfloat16 a)
{
return rccl_bfloat16(cosf(float(a)));
}
}
#endif // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__))
#endif // _RCCL_BFLOAT16_H_
+4 -1
View File
@@ -19,6 +19,8 @@
#define NCCL_VERSION_CODE ${NCCL_VERSION}
#define NCCL_VERSION(X,Y,Z) ((X) * 1000 + (Y) * 100 + (Z))
#define RCCL_BFLOAT16 1
#ifdef __cplusplus
extern "C" {
#endif
@@ -116,7 +118,8 @@ typedef enum { ncclInt8 = 0, ncclChar = 0,
ncclFloat16 = 6, ncclHalf = 6,
ncclFloat32 = 7, ncclFloat = 7,
ncclFloat64 = 8, ncclDouble = 8,
ncclNumTypes = 9 } ncclDataType_t;
ncclBfloat16 = 9,
ncclNumTypes = 10 } ncclDataType_t;
/*
* Collective communication operations
+9
View File
@@ -11,6 +11,7 @@
#include <vector>
#include <gtest/gtest.h>
#include "rccl.h"
#include "../include/rccl_bfloat16.h"
#define HIP_CALL(x) ASSERT_EQ(x, hipSuccess)
#define NCCL_CALL(x) ASSERT_EQ(x, ncclSuccess)
@@ -47,6 +48,7 @@ namespace CorrectnessTests
case ncclFloat16: return 2;
case ncclFloat32: return 4;
case ncclFloat64: return 8;
case ncclBfloat16: return 2;
default:
fprintf(stderr, "[ERROR] Unsupported datatype (%d)\n", dataType);
exit(0);
@@ -217,6 +219,7 @@ namespace CorrectnessTests
uint64_t* arrayU8 = (uint64_t *)arrayI1;
float* arrayF4 = (float *)arrayI1;
double* arrayF8 = (double *)arrayI1;
rccl_bfloat16* arrayB2 = (rccl_bfloat16 *)arrayI1;
// NOTE: Currently half-precision float tests are unsupported due to half being supported
// on GPU only and not host
@@ -241,6 +244,7 @@ namespace CorrectnessTests
case ncclUint64: arrayU8[j] = valueI; break;
case ncclFloat32: arrayF4[j] = valueF; break;
case ncclFloat64: arrayF8[j] = valueF; break;
case ncclBfloat16: arrayB2[j] = rccl_bfloat16(valueF); break;
default:
fprintf(stderr, "[ERROR] Unsupported datatype\n");
exit(0);
@@ -278,6 +282,7 @@ namespace CorrectnessTests
uint64_t* outputU8 = (uint64_t *)outputI1;
float* outputF4 = (float *)outputI1;
double* outputF8 = (double *)outputI1;
rccl_bfloat16* outputB2 = (rccl_bfloat16 *)outputI1;
bool isMatch = true;
@@ -295,6 +300,7 @@ namespace CorrectnessTests
uint64_t* expectedU8 = (uint64_t *)expectedI1;
float* expectedF4 = (float *)expectedI1;
double* expectedF8 = (double *)expectedI1;
rccl_bfloat16* expectedB2 = (rccl_bfloat16 *)expectedI1;
for (int j = 0; j < dataset.numElements && isMatch; j++)
{
@@ -308,6 +314,7 @@ namespace CorrectnessTests
case ncclUint64: isMatch &= (outputU8[j] == expectedU8[j]); break;
case ncclFloat32: isMatch &= (outputF4[j] == expectedF4[j]); break;
case ncclFloat64: isMatch &= (outputF8[j] == expectedF8[j]); break;
case ncclBfloat16: isMatch &= (outputB2[j] == expectedB2[j]); break;
default:
fprintf(stderr, "[ERROR] Unsupported datatype\n");
exit(0);
@@ -333,6 +340,8 @@ namespace CorrectnessTests
printf("Expected %f. Output %f on device %d[%d]\n", outputF4[j], expectedF4[j], i, j); break;
case ncclFloat64:
printf("Expected %lf. Output %lf on device %d[%d]\n", outputF8[j], expectedF8[j], i, j); break;
case ncclBfloat16:
printf("Expected %f. Output %f on device %d[%d]\n", (float)outputB2[j], (float)expectedB2[j], i, j); break;
default:
fprintf(stderr, "[ERROR] Unsupported datatype\n");
exit(0);
+2 -1
View File
@@ -101,7 +101,8 @@ namespace CorrectnessTests
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
ncclFloat64,
ncclBfloat16),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
+2 -1
View File
@@ -50,7 +50,8 @@ namespace CorrectnessTests
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
ncclFloat64,
ncclBfloat16),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
+3
View File
@@ -29,6 +29,7 @@ namespace CorrectnessTests
uint64_t* resultU8 = (uint64_t *)resultI1;
float* resultF4 = (float *)resultI1;
double* resultF8 = (double *)resultI1;
rccl_bfloat16* resultB2 = (rccl_bfloat16 *)resultI1;
// Initialize the result with the first device's array
memcpy(resultI1, dataset.expected[0], dataset.NumBytes());
@@ -44,6 +45,7 @@ namespace CorrectnessTests
uint64_t* arrayU8 = (uint64_t *)arrayI1;
float* arrayF4 = (float *)arrayI1;
double* arrayF8 = (double *)arrayI1;
rccl_bfloat16* arrayB2 = (rccl_bfloat16 *)arrayI1;
for (int j = 0; j < dataset.numElements; j++)
{
@@ -57,6 +59,7 @@ namespace CorrectnessTests
case ncclUint64: resultU8[j] = ReduceOp(op, resultU8[j], arrayU8[j]); break;
case ncclFloat32: resultF4[j] = ReduceOp(op, resultF4[j], arrayF4[j]); break;
case ncclFloat64: resultF8[j] = ReduceOp(op, resultF8[j], arrayF8[j]); break;
case ncclBfloat16: resultB2[j] = ReduceOp(op, resultB2[j], arrayB2[j]); break;
default:
fprintf(stderr, "[ERROR] Unsupported datatype\n");
exit(0);
+2 -1
View File
@@ -59,7 +59,8 @@ namespace CorrectnessTests
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
ncclFloat64,
ncclBfloat16),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
+2 -1
View File
@@ -89,7 +89,8 @@ namespace CorrectnessTests
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
ncclFloat64,
ncclBfloat16),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
+2 -1
View File
@@ -110,7 +110,8 @@ namespace CorrectnessTests
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
ncclFloat64,
ncclBfloat16),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
+2 -1
View File
@@ -58,7 +58,8 @@ namespace CorrectnessTests
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
ncclFloat64,
ncclBfloat16),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
+3
View File
@@ -29,6 +29,7 @@ namespace CorrectnessTests
uint64_t* resultU8 = (uint64_t *)resultI1;
float* resultF4 = (float *)resultI1;
double* resultF8 = (double *)resultI1;
rccl_bfloat16* resultB2 = (rccl_bfloat16 *)resultI1;
// Initialize the result with the first device's array
memcpy(resultI1, dataset.expected[0], dataset.NumBytes());
@@ -44,6 +45,7 @@ namespace CorrectnessTests
uint64_t* arrayU8 = (uint64_t *)arrayI1;
float* arrayF4 = (float *)arrayI1;
double* arrayF8 = (double *)arrayI1;
rccl_bfloat16* arrayB2 = (rccl_bfloat16 *)arrayI1;
for (int j = 0; j < dataset.numElements; j++)
{
@@ -57,6 +59,7 @@ namespace CorrectnessTests
case ncclUint64: resultU8[j] = ReduceOp(op, resultU8[j], arrayU8[j]); break;
case ncclFloat32: resultF4[j] = ReduceOp(op, resultF4[j], arrayF4[j]); break;
case ncclFloat64: resultF8[j] = ReduceOp(op, resultF8[j], arrayF8[j]); break;
case ncclBfloat16: resultB2[j] = ReduceOp(op, resultB2[j], arrayB2[j]); break;
default:
fprintf(stderr, "[ERROR] Unsupported datatype\n");
exit(0);
+2 -1
View File
@@ -57,7 +57,8 @@ namespace CorrectnessTests
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
ncclFloat64,
ncclBfloat16),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
+3
View File
@@ -29,6 +29,7 @@ namespace CorrectnessTests
uint64_t* resultU8 = (uint64_t *)resultI1;
float* resultF4 = (float *)resultI1;
double* resultF8 = (double *)resultI1;
rccl_bfloat16* resultB2 = (rccl_bfloat16 *)resultI1;
// Initialize the result with the first device's array
memcpy(resultI1, dataset.expected[0], dataset.NumBytes());
@@ -44,6 +45,7 @@ namespace CorrectnessTests
uint64_t* arrayU8 = (uint64_t *)arrayI1;
float* arrayF4 = (float *)arrayI1;
double* arrayF8 = (double *)arrayI1;
rccl_bfloat16* arrayB2 = (rccl_bfloat16 *)arrayI1;
for (int j = 0; j < dataset.numElements; j++)
{
@@ -57,6 +59,7 @@ namespace CorrectnessTests
case ncclUint64: resultU8[j] = ReduceOp(op, resultU8[j], arrayU8[j]); break;
case ncclFloat32: resultF4[j] = ReduceOp(op, resultF4[j], arrayF4[j]); break;
case ncclFloat64: resultF8[j] = ReduceOp(op, resultF8[j], arrayF8[j]); break;
case ncclBfloat16: resultB2[j] = ReduceOp(op, resultB2[j], arrayB2[j]); break;
default:
fprintf(stderr, "[ERROR] Unsupported datatype\n");
exit(0);