diff --git a/CMakeLists.txt b/CMakeLists.txt index 114ef58930..511ba8099e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -119,6 +119,26 @@ execute_process( OUTPUT_VARIABLE hipcc_version_string) message(STATUS "hipcc version: ${hipcc_version_string}") +## Check for ROCm version +execute_process( + COMMAND bash "-c" "cat ${ROCM_PATH}/.info/version" + OUTPUT_VARIABLE rocm_version_string +) +string(REGEX MATCH "([0-9]+)\\.([0-9]+)\\.([0-9]+)" rocm_version_matches ${rocm_version_string}) +if (rocm_version_matches) + set(ROCM_MAJOR_VERSION ${CMAKE_MATCH_1}) + set(ROCM_MINOR_VERSION ${CMAKE_MATCH_2}) + set(ROCM_PATCH_VERSION ${CMAKE_MATCH_3}) + + message(STATUS "ROCm version: ${ROCM_MAJOR_VERSION}.${ROCM_MINOR_VERSION}.${ROCM_PATCH_VERSION}") + + # Convert the version components to int for comparison + math(EXPR ROCM_VERSION "(10000 * ${ROCM_MAJOR_VERSION}) + (100 * ${ROCM_MINOR_VERSION}) + ${ROCM_PATCH_VERSION}") + add_definitions("-DROCM_VERSION=${ROCM_VERSION}") +else() + message(WARNING "Failed to extract ROCm version.") +endif() + ### Check for hipEventDisableSystemFence support check_symbol_exists("hipEventDisableSystemFence" "hip/hip_runtime_api.h" HIP_EVENT_DISABLE_FENCE) diff --git a/src/misc/argcheck.cc b/src/misc/argcheck.cc index 84f8a6f0d5..7689acd922 100644 --- a/src/misc/argcheck.cc +++ b/src/misc/argcheck.cc @@ -14,7 +14,11 @@ static ncclResult_t CudaPtrCheck(const void* pointer, struct ncclComm* comm, con WARN("%s : %s %p is not a valid pointer", opname, ptrname, pointer); return ncclInvalidArgument; } +#if ROCM_VERSION < 50500 + if (attr.memoryType == cudaMemoryTypeDevice && attr.device != comm->cudaDev) { +#else if (attr.type == cudaMemoryTypeDevice && attr.device != comm->cudaDev) { +#endif WARN("%s : %s allocated on device %d mismatchs with NCCL device %d", opname, ptrname, attr.device, comm->cudaDev); return ncclInvalidArgument; }