diff --git a/src/device/generate.py b/src/device/generate.py index c03abaf422..9f540eed31 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -164,7 +164,7 @@ def calc_unroll_for_local_arch(): # Homogeneous system is required to build for only 1 varient of unroll factor if len(gfx_targets) == 1: gfx_name, cu_count = gfx_targets[0] - if "gfx908" == gfx_name or ("gfx94" in gfx_name and cu_count > 80): + if "gfx908" == gfx_name or (gfx_name in ["gfx942", "gfx950"] and cu_count > 80): return 2 else: return 4 @@ -516,4 +516,4 @@ if is_msccl_kernels: out( "MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE({redop}, {ty_cxx}, false);\n" .format(redop=redop, ty_cxx=ty_to_cxx[ty]) - ) \ No newline at end of file + )