Enable fp8 support (#63)
* initial checkin * rename the fp8 datatype name * update based on cr comments * resolve the build issue * resolve fp8 campability issue * fix minior bug and catch up to reflex latest develop branch change * add fp8 + operatior support * update fp8 header file * resolve merge issue from develop branch
Этот коммит содержится в:
@@ -71,6 +71,7 @@ set(COMMON_FILES
|
||||
common.h
|
||||
common.cu
|
||||
nccl1_compat.h
|
||||
rccl_bfloat8.h
|
||||
rccl_bfloat16.h
|
||||
timer.h
|
||||
timer.cc
|
||||
|
||||
+23
-3
@@ -2,11 +2,13 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "rccl_bfloat8.h"
|
||||
#include "rccl_bfloat16.h"
|
||||
#include "common.h"
|
||||
#include <pthread.h>
|
||||
@@ -28,12 +30,18 @@ int test_ncclVersion = 0; // init'd with ncclGetVersion()
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
, ncclBfloat16
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
, ncclFp8E4M3, ncclFp8E5M2
|
||||
#endif
|
||||
};
|
||||
const char *test_typenames[ncclNumTypes] = {
|
||||
"int8", "uint8", "int32", "uint32", "int64", "uint64", "half", "float", "double"
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
, "bfloat16"
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
, "fp8_e4m3", "fp8_e5m2"
|
||||
#endif
|
||||
};
|
||||
int test_typenum = -1;
|
||||
|
||||
@@ -100,13 +108,13 @@ static int enable_out_of_place = 1;
|
||||
static double parsesize(const char *value) {
|
||||
long long int units;
|
||||
double size;
|
||||
char size_lit;
|
||||
char size_lit[2];
|
||||
|
||||
int count = sscanf(value, "%lf %1s", &size, &size_lit);
|
||||
int count = sscanf(value, "%lf %1s", &size, size_lit);
|
||||
|
||||
switch (count) {
|
||||
case 2:
|
||||
switch (size_lit) {
|
||||
switch (size_lit[0]) {
|
||||
case 'G':
|
||||
case 'g':
|
||||
units = 1024*1024*1024;
|
||||
@@ -401,6 +409,9 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
rccl_bfloat16 bf16;
|
||||
#endif
|
||||
#if defined(RCCL_FLOAT8)
|
||||
rccl_float8 fp8_e4m3; rccl_bfloat8 fp8_e5m2;
|
||||
#endif
|
||||
};
|
||||
switch(type) {
|
||||
case ncclInt8: i8 = ncclVerifiablePremulScalar<int8_t>(rank); break;
|
||||
@@ -415,6 +426,11 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
case ncclBfloat16: bf16 = ncclVerifiablePremulScalar<rccl_bfloat16>(rank); break;
|
||||
#endif
|
||||
#if defined(RCCL_FLOAT8)
|
||||
case ncclFp8E4M3: fp8_e4m3 = ncclVerifiablePremulScalar<rccl_float8>(rank); break;
|
||||
case ncclFp8E5M2: fp8_e5m2 = ncclVerifiablePremulScalar<rccl_bfloat8>(rank); break;
|
||||
#endif
|
||||
case ncclNumTypes: break;
|
||||
}
|
||||
NCCLCHECK(ncclRedOpCreatePreMulSum(&op, &u64, type, ncclScalarHostImmediate, args->comms[i]));
|
||||
}
|
||||
@@ -753,6 +769,10 @@ int main(int argc, char* argv[]) {
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
test_typenum++; // bfloat16
|
||||
#endif
|
||||
#if defined(RCCL_FLOAT8)
|
||||
test_typenum++; // fp8_e4m3
|
||||
test_typenum++; // fp8_e5m2
|
||||
#endif
|
||||
}
|
||||
if (NCCL_VERSION_CODE >= NCCL_VERSION(2,11,0) && test_ncclVersion >= NCCL_VERSION(2,11,0)) {
|
||||
test_opnum++; // PreMulSum
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -221,6 +222,10 @@ static size_t wordSize(ncclDataType_t type) {
|
||||
#if NCCL_MAJOR >= 2
|
||||
//case ncclInt8:
|
||||
case ncclUint8:
|
||||
#if NCCL_MAJOR >= 2 && RCCL_FLOAT8 == 1
|
||||
case ncclFp8E4M3:
|
||||
case ncclFp8E5M2:
|
||||
#endif
|
||||
#endif
|
||||
return 1;
|
||||
case ncclHalf:
|
||||
|
||||
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
@@ -1,7 +1,7 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2020-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
@@ -13,13 +13,20 @@
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "rccl/rccl.h"
|
||||
#include "rccl_bfloat8.h"
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && RCCL_BFLOAT16 ==1
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && RCCL_BFLOAT16 == 1
|
||||
#define HAVE_ncclBfloat16 1
|
||||
#else
|
||||
#define HAVE_ncclBfloat16 0
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && RCCL_FLOAT8 == 1
|
||||
#define HAVE_ncclfp8 1
|
||||
#else
|
||||
#define HAVE_ncclfp8 0
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
|
||||
#define HAVE_ncclAvg 1
|
||||
#else
|
||||
@@ -93,6 +100,12 @@ struct IsIntegral<__half>: std::false_type {};
|
||||
template<>
|
||||
struct IsIntegral<hip_bfloat16>: std::false_type {};
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
template<>
|
||||
struct IsIntegral<rccl_float8>: std::false_type {};
|
||||
template<>
|
||||
struct IsIntegral<rccl_bfloat8>: std::false_type {};
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -130,6 +143,16 @@ namespace {
|
||||
return hip_bfloat16(x);
|
||||
}
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
template<>
|
||||
__host__ __device__ rccl_float8 castTo<rccl_float8>(float x) {
|
||||
return static_cast<rccl_float8>(x);
|
||||
}
|
||||
template<>
|
||||
__host__ __device__ rccl_bfloat8 castTo<rccl_bfloat8>(float x) {
|
||||
return static_cast<rccl_bfloat8>(x);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -157,6 +180,14 @@ struct ReduceSum {
|
||||
return hip_bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
||||
}
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
__host__ __device__ rccl_float8 operator()(rccl_float8 a, rccl_float8 b) const {
|
||||
return rccl_float8(static_cast<float>(a) + static_cast<float>(b));
|
||||
}
|
||||
__host__ __device__ rccl_bfloat8 operator()(rccl_bfloat8 a, rccl_bfloat8 b) const {
|
||||
return rccl_bfloat8(static_cast<float>(a) + static_cast<float>(b));
|
||||
}
|
||||
#endif
|
||||
template<typename T>
|
||||
__host__ __device__ T postOp(T x) const { return x; }
|
||||
};
|
||||
@@ -173,6 +204,20 @@ struct ReduceProd {
|
||||
return hip_bfloat16(static_cast<float>(a) * static_cast<float>(b));
|
||||
}
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
__host__ __device__ rccl_float8 operator()(rccl_float8 a, rccl_float8 b) const {
|
||||
return static_cast<rccl_float8>(a * b);
|
||||
}
|
||||
__host__ __device__ rccl_float8 operator()(rccl_float8 a, float b) const {
|
||||
return static_cast<rccl_float8>(a * b);
|
||||
}
|
||||
__host__ __device__ rccl_bfloat8 operator()(rccl_bfloat8 a, rccl_bfloat8 b) const {
|
||||
return static_cast<rccl_bfloat8>(a * b);
|
||||
}
|
||||
__host__ __device__ rccl_bfloat8 operator()(rccl_bfloat8 a, float b) const {
|
||||
return static_cast<rccl_bfloat8>(a * b);
|
||||
}
|
||||
#endif
|
||||
template<typename T>
|
||||
__host__ __device__ T postOp(T x) const { return x; }
|
||||
};
|
||||
@@ -189,6 +234,14 @@ struct ReduceMin {
|
||||
return static_cast<float>(a) < static_cast<float>(b) ? a : b;
|
||||
}
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
__host__ __device__ rccl_float8 operator()(rccl_float8 a, rccl_float8 b) const {
|
||||
return static_cast<float>(a) < static_cast<float>(b) ? a : b;
|
||||
}
|
||||
__host__ __device__ rccl_bfloat8 operator()(rccl_bfloat8 a, rccl_bfloat8 b) const {
|
||||
return static_cast<float>(a) < static_cast<float>(b) ? a : b;
|
||||
}
|
||||
#endif
|
||||
template<typename T>
|
||||
__host__ __device__ T postOp(T x) const { return x; }
|
||||
};
|
||||
@@ -205,6 +258,14 @@ struct ReduceMax {
|
||||
return static_cast<float>(a) > static_cast<float>(b) ? a : b;
|
||||
}
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
__host__ __device__ rccl_float8 operator()(rccl_float8 a, rccl_float8 b) const {
|
||||
return static_cast<float>(a) > static_cast<float>(b) ? a : b;
|
||||
}
|
||||
__host__ __device__ rccl_bfloat8 operator()(rccl_bfloat8 a, rccl_bfloat8 b) const {
|
||||
return static_cast<float>(a) > static_cast<float>(b) ? a : b;
|
||||
}
|
||||
#endif
|
||||
template<typename T>
|
||||
__host__ __device__ T postOp(T x) const { return x; }
|
||||
};
|
||||
@@ -285,6 +346,18 @@ struct FloatLayout<hip_bfloat16> {
|
||||
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
|
||||
};
|
||||
#endif
|
||||
#if RCCL_FLOAT8 == 1
|
||||
template<>
|
||||
struct FloatLayout<rccl_float8> {
|
||||
static constexpr int exponent_bits = 4, mantissa_bits = 3;
|
||||
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
|
||||
};
|
||||
template<>
|
||||
struct FloatLayout<rccl_bfloat8> {
|
||||
static constexpr int exponent_bits = 5, mantissa_bits = 2;
|
||||
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
__host__ __device__ T makeFloat(int sign, int exp, uint64_t mant) {
|
||||
@@ -816,6 +889,10 @@ void prepareInput1(
|
||||
#if HAVE_ncclBfloat16
|
||||
case ncclBfloat16: CASE_TY(hip_bfloat16)
|
||||
#endif
|
||||
#if HAVE_ncclfp8
|
||||
case ncclFp8E4M3: CASE_TY(rccl_float8)
|
||||
case ncclFp8E5M2: CASE_TY(rccl_bfloat8)
|
||||
#endif
|
||||
case ncclFloat32: CASE_TY(float)
|
||||
case ncclFloat64: CASE_TY(double)
|
||||
default: assert(0);
|
||||
@@ -892,6 +969,10 @@ void prepareExpected1(
|
||||
#if HAVE_ncclBfloat16
|
||||
case ncclBfloat16: CASE_TY(hip_bfloat16)
|
||||
#endif
|
||||
#if HAVE_ncclfp8
|
||||
case ncclFp8E4M3: CASE_TY(rccl_float8)
|
||||
case ncclFp8E5M2: CASE_TY(rccl_bfloat8)
|
||||
#endif
|
||||
case ncclFloat32: CASE_TY(float)
|
||||
case ncclFloat64: CASE_TY(double)
|
||||
default: assert(0);
|
||||
@@ -962,6 +1043,13 @@ __host__ __device__ unsigned calcSumFloatTolerance(int rank_n, int elt_ty) {
|
||||
coef = .66f;
|
||||
break;
|
||||
#endif
|
||||
#if HAVE_ncclfp8
|
||||
case ncclFp8E4M3:
|
||||
case ncclFp8E5M2:
|
||||
power = .91f;
|
||||
coef = .66f;
|
||||
break;
|
||||
#endif
|
||||
}
|
||||
#if __CUDA_ARCH__
|
||||
return 1 + unsigned(coef*powf(float(rank_n), power));
|
||||
@@ -1086,6 +1174,10 @@ void ncclVerifiableVerify(
|
||||
#if HAVE_ncclBfloat16
|
||||
floating |= elt_ty == ncclBfloat16;
|
||||
#endif
|
||||
#if HAVE_ncclfp8
|
||||
floating |= elt_ty == ncclFp8E4M3;
|
||||
floating |= elt_ty == ncclFp8E5M2;
|
||||
#endif
|
||||
|
||||
unsigned tolerance = 0;
|
||||
#if HAVE_ncclAvg
|
||||
@@ -1114,6 +1206,10 @@ void ncclVerifiableVerify(
|
||||
#if HAVE_ncclBfloat16
|
||||
case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t)
|
||||
#endif
|
||||
#if HAVE_ncclfp8
|
||||
case ncclFp8E4M3: CASE_TY(rccl_float8, uint8_t)
|
||||
case ncclFp8E5M2: CASE_TY(rccl_bfloat8, uint8_t)
|
||||
#endif
|
||||
case ncclFloat32: CASE_TY(float, uint32_t)
|
||||
case ncclFloat64: CASE_TY(double, uint64_t)
|
||||
default: assert(0);
|
||||
@@ -1181,6 +1277,10 @@ __global__ void sweep() {
|
||||
#if HAVE_ncclBfloat16
|
||||
sweep1<hip_bfloat16>(ncclBfloat16, "bfloat16");
|
||||
#endif
|
||||
#if HAVE_ncclfp8
|
||||
sweep1<rccl_float8>(ncclFp8E4M3, "fp8_e4m3");
|
||||
sweep1<rccl_bfloat8>(ncclFp8E5M2, "fp8_e5m2");
|
||||
#endif
|
||||
sweep1<float>(ncclFloat32, "float");
|
||||
sweep1<double>(ncclFloat64, "double");
|
||||
}
|
||||
|
||||
@@ -21,7 +21,12 @@ ${HIPIFY_DIR}/verifiable.h: $(TEST_VERIFIABLE_SRCDIR)/verifiable.h
|
||||
@mkdir -p ${HIPIFY_DIR}
|
||||
hipify-perl -quiet-warnings $< > $@
|
||||
|
||||
$(TEST_VERIFIABLE_BUILDDIR)/verifiable.o: $(HIPIFY_DIR)/verifiable.cu.cpp $(HIPIFY_DIR)/verifiable.h
|
||||
${HIPIFY_DIR}/rccl_bfloat8.h: $(TEST_VERIFIABLE_SRCDIR)/../src/rccl_bfloat8.h
|
||||
@printf "Hipifying %-35s > %s\n" $< $@
|
||||
@mkdir -p ${HIPIFY_DIR}
|
||||
hipify-perl -quiet-warnings $< > $@
|
||||
|
||||
$(TEST_VERIFIABLE_BUILDDIR)/verifiable.o: $(HIPIFY_DIR)/verifiable.cu.cpp $(HIPIFY_DIR)/verifiable.h $(HIPIFY_DIR)/rccl_bfloat8.h
|
||||
@printf "Compiling %s\n" $@
|
||||
@mkdir -p $(TEST_VERIFIABLE_BUILDDIR)
|
||||
echo " $(HIPCC) -o $@ $(HIPCUFLAGS) -c $<"
|
||||
|
||||
Ссылка в новой задаче
Block a user