From db5ab33461db802bdcb7b0acad4db3ee1da1cd6f Mon Sep 17 00:00:00 2001 From: mberenjk <146776561+mberenjk@users.noreply.github.com> Date: Fri, 16 May 2025 09:14:46 -0500 Subject: [PATCH] Switching to old version of rccl_float8 for ROCm versions earlier than 6.3 for backward compatibility. (#128) Co-authored-by: Marzieh Berenjkoub [ROCm/rccl-tests commit: 90760916025b4bea1db0fd5d7f8ab2497efaefc0] --- projects/rccl-tests/src/rccl_float8.h | 2 +- projects/rccl-tests/verifiable/verifiable.cu | 17 +++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/projects/rccl-tests/src/rccl_float8.h b/projects/rccl-tests/src/rccl_float8.h index 76bd4f35a1..5540f1e1e3 100644 --- a/projects/rccl-tests/src/rccl_float8.h +++ b/projects/rccl-tests/src/rccl_float8.h @@ -40,7 +40,7 @@ typedef struct } rccl_bfloat8; // __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__)) -#elif HIP_VERSION >= 60200000 +#elif HIP_VERSION >= 60300000 #include diff --git a/projects/rccl-tests/verifiable/verifiable.cu b/projects/rccl-tests/verifiable/verifiable.cu index 7611a6b491..e875c3238b 100644 --- a/projects/rccl-tests/verifiable/verifiable.cu +++ b/projects/rccl-tests/verifiable/verifiable.cu @@ -392,7 +392,7 @@ struct FloatLayout { }; #endif #if RCCL_FLOAT8 == 1 -#if __HIP_DEVICE_COMPILE__ +#if __HIP_DEVICE_COMPILE__ || HIP_VERSION < 60300000 template<> struct FloatLayout { static constexpr bool is_floating_point = true; @@ -993,11 +993,10 @@ cudaError_t prepareInput1( #if HAVE_ncclBfloat16 case ncclBfloat16: fn = (void const*)&prepareInput2; break; #endif - #if HAVE_ncclfp8_DEVICE + #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 case ncclFloat8e4m3: fn = (void const*)&prepareInput2; break; case ncclFloat8e5m2: fn = (void const*)&prepareInput2; break; - #endif - #if HAVE_ncclfp8_HOST + #elif HAVE_ncclfp8_HOST case ncclFloat8e4m3: if (rccl_float8_useFnuz) { fn = (void const*)&prepareInput2<__hip_fp8_e4m3_fnuz, ReduceOp>; break;} else { fn = (void const*)&prepareInput2<__hip_fp8_e4m3, ReduceOp>; break;} case ncclFloat8e5m2: if (rccl_float8_useFnuz) { fn = (void const*)&prepareInput2<__hip_fp8_e5m2_fnuz, ReduceOp>; break;} @@ -1084,11 +1083,10 @@ cudaError_t prepareExpected1( #if HAVE_ncclBfloat16 case ncclBfloat16: fn = (void const*)&prepareExpected2; break; #endif - #if HAVE_ncclfp8_DEVICE + #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 //for backward compatibility case ncclFloat8e4m3: fn = (void const*)&prepareExpected2; break; case ncclFloat8e5m2: fn = (void const*)&prepareExpected2; break; - #endif - #if HAVE_ncclfp8_HOST + #elif HAVE_ncclfp8_HOST case ncclFloat8e4m3: if (rccl_float8_useFnuz) { fn = (void const*)&prepareExpected2<__hip_fp8_e4m3_fnuz, ReduceOp>; break; } else { fn = (void const*)&prepareExpected2<__hip_fp8_e4m3, ReduceOp>; break; } case ncclFloat8e5m2: if (rccl_float8_useFnuz) { fn = (void const*)&prepareExpected2<__hip_fp8_e5m2_fnuz, ReduceOp>; break; } @@ -1323,11 +1321,10 @@ hipError_t ncclVerifiableVerify( #if HAVE_ncclBfloat16 case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t) #endif - #if HAVE_ncclfp8_DEVICE + #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 case ncclFloat8e4m3: CASE_TY(rccl_float8, uint8_t) case ncclFloat8e5m2: CASE_TY(rccl_bfloat8, uint8_t) - #endif - #if HAVE_ncclfp8_HOST + #elif HAVE_ncclfp8_HOST case ncclFloat8e4m3: if (rccl_float8_useFnuz) { CASE_TY(__hip_fp8_e4m3_fnuz, uint8_t);} else { CASE_TY(__hip_fp8_e4m3, uint8_t);} case ncclFloat8e5m2: if (rccl_float8_useFnuz) { CASE_TY(__hip_fp8_e5m2_fnuz, uint8_t);}