SWDEV-545947 - Add Implementation for hipSetValidDevices (#805)

[ROCm/clr commit: c60888b0cd]
This commit is contained in:
Dittakavi, Satyanvesh
2025-08-12 16:29:09 +05:30
committed by GitHub
parent 263a0bc57d
commit 04c935d29c
2 changed files with 27 additions and 3 deletions
+24 -2
View File
@@ -719,6 +719,7 @@ hipError_t hipGetDeviceFlags(unsigned int* flags) {
hipError_t hipSetDevice(int device) {
HIP_INIT_API_NO_RETURN(hipSetDevice, device);
hip::tls.isSetDeviceCalled = true;
// Check if the device is already set
if (hip::tls.device_ != nullptr && hip::tls.device_->deviceId() == device) {
HIP_RETURN(hipSuccess);
@@ -784,10 +785,31 @@ hipError_t hipSetDeviceFlags(unsigned int flags) {
hipError_t hipSetValidDevices(int* device_arr, int len) {
HIP_INIT_API(hipSetValidDevices, device_arr, len);
// HIP runtime will go ahead with the default behavior of trying devices
// from a default list sequentially, if the len passed is 0
if (len == 0) {
HIP_RETURN(hipSuccess);
}
int count = 0;
HIP_RETURN_ONFAIL(ihipDeviceGetCount(&count));
assert(0 && "Unimplemented");
if (device_arr == nullptr || len < 0 || len > count) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(hipErrorNotSupported);
for (int i = 0; i < len; ++i) {
if (device_arr[i] < 0 || device_arr[i] >= count) {
HIP_RETURN(hipErrorInvalidDevice);
}
}
if (tls.isSetDeviceCalled) {
HIP_RETURN(hipSuccess);
}
tls.device_ = g_devices[device_arr[0]];
uint32_t preferredNumaNode = (tls.device_)->devices()[0]->getPreferredNumaNode();
amd::Os::setPreferredNumaNode(preferredNumaNode);
HIP_RETURN(hipSuccess);
}
} //namespace hip
+3 -1
View File
@@ -619,11 +619,13 @@ public:
hipStreamCaptureMode stream_capture_mode_;
std::stack<ihipExec_t> exec_stack_;
stream_per_thread stream_per_thread_obj_;
bool isSetDeviceCalled;
TlsAggregator(): device_(nullptr),
last_error_(hipSuccess),
last_command_error_(hipSuccess),
stream_capture_mode_(hipStreamCaptureModeGlobal) {
stream_capture_mode_(hipStreamCaptureModeGlobal),
isSetDeviceCalled(false) {
}
~TlsAggregator() {
}