fix bug in reduce kernel bfloat16 for ROCm >= 6.0 (#2139)

Co-authored-by: Prasannakumar Murugesan <prmuruge@amd.com>
As part of an earlier commit, bfloat16 handling in reduce kernel for FuncMinMax fell into generic/default template when there is no SPECIALIZE_REDUCE for a particular type, this generic template does a bitwise integer comparison and it broke bfloat16 ops.
change the else-if statement to else statement, that way it covers both ROCm version < 6.0 and >= 6.0 (with ROCm > 6.0, device.h already typedefs __hip_bfloat16 to hip_bfloat16, so no special case is needed here).
This commit is contained in:
prasanna-amd
2026-01-20 14:07:20 -08:00
committed by GitHub
szülő f38665ac9a
commit fa366ac03f
+1 -1
Fájl megtekintése
@@ -414,7 +414,7 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h
SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 1, __nv_bfloat16, fn.isMinNotMax ? __hmin(x, y) : __hmax(x, y))
// coverity[copy_constructor_call]
SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 2, __nv_bfloat162, fn.isMinNotMax ? __hmin2(x, y) : __hmax2(x, y))
#elif ROCM_VERSION < 60000
#else
SPECIALIZE_REDUCE(FuncSum, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) + (float)(y)))
SPECIALIZE_REDUCE(FuncProd, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) * (float)(y)))
SPECIALIZE_REDUCE(FuncMinMax, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)(fn.isMinNotMax ? fminf((float)(x), (float)(y)) : fmaxf((float)(x), (float)(y))))