[hip] add initial support for hipLaunchCooperativeKernelMultiDevice API (#1368)
* [hip] add initial support for hipLaunchCooperativeKernelMultiDevice API * fix formatting
Этот коммит содержится в:
коммит произвёл
Maneesh Gupta
родитель
51c7fedd36
Коммит
6c7da60e28
@@ -480,6 +480,116 @@ hipError_t hipLaunchCooperativeKernel(const void* f, dim3 gridDim,
|
||||
return ihipLogStatus(result);
|
||||
}
|
||||
|
||||
hipError_t hipLaunchCooperativeKernelMultiDevice(hipLaunchParams* launchParamsList,
|
||||
int numDevices, unsigned int flags) {
|
||||
|
||||
HIP_INIT_API(hipLaunchCooperativeKernelMultiDevice, launchParamsList, numDevices, flags);
|
||||
hipError_t result;
|
||||
|
||||
if (numDevices > g_deviceCnt || launchParamsList == nullptr) {
|
||||
return ihipLogStatus(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
for (int i = 0; i < numDevices; ++i) {
|
||||
if (!launchParamsList[i].stream->getDevice()->_props.cooperativeMultiDeviceLaunch) {
|
||||
return ihipLogStatus(hipErrorInvalidConfiguration);
|
||||
}
|
||||
}
|
||||
|
||||
hipFunction_t* gwsKds = reinterpret_cast<hipFunction_t*>(malloc(sizeof(hipFunction_t) * numDevices));
|
||||
hipFunction_t* kds = reinterpret_cast<hipFunction_t*>(malloc(sizeof(hipFunction_t) * numDevices));
|
||||
if (kds == nullptr || gwsKds == nullptr) {
|
||||
return ihipLogStatus(hipErrorNotInitialized);
|
||||
}
|
||||
|
||||
// prepare all kernel descriptors for initializing the GWS and the main kernels per device
|
||||
for (int i = 0; i < numDevices; ++i) {
|
||||
const hipLaunchParams& lp = launchParamsList[i];
|
||||
if (lp.stream == nullptr) {
|
||||
free(gwsKds);
|
||||
free(kds);
|
||||
return ihipLogStatus(hipErrorNotInitialized);
|
||||
}
|
||||
|
||||
gwsKds[i] = hip_impl::get_program_state().kernel_descriptor(reinterpret_cast<std::uintptr_t>(&init_gws),
|
||||
hip_impl::target_agent(lp.stream));
|
||||
if (gwsKds[i] == nullptr) {
|
||||
free(gwsKds);
|
||||
free(kds);
|
||||
return ihipLogStatus(hipErrorInvalidValue);
|
||||
}
|
||||
hip_impl::kernargs_size_align gwsKargs = hip_impl::get_program_state().get_kernargs_size_align(
|
||||
reinterpret_cast<std::uintptr_t>(&init_gws));
|
||||
gwsKds[i]->_kernarg_layout = *reinterpret_cast<const std::vector<std::pair<std::size_t, std::size_t>>*>(
|
||||
gwsKargs.getHandle());
|
||||
|
||||
|
||||
kds[i] = hip_impl::get_program_state().kernel_descriptor(reinterpret_cast<std::uintptr_t>(lp.func),
|
||||
hip_impl::target_agent(lp.stream));
|
||||
if (kds[i] == nullptr) {
|
||||
free(gwsKds);
|
||||
free(kds);
|
||||
return ihipLogStatus(hipErrorInvalidValue);
|
||||
}
|
||||
hip_impl::kernargs_size_align kargs = hip_impl::get_program_state().get_kernargs_size_align(
|
||||
reinterpret_cast<std::uintptr_t>(lp.func));
|
||||
kds[i]->_kernarg_layout = *reinterpret_cast<const std::vector<std::pair<std::size_t, std::size_t>>*>(
|
||||
kargs.getHandle());
|
||||
}
|
||||
|
||||
// lock all streams before launching the blit kernels for initializing the GWS and main kernels to each device
|
||||
for (int i = 0; i < numDevices; ++i) {
|
||||
LockedAccessor_StreamCrit_t streamCrit(launchParamsList[i].stream->criticalData(), false);
|
||||
#if (__hcc_workweek__ >= 19213)
|
||||
streamCrit->_av.acquire_locked_hsa_queue();
|
||||
#endif
|
||||
}
|
||||
|
||||
// launch the init_gws kernel to initialize the GWS followed by launching the main kernels for each device
|
||||
for (int i = 0; i < numDevices; ++i) {
|
||||
const hipLaunchParams& lp = launchParamsList[i];
|
||||
|
||||
void *gwsKernelParam[1];
|
||||
uint nwm1 = (lp.gridDim.x * lp.gridDim.y * lp.gridDim.z) - 1;
|
||||
gwsKernelParam[0] = &nwm1;
|
||||
|
||||
result = ihipModuleLaunchKernel(tls, gwsKds[i], 1, 1, 1, 1, 1, 1,
|
||||
0, lp.stream, gwsKernelParam, nullptr, nullptr, nullptr, 0, true);
|
||||
|
||||
if (result != hipSuccess) {
|
||||
for (int j = 0; j < numDevices; ++j) {
|
||||
launchParamsList[j].stream->criticalData().unlock();
|
||||
#if (__hcc_workweek__ >= 19213)
|
||||
launchParamsList[j].stream->criticalData()._av.release_locked_hsa_queue();
|
||||
#endif
|
||||
}
|
||||
return ihipLogStatus(hipErrorLaunchFailure);
|
||||
}
|
||||
|
||||
result = ihipModuleLaunchKernel(tls, kds[i],
|
||||
lp.gridDim.x * lp.blockDim.x,
|
||||
lp.gridDim.y * lp.blockDim.y,
|
||||
lp.gridDim.z * lp.blockDim.z,
|
||||
lp.blockDim.x, lp.blockDim.y,
|
||||
lp.blockDim.z, lp.sharedMem,
|
||||
lp.stream, lp.args, nullptr, nullptr, nullptr, 0,
|
||||
true);
|
||||
}
|
||||
|
||||
// unlock all streams
|
||||
for (int i = 0; i < numDevices; ++i) {
|
||||
launchParamsList[i].stream->criticalData().unlock();
|
||||
#if (__hcc_workweek__ >= 19213)
|
||||
launchParamsList[i].stream->criticalData()._av.release_locked_hsa_queue();
|
||||
#endif
|
||||
}
|
||||
|
||||
free(gwsKds);
|
||||
free(kds);
|
||||
|
||||
return ihipLogStatus(result);
|
||||
}
|
||||
|
||||
namespace hip_impl {
|
||||
hsa_executable_t executable_for(hipModule_t hmod) {
|
||||
return hmod->executable;
|
||||
|
||||
Ссылка в новой задаче
Block a user