From f4858e14b2e6f0acb3c78a83951f9dce3bfea1ef Mon Sep 17 00:00:00 2001 From: akolliasAMD <99202231+akolliasAMD@users.noreply.github.com> Date: Thu, 21 Dec 2023 08:58:33 -0700 Subject: [PATCH] rearranged how the min and max functions are part of msccl (#1025) * rearranged how the min and max functions are part of msccl * added more coverage on in place graph tests --- src/collectives/device/msccl_kernel_impl.h | 8 ++++---- src/include/msccl/msccl_kernel.h | 2 +- src/misc/msccl/msccl_setup.cc | 4 ++-- test/AllReduceTests.cpp | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/collectives/device/msccl_kernel_impl.h b/src/collectives/device/msccl_kernel_impl.h index 897481e695..4b010f342b 100644 --- a/src/collectives/device/msccl_kernel_impl.h +++ b/src/collectives/device/msccl_kernel_impl.h @@ -338,21 +338,21 @@ __device__ __forceinline__ void mscclRunInterpreter( if (tid == 0) { NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } -#endif +#endif prims.template send<1>(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end. #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_EXIT) if (tid == 0) { NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } -#endif +#endif } else if (t->type == MSCCL_RECV) { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_ENTRY) if (tid == 0) { NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } -#endif +#endif prims.template recv<1>(dstOffset, thisNelem); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_EXIT) if (tid == 0) { @@ -515,8 +515,8 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct #define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, false) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, false) diff --git a/src/include/msccl/msccl_kernel.h b/src/include/msccl/msccl_kernel.h index 49b0e23def..4962fdcb69 100644 --- a/src/include/msccl/msccl_kernel.h +++ b/src/include/msccl/msccl_kernel.h @@ -39,8 +39,8 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct n #define MSCCL_DECL_KERNEL_ENTRY_FUNC() \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, false) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, false) diff --git a/src/misc/msccl/msccl_setup.cc b/src/misc/msccl/msccl_setup.cc index c00db63d6d..d8933978b1 100644 --- a/src/misc/msccl/msccl_setup.cc +++ b/src/misc/msccl/msccl_setup.cc @@ -185,7 +185,7 @@ static void HIPRT_CB mscclSetupProxyCallback(void *args) { INFO(NCCL_NET,"mscclSetupProxyCallback: proxy args size: %ld\n", params->size()); for (auto &p : *params) { mscclSetupProxyImpl(p.hostAlgo, p.comm); - } + } } ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStream_t stream) { @@ -317,8 +317,8 @@ static ncclResult_t hostToDevRedOp( #define MSCCL_KERNEL_ENTRY() \ MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, false), \ MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, false), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Min, false), \ MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(Min, false), \ MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum, false), \ MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv, false) diff --git a/test/AllReduceTests.cpp b/test/AllReduceTests.cpp index e54de35351..aeb02fb4f3 100644 --- a/test/AllReduceTests.cpp +++ b/test/AllReduceTests.cpp @@ -73,7 +73,7 @@ namespace RcclUnitTesting std::vector const dataTypes = {ncclInt32}; std::vector const redOps = {ncclMax}; std::vector const roots = {0}; - std::vector const numElements = {393216}; + std::vector const numElements = {393216, 12888, 384}; std::vector const inPlaceList = {true}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {true};