From c60888b0cd135739477e62d400a61138333f1b08 Mon Sep 17 00:00:00 2001 From: "Dittakavi, Satyanvesh" Date: Tue, 12 Aug 2025 16:29:09 +0530 Subject: [PATCH] SWDEV-545947 - Add Implementation for hipSetValidDevices (#805) --- hipamd/src/hip_device_runtime.cpp | 26 ++++++++++++++++++++++++-- hipamd/src/hip_internal.hpp | 4 +++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/hipamd/src/hip_device_runtime.cpp b/hipamd/src/hip_device_runtime.cpp index 180ec9251e..19e96101fa 100644 --- a/hipamd/src/hip_device_runtime.cpp +++ b/hipamd/src/hip_device_runtime.cpp @@ -719,6 +719,7 @@ hipError_t hipGetDeviceFlags(unsigned int* flags) { hipError_t hipSetDevice(int device) { HIP_INIT_API_NO_RETURN(hipSetDevice, device); + hip::tls.isSetDeviceCalled = true; // Check if the device is already set if (hip::tls.device_ != nullptr && hip::tls.device_->deviceId() == device) { HIP_RETURN(hipSuccess); @@ -784,10 +785,31 @@ hipError_t hipSetDeviceFlags(unsigned int flags) { hipError_t hipSetValidDevices(int* device_arr, int len) { HIP_INIT_API(hipSetValidDevices, device_arr, len); + // HIP runtime will go ahead with the default behavior of trying devices + // from a default list sequentially, if the len passed is 0 + if (len == 0) { + HIP_RETURN(hipSuccess); + } + int count = 0; + HIP_RETURN_ONFAIL(ihipDeviceGetCount(&count)); - assert(0 && "Unimplemented"); + if (device_arr == nullptr || len < 0 || len > count) { + HIP_RETURN(hipErrorInvalidValue); + } - HIP_RETURN(hipErrorNotSupported); + for (int i = 0; i < len; ++i) { + if (device_arr[i] < 0 || device_arr[i] >= count) { + HIP_RETURN(hipErrorInvalidDevice); + } + } + + if (tls.isSetDeviceCalled) { + HIP_RETURN(hipSuccess); + } + tls.device_ = g_devices[device_arr[0]]; + uint32_t preferredNumaNode = (tls.device_)->devices()[0]->getPreferredNumaNode(); + amd::Os::setPreferredNumaNode(preferredNumaNode); + HIP_RETURN(hipSuccess); } } //namespace hip diff --git a/hipamd/src/hip_internal.hpp b/hipamd/src/hip_internal.hpp index a475c19064..c218f27a9a 100644 --- a/hipamd/src/hip_internal.hpp +++ b/hipamd/src/hip_internal.hpp @@ -619,11 +619,13 @@ public: hipStreamCaptureMode stream_capture_mode_; std::stack exec_stack_; stream_per_thread stream_per_thread_obj_; + bool isSetDeviceCalled; TlsAggregator(): device_(nullptr), last_error_(hipSuccess), last_command_error_(hipSuccess), - stream_capture_mode_(hipStreamCaptureModeGlobal) { + stream_capture_mode_(hipStreamCaptureModeGlobal), + isSetDeviceCalled(false) { } ~TlsAggregator() { }