diff --git a/src/rccl_float8.h b/src/rccl_float8.h index 76bd4f35a1..5540f1e1e3 100644 --- a/src/rccl_float8.h +++ b/src/rccl_float8.h @@ -40,7 +40,7 @@ typedef struct } rccl_bfloat8; // __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__)) -#elif HIP_VERSION >= 60200000 +#elif HIP_VERSION >= 60300000 #include diff --git a/verifiable/verifiable.cu b/verifiable/verifiable.cu index 7611a6b491..e875c3238b 100644 --- a/verifiable/verifiable.cu +++ b/verifiable/verifiable.cu @@ -392,7 +392,7 @@ struct FloatLayout { }; #endif #if RCCL_FLOAT8 == 1 -#if __HIP_DEVICE_COMPILE__ +#if __HIP_DEVICE_COMPILE__ || HIP_VERSION < 60300000 template<> struct FloatLayout { static constexpr bool is_floating_point = true; @@ -993,11 +993,10 @@ cudaError_t prepareInput1( #if HAVE_ncclBfloat16 case ncclBfloat16: fn = (void const*)&prepareInput2; break; #endif - #if HAVE_ncclfp8_DEVICE + #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 case ncclFloat8e4m3: fn = (void const*)&prepareInput2; break; case ncclFloat8e5m2: fn = (void const*)&prepareInput2; break; - #endif - #if HAVE_ncclfp8_HOST + #elif HAVE_ncclfp8_HOST case ncclFloat8e4m3: if (rccl_float8_useFnuz) { fn = (void const*)&prepareInput2<__hip_fp8_e4m3_fnuz, ReduceOp>; break;} else { fn = (void const*)&prepareInput2<__hip_fp8_e4m3, ReduceOp>; break;} case ncclFloat8e5m2: if (rccl_float8_useFnuz) { fn = (void const*)&prepareInput2<__hip_fp8_e5m2_fnuz, ReduceOp>; break;} @@ -1084,11 +1083,10 @@ cudaError_t prepareExpected1( #if HAVE_ncclBfloat16 case ncclBfloat16: fn = (void const*)&prepareExpected2; break; #endif - #if HAVE_ncclfp8_DEVICE + #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 //for backward compatibility case ncclFloat8e4m3: fn = (void const*)&prepareExpected2; break; case ncclFloat8e5m2: fn = (void const*)&prepareExpected2; break; - #endif - #if HAVE_ncclfp8_HOST + #elif HAVE_ncclfp8_HOST case ncclFloat8e4m3: if (rccl_float8_useFnuz) { fn = (void const*)&prepareExpected2<__hip_fp8_e4m3_fnuz, ReduceOp>; break; } else { fn = (void const*)&prepareExpected2<__hip_fp8_e4m3, ReduceOp>; break; } case ncclFloat8e5m2: if (rccl_float8_useFnuz) { fn = (void const*)&prepareExpected2<__hip_fp8_e5m2_fnuz, ReduceOp>; break; } @@ -1323,11 +1321,10 @@ hipError_t ncclVerifiableVerify( #if HAVE_ncclBfloat16 case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t) #endif - #if HAVE_ncclfp8_DEVICE + #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 case ncclFloat8e4m3: CASE_TY(rccl_float8, uint8_t) case ncclFloat8e5m2: CASE_TY(rccl_bfloat8, uint8_t) - #endif - #if HAVE_ncclfp8_HOST + #elif HAVE_ncclfp8_HOST case ncclFloat8e4m3: if (rccl_float8_useFnuz) { CASE_TY(__hip_fp8_e4m3_fnuz, uint8_t);} else { CASE_TY(__hip_fp8_e4m3, uint8_t);} case ncclFloat8e5m2: if (rccl_float8_useFnuz) { CASE_TY(__hip_fp8_e5m2_fnuz, uint8_t);}