diff --git a/projects/rccl/src/collectives/device/prims_ll128.h b/projects/rccl/src/collectives/device/prims_ll128.h index 48f7796df8..0529464f36 100644 --- a/projects/rccl/src/collectives/device/prims_ll128.h +++ b/projects/rccl/src/collectives/device/prims_ll128.h @@ -12,6 +12,14 @@ #define NCCL_LL128_FLAGTHREAD (NCCL_LL128_LINEELEMS-1) +#ifndef RCCL_USE_WBINVL1_VOL +#if defined(__GFX8__) || defined(__GFX9__) +#define RCCL_USE_WBINVL1_VOL 1 +#else +#define RCCL_USE_WBINVL1_VOL 0 +#endif +#endif + template class Primitives: public PrimitivesWithoutDirect> { @@ -304,7 +312,7 @@ private: } } -#if !defined(__gfx1030__) && !defined(__gfx1100__) && !defined(__gfx1101__) && !defined(__gfx1102__) +#if RCCL_USE_WBINVL1_VOL if (tid == 0) __asm__ __volatile__("buffer_wbinvl1_vol"); #endif /************************ Send **************************/ diff --git a/projects/rccl/src/include/devcomm.h b/projects/rccl/src/include/devcomm.h index c569a39ede..1c80a3f32c 100644 --- a/projects/rccl/src/include/devcomm.h +++ b/projects/rccl/src/include/devcomm.h @@ -53,11 +53,7 @@ union ncclLLFifoLine { int4 i4; }; -#if defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) -#define WARP_SIZE 32 -#else -#define WARP_SIZE 64 -#endif +#define WARP_SIZE warpSize #define MAXCHANNELS 32 #define NCCL_MAX_NTHREADS 256 #define NCCL_SIMPLE_MAX_NTHREADS NCCL_MAX_NTHREADS