diff --git a/projects/clr/hipamd/src/hip_hmm.cpp b/projects/clr/hipamd/src/hip_hmm.cpp index 5801df1032..f4ecb4e886 100755 --- a/projects/clr/hipamd/src/hip_hmm.cpp +++ b/projects/clr/hipamd/src/hip_hmm.cpp @@ -24,6 +24,7 @@ #include "platform/context.hpp" #include "platform/command.hpp" #include "platform/memory.hpp" +#include "os/os.hpp" namespace hip { @@ -310,9 +311,19 @@ hipError_t ihipMemPrefetchAsync(const void* dev_ptr, size_t count, hipMemLocatio const bool cpuAccess = isHost || isHostNuma || isHostCurrent; // Determine the target device index: - // - for host-prefetch and host-current, always use device 0 + // - for host-prefetch, use default CPU agent + // - for host-current, query the current thread's NUMA node ID // - for host-NUMA or device-prefetch, use the provided id - int targetDevice = (isHost || isHostCurrent) ? hipCpuDeviceId : location.id; + int targetDevice; + if (isHost) { + targetDevice = hipCpuDeviceId; + } else if (isHostCurrent) { + uint32_t numa_node = amd::numa::getCurrentNumaNode(); + targetDevice = + (numa_node == static_cast(-1)) ? hipCpuDeviceId : static_cast(numa_node); + } else { + targetDevice = location.id; + } amd::Device* dev = nullptr; if (cpuAccess == false) { @@ -378,10 +389,16 @@ hipError_t ihipMemAdvise(const void* dev_ptr, size_t count, hipMemoryAdvise advi use_cpu = true; break; case hipMemLocationTypeHost: - case hipMemLocationTypeHostNumaCurrent: targetDevice = hipCpuDeviceId; use_cpu = true; break; + case hipMemLocationTypeHostNumaCurrent: { + uint32_t numa_node = amd::numa::getCurrentNumaNode(); + targetDevice = + (numa_node == static_cast(-1)) ? hipCpuDeviceId : static_cast(numa_node); + use_cpu = true; + break; + } default: return hipErrorInvalidValue; } diff --git a/projects/clr/rocclr/os/os.hpp b/projects/clr/rocclr/os/os.hpp index 7e61f815ac..cec7f482cd 100644 --- a/projects/clr/rocclr/os/os.hpp +++ b/projects/clr/rocclr/os/os.hpp @@ -299,6 +299,9 @@ namespace numa { static constexpr uint32_t kBitsPerUInt64 = 8 * sizeof(uint64_t); +//! Get the NUMA node ID of the current thread +uint32_t getCurrentNumaNode(); + /*! \brief Manage Numa policy. * * \note Works in Linux only, dummy in Windows. diff --git a/projects/clr/rocclr/os/os_posix.cpp b/projects/clr/rocclr/os/os_posix.cpp index ddc060b0c7..7030d81b89 100644 --- a/projects/clr/rocclr/os/os_posix.cpp +++ b/projects/clr/rocclr/os/os_posix.cpp @@ -954,6 +954,15 @@ void Os::CxaDemangle(const std::string& name, std::string* result) { namespace numa { +// ================================================================================================ +uint32_t getCurrentNumaNode() { + unsigned cpu, node; + if (syscall(__NR_getcpu, &cpu, &node, nullptr) < 0) { + return static_cast(-1); + } + return static_cast(node); +} + // ================================================================================================ NumaPolicy::NumaPolicy(const uint32_t numa_node_count) : node_map_((numa_node_count + kBitsPerUInt64 - 1) / kBitsPerUInt64, 0) { } diff --git a/projects/clr/rocclr/os/os_win32.cpp b/projects/clr/rocclr/os/os_win32.cpp index 4055420d73..b5c8d6ea58 100644 --- a/projects/clr/rocclr/os/os_win32.cpp +++ b/projects/clr/rocclr/os/os_win32.cpp @@ -704,6 +704,19 @@ void Os::CxaDemangle(const std::string& name, std::string* result) { *result = n namespace numa { +// ================================================================================================ +uint32_t getCurrentNumaNode() { + PROCESSOR_NUMBER procNumber{}; + GetCurrentProcessorNumberEx(&procNumber); + + USHORT numa_node = static_cast(-1); + if (!GetNumaProcessorNodeEx(&procNumber, &numa_node)) { + return static_cast(-1); + } + + return static_cast(numa_node); +} + // ================================================================================================ NumaPolicy::NumaPolicy(const uint32_t numa_node_count) { }