diff --git a/hipamd/api/hip/hip_module.cpp b/hipamd/api/hip/hip_module.cpp index edc3ba4384..eb6aec8e08 100644 --- a/hipamd/api/hip/hip_module.cpp +++ b/hipamd/api/hip/hip_module.cpp @@ -333,16 +333,23 @@ hipError_t hipLaunchCooperativeKernel(const void* f, amd::NDRangeKernelCommand::CooperativeGroups)); } -hipError_t hipLaunchCooperativeKernelMultiDevice(hipLaunchParams* launchParamsList, - int numDevices, unsigned int flags) +hipError_t ihipLaunchCooperativeKernelMultiDevice(hipLaunchParams* launchParamsList, + int numDevices, unsigned int flags, uint32_t extFlags) { - int deviceId = ihipGetDevice(); + int currentDevice = ihipGetDevice(); + int numActiveGPUs = 0; + ihipDeviceGetCount(&numActiveGPUs); + if ((numDevices > numActiveGPUs) || (launchParamsList == nullptr)) { + return hipErrorInvalidValue; + } + hipError_t result = hipErrorUnknown; for (int i = 0; i < numDevices; ++i) { + hipSetDevice(i); const hipLaunchParams& launch = launchParamsList[i]; amd::HostQueue* queue = as_amd(reinterpret_cast(launch.stream))->asHostQueue(); - hipFunction_t func = PlatformState::instance().getFunc(launch.func, deviceId); + hipFunction_t func = PlatformState::instance().getFunc(launch.func, i); if (func == nullptr) { HIP_RETURN(result); } @@ -352,8 +359,22 @@ hipError_t hipLaunchCooperativeKernelMultiDevice(hipLaunchParams* launchParamsLi launch.gridDim.z * launch.blockDim.z, launch.blockDim.x, launch.blockDim.y, launch.blockDim.z, launch.sharedMem, launch.stream, - launch.args, nullptr, nullptr, nullptr, flags, - (amd::NDRangeKernelCommand::CooperativeGroups | amd::NDRangeKernelCommand::CooperativeMultiDeviceGroups)); + launch.args, nullptr, nullptr, nullptr, flags, extFlags); } + + hipSetDevice(currentDevice); return result; } + +hipError_t hipLaunchCooperativeKernelMultiDevice(hipLaunchParams* launchParamsList, + int numDevices, unsigned int flags) +{ + return ihipLaunchCooperativeKernelMultiDevice(launchParamsList, numDevices, flags, + (amd::NDRangeKernelCommand::CooperativeGroups | + amd::NDRangeKernelCommand::CooperativeMultiDeviceGroups)); +} + +hipError_t hipExtLaunchMultiKernelMultiDevice(hipLaunchParams* launchParamsList, + int numDevices, unsigned int flags) { + return ihipLaunchCooperativeKernelMultiDevice(launchParamsList, numDevices, flags, 0); +}