From 6c7da60e28c1aaabc93d84ed572a4e952a679d17 Mon Sep 17 00:00:00 2001 From: Aryan Salmanpour Date: Mon, 16 Sep 2019 04:31:17 -0400 Subject: [PATCH] [hip] add initial support for hipLaunchCooperativeKernelMultiDevice API (#1368) * [hip] add initial support for hipLaunchCooperativeKernelMultiDevice API * fix formatting --- hipamd/src/hip_module.cpp | 110 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/hipamd/src/hip_module.cpp b/hipamd/src/hip_module.cpp index 2548660a72..4c07b9777b 100644 --- a/hipamd/src/hip_module.cpp +++ b/hipamd/src/hip_module.cpp @@ -480,6 +480,116 @@ hipError_t hipLaunchCooperativeKernel(const void* f, dim3 gridDim, return ihipLogStatus(result); } +hipError_t hipLaunchCooperativeKernelMultiDevice(hipLaunchParams* launchParamsList, + int numDevices, unsigned int flags) { + + HIP_INIT_API(hipLaunchCooperativeKernelMultiDevice, launchParamsList, numDevices, flags); + hipError_t result; + + if (numDevices > g_deviceCnt || launchParamsList == nullptr) { + return ihipLogStatus(hipErrorInvalidValue); + } + + for (int i = 0; i < numDevices; ++i) { + if (!launchParamsList[i].stream->getDevice()->_props.cooperativeMultiDeviceLaunch) { + return ihipLogStatus(hipErrorInvalidConfiguration); + } + } + + hipFunction_t* gwsKds = reinterpret_cast(malloc(sizeof(hipFunction_t) * numDevices)); + hipFunction_t* kds = reinterpret_cast(malloc(sizeof(hipFunction_t) * numDevices)); + if (kds == nullptr || gwsKds == nullptr) { + return ihipLogStatus(hipErrorNotInitialized); + } + + // prepare all kernel descriptors for initializing the GWS and the main kernels per device + for (int i = 0; i < numDevices; ++i) { + const hipLaunchParams& lp = launchParamsList[i]; + if (lp.stream == nullptr) { + free(gwsKds); + free(kds); + return ihipLogStatus(hipErrorNotInitialized); + } + + gwsKds[i] = hip_impl::get_program_state().kernel_descriptor(reinterpret_cast(&init_gws), + hip_impl::target_agent(lp.stream)); + if (gwsKds[i] == nullptr) { + free(gwsKds); + free(kds); + return ihipLogStatus(hipErrorInvalidValue); + } + hip_impl::kernargs_size_align gwsKargs = hip_impl::get_program_state().get_kernargs_size_align( + reinterpret_cast(&init_gws)); + gwsKds[i]->_kernarg_layout = *reinterpret_cast>*>( + gwsKargs.getHandle()); + + + kds[i] = hip_impl::get_program_state().kernel_descriptor(reinterpret_cast(lp.func), + hip_impl::target_agent(lp.stream)); + if (kds[i] == nullptr) { + free(gwsKds); + free(kds); + return ihipLogStatus(hipErrorInvalidValue); + } + hip_impl::kernargs_size_align kargs = hip_impl::get_program_state().get_kernargs_size_align( + reinterpret_cast(lp.func)); + kds[i]->_kernarg_layout = *reinterpret_cast>*>( + kargs.getHandle()); + } + + // lock all streams before launching the blit kernels for initializing the GWS and main kernels to each device + for (int i = 0; i < numDevices; ++i) { + LockedAccessor_StreamCrit_t streamCrit(launchParamsList[i].stream->criticalData(), false); +#if (__hcc_workweek__ >= 19213) + streamCrit->_av.acquire_locked_hsa_queue(); +#endif + } + + // launch the init_gws kernel to initialize the GWS followed by launching the main kernels for each device + for (int i = 0; i < numDevices; ++i) { + const hipLaunchParams& lp = launchParamsList[i]; + + void *gwsKernelParam[1]; + uint nwm1 = (lp.gridDim.x * lp.gridDim.y * lp.gridDim.z) - 1; + gwsKernelParam[0] = &nwm1; + + result = ihipModuleLaunchKernel(tls, gwsKds[i], 1, 1, 1, 1, 1, 1, + 0, lp.stream, gwsKernelParam, nullptr, nullptr, nullptr, 0, true); + + if (result != hipSuccess) { + for (int j = 0; j < numDevices; ++j) { + launchParamsList[j].stream->criticalData().unlock(); +#if (__hcc_workweek__ >= 19213) + launchParamsList[j].stream->criticalData()._av.release_locked_hsa_queue(); +#endif + } + return ihipLogStatus(hipErrorLaunchFailure); + } + + result = ihipModuleLaunchKernel(tls, kds[i], + lp.gridDim.x * lp.blockDim.x, + lp.gridDim.y * lp.blockDim.y, + lp.gridDim.z * lp.blockDim.z, + lp.blockDim.x, lp.blockDim.y, + lp.blockDim.z, lp.sharedMem, + lp.stream, lp.args, nullptr, nullptr, nullptr, 0, + true); + } + + // unlock all streams + for (int i = 0; i < numDevices; ++i) { + launchParamsList[i].stream->criticalData().unlock(); +#if (__hcc_workweek__ >= 19213) + launchParamsList[i].stream->criticalData()._av.release_locked_hsa_queue(); +#endif + } + + free(gwsKds); + free(kds); + + return ihipLogStatus(result); +} + namespace hip_impl { hsa_executable_t executable_for(hipModule_t hmod) { return hmod->executable;