diff --git a/projects/clr/rocclr/device/device.hpp b/projects/clr/rocclr/device/device.hpp index 9b9d57d2a6..5464b1e05d 100644 --- a/projects/clr/rocclr/device/device.hpp +++ b/projects/clr/rocclr/device/device.hpp @@ -1730,6 +1730,8 @@ class Device : public RuntimeObject { ) const { return false; }; + + virtual const uint32_t getPreferredNumaNode() const { return 0; } virtual void ReleaseGlobalSignal(void* signal) const {} //! Returns TRUE if the device is available for computations diff --git a/projects/clr/rocclr/device/rocm/rocdevice.cpp b/projects/clr/rocclr/device/rocm/rocdevice.cpp index 92b54f5c9c..5c8bbcd5e1 100644 --- a/projects/clr/rocclr/device/rocm/rocdevice.cpp +++ b/projects/clr/rocclr/device/rocm/rocdevice.cpp @@ -170,7 +170,8 @@ Device::Device(hsa_agent_t bkendDevice) , queuePool_(QueuePriority::Total) , coopHostcallBuffer_(nullptr) , queueWithCUMaskPool_(QueuePriority::Total) - , numOfVgpus_(0) { + , numOfVgpus_(0) + , preferred_numa_node_(0) { group_segment_.handle = 0; system_segment_.handle = 0; system_coarse_segment_.handle = 0; @@ -194,7 +195,7 @@ void Device::setupCpuAgent() { } } } - + preferred_numa_node_ = index; cpu_agent_ = cpu_agents_[index].agent; system_segment_ = cpu_agents_[index].fine_grain_pool; system_coarse_segment_ = cpu_agents_[index].coarse_grain_pool; diff --git a/projects/clr/rocclr/device/rocm/rocdevice.hpp b/projects/clr/rocclr/device/rocm/rocdevice.hpp index e8f4791042..ce98deb2b4 100644 --- a/projects/clr/rocclr/device/rocm/rocdevice.hpp +++ b/projects/clr/rocclr/device/rocm/rocdevice.hpp @@ -534,6 +534,8 @@ class Device : public NullDevice { virtual amd::Memory* GetArenaMemObj(const void* ptr, size_t& offset); + const uint32_t getPreferredNumaNode() const { return preferred_numa_node_; } + private: bool create(); @@ -555,6 +557,7 @@ class Device : public NullDevice { static std::vector cpu_agents_; hsa_agent_t cpu_agent_; + uint32_t preferred_numa_node_; std::vector p2p_agents_; //!< List of P2P agents available for this device std::vector enabled_p2p_devices_; //!< List of user enabled P2P devices for this device mutable std::mutex lock_allow_access_; //!< To serialize allow_access calls diff --git a/projects/clr/rocclr/os/os.hpp b/projects/clr/rocclr/os/os.hpp index 923b0f4509..b4f6816bc2 100644 --- a/projects/clr/rocclr/os/os.hpp +++ b/projects/clr/rocclr/os/os.hpp @@ -222,6 +222,9 @@ class Os : AllStatic { //! Platform-specific optimized memcpy() static void* fastMemcpy(void* dest, const void* src, size_t n); + //! NUMA related settings + static void setPreferredNumaNode(uint32_t node); + // File/Path helper routines: // diff --git a/projects/clr/rocclr/os/os_posix.cpp b/projects/clr/rocclr/os/os_posix.cpp index 04b2f71a79..7b7fcfd5ef 100644 --- a/projects/clr/rocclr/os/os_posix.cpp +++ b/projects/clr/rocclr/os/os_posix.cpp @@ -48,6 +48,10 @@ #define DT_GNU_HASH 0x6ffffef5 #endif // DT_GNU_HASH +#ifdef ROCCLR_SUPPORT_NUMA_POLICY +#include +#endif // ROCCLR_SUPPORT_NUMA_POLICY + #include #include #include @@ -60,7 +64,6 @@ #include #include - namespace amd { static struct sigaction oldSigAction; @@ -121,7 +124,6 @@ static void divisionErrorHandler(int sig, siginfo_t* info, void* ptr) { return; } - std::cerr << "Unhandled signal in divisionErrorHandler()" << std::endl; ::abort(); } @@ -306,6 +308,20 @@ void Os::currentStackInfo(address* base, size_t* size) { void Os::setCurrentThreadName(const char* name) { ::prctl(PR_SET_NAME, name); } +void Os::setPreferredNumaNode(uint32_t node) { + if (AMD_CPU_AFFINITY) { + // Set preferred node affinity mask + int num_cpus = numa_num_configured_cpus(); + bitmask* bm = numa_bitmask_alloc(num_cpus); + + numa_node_to_cpus(node, bm); + if (numa_sched_setaffinity(0, bm) < 0) { + assert(0 && "failed to set affinity"); + } + + numa_free_cpumask(bm); + } +} void* Thread::entry(Thread* thread) { sigset_t set; diff --git a/projects/clr/rocclr/os/os_win32.cpp b/projects/clr/rocclr/os/os_win32.cpp index 499aa694bc..d9eebe0bc6 100644 --- a/projects/clr/rocclr/os/os_win32.cpp +++ b/projects/clr/rocclr/os/os_win32.cpp @@ -250,6 +250,8 @@ static void SetThreadName(DWORD threadId, const char* name) { void Os::setCurrentThreadName(const char* name) { SetThreadName(GetCurrentThreadId(), name); } +void Os::setPreferredNumaNode(uint32_t node) {}; + static LONG WINAPI divExceptionFilter(struct _EXCEPTION_POINTERS* ep) { DWORD code = ep->ExceptionRecord->ExceptionCode;