From 3dae883e73857916ae075a17c47738963d7268fc Mon Sep 17 00:00:00 2001 From: Jaydeep Patel Date: Thu, 28 Mar 2024 05:39:53 +0000 Subject: [PATCH] SWDEV-453535 - Capture hipMemset3DAsync. Change-Id: I517c2557573db258b3e3e353f02f6a56652b0fde [ROCm/clr commit: 12e0bdcd3219052ecc85956bbad60b76781d44ad] --- projects/clr/hipamd/src/hip_graph.cpp | 29 +++++++++++-- .../clr/hipamd/src/hip_graph_internal.hpp | 43 +++++++++++-------- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index edaec142c0..28335c6d51 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -165,7 +165,8 @@ hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* gra hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph, hip::GraphNode* const* pDependencies, size_t numDependencies, - const hipMemsetParams* pMemsetParams, bool capture = true) { + const hipMemsetParams* pMemsetParams, + bool capture = true, size_t depth = 1) { if (pGraphNode == nullptr || graph == nullptr || pMemsetParams == nullptr || (numDependencies > 0 && pDependencies == nullptr) || pMemsetParams->height == 0) { return hipErrorInvalidValue; @@ -181,6 +182,9 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph if (status != hipSuccess) { return status; } + if (depth == 0) { + return hipErrorInvalidValue; + } if (pMemsetParams->height == 1) { status = ihipMemset_validate(pMemsetParams->dst, pMemsetParams->value, pMemsetParams->elementSize, @@ -189,15 +193,16 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph if (pMemsetParams->pitch < (pMemsetParams->width * pMemsetParams->elementSize)) { return hipErrorInvalidValue; } - auto sizeBytes = pMemsetParams->width * pMemsetParams->height * pMemsetParams->elementSize * 1; + auto sizeBytes = pMemsetParams->width * pMemsetParams->height * + depth * pMemsetParams->elementSize; status = ihipMemset3D_validate( {pMemsetParams->dst, pMemsetParams->pitch, pMemsetParams->width, pMemsetParams->height}, - pMemsetParams->value, {pMemsetParams->width, pMemsetParams->height, 1}, sizeBytes); + pMemsetParams->value, {pMemsetParams->width, pMemsetParams->height, depth}, sizeBytes); } if (status != hipSuccess) { return status; } - *pGraphNode = new hip::GraphMemsetNode(pMemsetParams); + *pGraphNode = new hip::GraphMemsetNode(pMemsetParams, depth); status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture); return status; } @@ -736,9 +741,25 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe hipExtent& extent) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] Current capture node Memset3D on stream : %p", stream); + hipMemsetParams memsetParams = {0}; if (!hip::isValid(stream)) { return hipErrorContextIsDestroyed; } + memsetParams.dst = pitchedDevPtr.ptr; + memsetParams.value = value; + memsetParams.width = extent.width; + memsetParams.height = extent.height; + memsetParams.pitch = pitchedDevPtr.pitch; + memsetParams.elementSize = 1; + hip::Stream* s = reinterpret_cast(stream); + hip::GraphNode* pGraphNode; + hipError_t status = + ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), + s->GetLastCapturedNodes().size(), &memsetParams, true, extent.depth); + if (status != hipSuccess) { + return status; + } + s->SetLastCapturedNode(pGraphNode); return hipSuccess; } diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index 7b13bfbfe3..9195af39e6 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -1704,16 +1704,17 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D { }; class GraphMemsetNode : public GraphNode { hipMemsetParams memsetParams_; - + size_t depth_ = 1; public: - GraphMemsetNode(const hipMemsetParams* pMemsetParams) + GraphMemsetNode(const hipMemsetParams* pMemsetParams, size_t depth = 1) : GraphNode(hipGraphNodeTypeMemset, "solid", "invtrapezium", "MEMSET") { memsetParams_ = *pMemsetParams; + depth_ = depth; size_t sizeBytes = 0; if (memsetParams_.height == 1) { sizeBytes = memsetParams_.width * memsetParams_.elementSize; } else { - sizeBytes = memsetParams_.width * memsetParams_.height * memsetParams_.elementSize; + sizeBytes = memsetParams_.width * memsetParams_.height * depth_ * memsetParams_.elementSize; } } @@ -1721,6 +1722,7 @@ class GraphMemsetNode : public GraphNode { // Copy constructor GraphMemsetNode(const GraphMemsetNode& memsetNode) : GraphNode(memsetNode) { memsetParams_ = memsetNode.memsetParams_; + depth_ = memsetNode.depth_; } GraphNode* clone() const override { @@ -1733,17 +1735,17 @@ class GraphMemsetNode : public GraphNode { char buffer[500]; sprintf(buffer, "{\n%s\n| {{ID | node handle | dptr | pitch | value | elementSize | width | " - "height} | {%u | %p | %p | %zu | %u | %u | %zu | %zu}}}", + "height | depth} | {%u | %p | %p | %zu | %u | %u | %zu | %zu | %zu}}}", label_.c_str(), GetID(), this, memsetParams_.dst, memsetParams_.pitch, memsetParams_.value, memsetParams_.elementSize, memsetParams_.width, - memsetParams_.height); + memsetParams_.height, depth_); label = buffer; } else { size_t sizeBytes; if (memsetParams_.height == 1) { sizeBytes = memsetParams_.width * memsetParams_.elementSize; } else { - sizeBytes = memsetParams_.width * memsetParams_.height * memsetParams_.elementSize; + sizeBytes = memsetParams_.width * memsetParams_.height * depth_ * memsetParams_.elementSize; } label = std::to_string(GetID()) + "\n" + label_ + "\n(" + std::to_string(memsetParams_.value) + "," + std::to_string(sizeBytes) + ")"; @@ -1774,7 +1776,7 @@ class GraphMemsetNode : public GraphNode { {memsetParams_.dst, memsetParams_.pitch, memsetParams_.width * memsetParams_.elementSize, memsetParams_.height}, memsetParams_.value, - {memsetParams_.width * memsetParams_.elementSize, memsetParams_.height, 1}, stream, + {memsetParams_.width * memsetParams_.elementSize, memsetParams_.height, depth_}, stream, memsetParams_.elementSize); } return status; @@ -1793,12 +1795,15 @@ class GraphMemsetNode : public GraphNode { params->width = memsetParams_.width; } - hipError_t SetParamsInternal(const hipMemsetParams* params, bool isExec) { + hipError_t SetParamsInternal(const hipMemsetParams* params, bool isExec, size_t depth = 1) { hipError_t hip_error = hipSuccess; hip_error = ihipGraphMemsetParams_validate(params); if (hip_error != hipSuccess) { return hip_error; } + if (depth == 0) { + return hipErrorInvalidValue; + } if (isExec) { size_t discardOffset = 0; amd::Memory *memObj = getMemoryObject(params->dst, discardOffset); @@ -1829,7 +1834,7 @@ class GraphMemsetNode : public GraphNode { // 2D - hipGraphExecMemsetNodeSetParams returns invalid value if new width or new height is // not same as what memset node is added with. if (memsetParams_.width * memsetParams_.elementSize != params->width * params->elementSize - || memsetParams_.height != params->height) { + || memsetParams_.height != params->height || depth != depth_) { return hipErrorInvalidValue; } } else { @@ -1839,26 +1844,30 @@ class GraphMemsetNode : public GraphNode { amd::Memory *memObj = getMemoryObject(params->dst, discardOffset); if (memObj != nullptr) { if (params->width * params->elementSize > memObj->getUserData().width_ - || params->height > memObj->getUserData().height_) { + || params->height > memObj->getUserData().height_ + || depth > memObj->getUserData().depth_) { return hipErrorInvalidValue; } } } - sizeBytes = params->width * params->elementSize * params->height * 1; + sizeBytes = params->width * params->elementSize * params->height * depth; hip_error = ihipMemset3D_validate( {params->dst, params->pitch, params->width * params->elementSize, params->height}, - params->value, {params->width * params->elementSize, params->height, 1}, sizeBytes); + params->value, {params->width * params->elementSize, params->height, depth}, sizeBytes); } if (hip_error != hipSuccess) { return hip_error; } std::memcpy(&memsetParams_, params, sizeof(hipMemsetParams)); + depth_ = depth; return hipSuccess; } - hipError_t SetParams(const hipMemsetParams* params, bool isExec = false) { - return SetParamsInternal(params, isExec); + + hipError_t SetParams(const hipMemsetParams* params, bool isExec = false, size_t depth = 1) { + return SetParamsInternal(params, isExec, depth); } - hipError_t SetParams(const HIP_MEMSET_NODE_PARAMS* params, bool isExec = false) { + + hipError_t SetParams(const HIP_MEMSET_NODE_PARAMS* params, bool isExec = false, size_t depth = 1) { hipMemsetParams pmemsetParams; pmemsetParams.dst = params->dst; pmemsetParams.elementSize = params->elementSize; @@ -1866,11 +1875,11 @@ class GraphMemsetNode : public GraphNode { pmemsetParams.pitch = params->pitch; pmemsetParams.value = params->value; pmemsetParams.width = params->width; - return SetParamsInternal(&pmemsetParams, isExec); + return SetParamsInternal(&pmemsetParams, isExec, depth); } hipError_t SetParams(GraphNode* node) override { const GraphMemsetNode* memsetNode = static_cast(node); - return SetParams(&memsetNode->memsetParams_); + return SetParams(&memsetNode->memsetParams_, false, memsetNode->depth_); } };