fix bug in reduce kernel bfloat16 for ROCm >= 6.0 (#2139)
Co-authored-by: Prasannakumar Murugesan <prmuruge@amd.com> As part of an earlier commit, bfloat16 handling in reduce kernel for FuncMinMax fell into generic/default template when there is no SPECIALIZE_REDUCE for a particular type, this generic template does a bitwise integer comparison and it broke bfloat16 ops. change the else-if statement to else statement, that way it covers both ROCm version < 6.0 and >= 6.0 (with ROCm > 6.0, device.h already typedefs __hip_bfloat16 to hip_bfloat16, so no special case is needed here).
This commit is contained in:
committed by
GitHub
orang tua
f38665ac9a
melakukan
fa366ac03f
@@ -414,7 +414,7 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h
|
||||
SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 1, __nv_bfloat16, fn.isMinNotMax ? __hmin(x, y) : __hmax(x, y))
|
||||
// coverity[copy_constructor_call]
|
||||
SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 2, __nv_bfloat162, fn.isMinNotMax ? __hmin2(x, y) : __hmax2(x, y))
|
||||
#elif ROCM_VERSION < 60000
|
||||
#else
|
||||
SPECIALIZE_REDUCE(FuncSum, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) + (float)(y)))
|
||||
SPECIALIZE_REDUCE(FuncProd, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) * (float)(y)))
|
||||
SPECIALIZE_REDUCE(FuncMinMax, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)(fn.isMinNotMax ? fminf((float)(x), (float)(y)) : fmaxf((float)(x), (float)(y))))
|
||||
|
||||
Reference in New Issue
Block a user