From fa366ac03fbd0e965ed0877b1db816b02a2eb19f Mon Sep 17 00:00:00 2001 From: prasanna-amd Date: Tue, 20 Jan 2026 14:07:20 -0800 Subject: [PATCH] fix bug in reduce kernel bfloat16 for ROCm >= 6.0 (#2139) Co-authored-by: Prasannakumar Murugesan As part of an earlier commit, bfloat16 handling in reduce kernel for FuncMinMax fell into generic/default template when there is no SPECIALIZE_REDUCE for a particular type, this generic template does a bitwise integer comparison and it broke bfloat16 ops. change the else-if statement to else statement, that way it covers both ROCm version < 6.0 and >= 6.0 (with ROCm > 6.0, device.h already typedefs __hip_bfloat16 to hip_bfloat16, so no special case is needed here). --- src/device/reduce_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/reduce_kernel.h b/src/device/reduce_kernel.h index b0635f2f21..593f868eae 100755 --- a/src/device/reduce_kernel.h +++ b/src/device/reduce_kernel.h @@ -414,7 +414,7 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 1, __nv_bfloat16, fn.isMinNotMax ? __hmin(x, y) : __hmax(x, y)) // coverity[copy_constructor_call] SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 2, __nv_bfloat162, fn.isMinNotMax ? __hmin2(x, y) : __hmax2(x, y)) -#elif ROCM_VERSION < 60000 +#else SPECIALIZE_REDUCE(FuncSum, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) + (float)(y))) SPECIALIZE_REDUCE(FuncProd, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) * (float)(y))) SPECIALIZE_REDUCE(FuncMinMax, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)(fn.isMinNotMax ? fminf((float)(x), (float)(y)) : fmaxf((float)(x), (float)(y))))