SWDEV-498061 - Add capture support for hipModuleLaunchCooperativeKernel
Change-Id: I5ed188e046c680c2785b3952391f59ed1d0c21b8
This commit is contained in:
committad av
Vladana Stojiljkovic
förälder
db8527f655
incheckning
30cb2d0e67
@@ -348,6 +348,40 @@ hipError_t capturehipModuleLaunchKernel(hipStream_t& stream, hipFunction_t& f, u
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
hipError_t capturehipModuleLaunchCooperativeKernel(hipStream_t& stream, hipFunction_t& f,
|
||||
uint32_t& gridDimX, uint32_t& gridDimY,
|
||||
uint32_t& gridDimZ, uint32_t& blockDimX,
|
||||
uint32_t& blockDimY, uint32_t& blockDimZ,
|
||||
uint32_t& sharedMemBytes, void**& kernelParams) {
|
||||
ClPrint(amd::LOG_INFO, amd::LOG_API,
|
||||
"[hipGraph] Current capture node ModuleLaunchCooperativeKernel on stream : %p", stream);
|
||||
|
||||
if (!hip::isValid(stream)) {
|
||||
return hipErrorContextIsDestroyed;
|
||||
}
|
||||
|
||||
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
|
||||
hipKernelNodeParams nodeParams;
|
||||
nodeParams.func = f;
|
||||
nodeParams.blockDim = {blockDimX, blockDimY, blockDimZ};
|
||||
nodeParams.gridDim = {gridDimX, gridDimY, gridDimZ};
|
||||
nodeParams.kernelParams = kernelParams;
|
||||
nodeParams.sharedMemBytes = sharedMemBytes;
|
||||
nodeParams.extra = nullptr;
|
||||
|
||||
hip::GraphNode* pGraphNode;
|
||||
hipError_t status =
|
||||
ihipGraphAddKernelNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &nodeParams, nullptr, true,
|
||||
amd::NDRangeKernelCommand::CooperativeGroups);
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
s->SetLastCapturedNode(pGraphNode);
|
||||
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
hipError_t capturehipLaunchByPtr(hipStream_t& stream, hipFunction_t func, dim3 blockDim,
|
||||
dim3 gridDim, unsigned int sharedMemBytes, void** extra) {
|
||||
ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] Current capture node LaunchByPtr on stream : %p",
|
||||
|
||||
@@ -42,6 +42,12 @@ hipError_t capturehipModuleLaunchKernel(hipStream_t& stream, hipFunction_t& f, u
|
||||
uint32_t& sharedMemBytes, void**& kernelParams,
|
||||
void**& extra);
|
||||
|
||||
hipError_t capturehipModuleLaunchCooperativeKernel(hipStream_t& stream, hipFunction_t& f,
|
||||
uint32_t& gridDimX, uint32_t& gridDimY,
|
||||
uint32_t& gridDimZ, uint32_t& blockDimX,
|
||||
uint32_t& blockDimY, uint32_t& blockDimZ,
|
||||
uint32_t& sharedMemBytes, void**& kernelParams);
|
||||
|
||||
hipError_t capturehipLaunchByPtr(hipStream_t& stream, hipFunction_t func, dim3 blockDim,
|
||||
dim3 gridDim, unsigned int sharedMemBytes, void** extra);
|
||||
|
||||
|
||||
@@ -439,7 +439,6 @@ struct GraphNode : public hipGraphNodeDOTAttribute {
|
||||
bool isGraphCapture = false;
|
||||
if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) {
|
||||
switch (GetType()) {
|
||||
case hipGraphNodeTypeKernel:
|
||||
case hipGraphNodeTypeMemset:
|
||||
isGraphCapture = true;
|
||||
break;
|
||||
@@ -1325,6 +1324,17 @@ class GraphKernelNode : public GraphNode {
|
||||
}
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
virtual bool GraphCaptureEnabled() override {
|
||||
bool isGraphCapture = false;
|
||||
if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) {
|
||||
// Disable capture for cooperative kernels
|
||||
if (!coopKernel_) {
|
||||
isGraphCapture = true;
|
||||
}
|
||||
}
|
||||
return isGraphCapture;
|
||||
}
|
||||
};
|
||||
|
||||
class GraphMemcpyNode : public GraphNode {
|
||||
|
||||
@@ -572,6 +572,9 @@ hipError_t hipModuleLaunchCooperativeKernel(hipFunction_t f, unsigned int gridDi
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
STREAM_CAPTURE(hipModuleLaunchCooperativeKernel, stream, f, gridDimX, gridDimY, gridDimZ,
|
||||
blockDimX, blockDimY, blockDimZ, sharedMemBytes, kernelParams);
|
||||
|
||||
size_t globalWorkSizeX = static_cast<size_t>(gridDimX) * blockDimX;
|
||||
size_t globalWorkSizeY = static_cast<size_t>(gridDimY) * blockDimY;
|
||||
size_t globalWorkSizeZ = static_cast<size_t>(gridDimZ) * blockDimZ;
|
||||
|
||||
Referens i nytt ärende
Block a user