Switching to old version of rccl_float8 for ROCm versions earlier than 6.3 for backward compatibility. (#128)

Co-authored-by: Marzieh Berenjkoub <mberenjk@amd.com>
Этот коммит содержится в:
mberenjk
2025-05-16 09:14:46 -05:00
коммит произвёл GitHub
родитель 0abe3c80bb
Коммит 9076091602
2 изменённых файлов: 8 добавлений и 11 удалений
+1 -1
Просмотреть файл
@@ -40,7 +40,7 @@ typedef struct
} rccl_bfloat8;
// __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__))
#elif HIP_VERSION >= 60200000
#elif HIP_VERSION >= 60300000
#include <hip/hip_fp8.h>
+7 -10
Просмотреть файл
@@ -392,7 +392,7 @@ struct FloatLayout<hip_bfloat16> {
};
#endif
#if RCCL_FLOAT8 == 1
#if __HIP_DEVICE_COMPILE__
#if __HIP_DEVICE_COMPILE__ || HIP_VERSION < 60300000
template<>
struct FloatLayout<rccl_float8> {
static constexpr bool is_floating_point = true;
@@ -993,11 +993,10 @@ cudaError_t prepareInput1(
#if HAVE_ncclBfloat16
case ncclBfloat16: fn = (void const*)&prepareInput2<hip_bfloat16, ReduceOp>; break;
#endif
#if HAVE_ncclfp8_DEVICE
#if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000
case ncclFloat8e4m3: fn = (void const*)&prepareInput2<rccl_float8, ReduceOp>; break;
case ncclFloat8e5m2: fn = (void const*)&prepareInput2<rccl_bfloat8, ReduceOp>; break;
#endif
#if HAVE_ncclfp8_HOST
#elif HAVE_ncclfp8_HOST
case ncclFloat8e4m3: if (rccl_float8_useFnuz) { fn = (void const*)&prepareInput2<__hip_fp8_e4m3_fnuz, ReduceOp>; break;}
else { fn = (void const*)&prepareInput2<__hip_fp8_e4m3, ReduceOp>; break;}
case ncclFloat8e5m2: if (rccl_float8_useFnuz) { fn = (void const*)&prepareInput2<__hip_fp8_e5m2_fnuz, ReduceOp>; break;}
@@ -1084,11 +1083,10 @@ cudaError_t prepareExpected1(
#if HAVE_ncclBfloat16
case ncclBfloat16: fn = (void const*)&prepareExpected2<hip_bfloat16, ReduceOp>; break;
#endif
#if HAVE_ncclfp8_DEVICE
#if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 //for backward compatibility
case ncclFloat8e4m3: fn = (void const*)&prepareExpected2<rccl_float8, ReduceOp>; break;
case ncclFloat8e5m2: fn = (void const*)&prepareExpected2<rccl_bfloat8, ReduceOp>; break;
#endif
#if HAVE_ncclfp8_HOST
#elif HAVE_ncclfp8_HOST
case ncclFloat8e4m3: if (rccl_float8_useFnuz) { fn = (void const*)&prepareExpected2<__hip_fp8_e4m3_fnuz, ReduceOp>; break; }
else { fn = (void const*)&prepareExpected2<__hip_fp8_e4m3, ReduceOp>; break; }
case ncclFloat8e5m2: if (rccl_float8_useFnuz) { fn = (void const*)&prepareExpected2<__hip_fp8_e5m2_fnuz, ReduceOp>; break; }
@@ -1323,11 +1321,10 @@ hipError_t ncclVerifiableVerify(
#if HAVE_ncclBfloat16
case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t)
#endif
#if HAVE_ncclfp8_DEVICE
#if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000
case ncclFloat8e4m3: CASE_TY(rccl_float8, uint8_t)
case ncclFloat8e5m2: CASE_TY(rccl_bfloat8, uint8_t)
#endif
#if HAVE_ncclfp8_HOST
#elif HAVE_ncclfp8_HOST
case ncclFloat8e4m3: if (rccl_float8_useFnuz) { CASE_TY(__hip_fp8_e4m3_fnuz, uint8_t);}
else { CASE_TY(__hip_fp8_e4m3, uint8_t);}
case ncclFloat8e5m2: if (rccl_float8_useFnuz) { CASE_TY(__hip_fp8_e5m2_fnuz, uint8_t);}