SWDEV-453535 - Capture hipMemset3DAsync.

Change-Id: I517c2557573db258b3e3e353f02f6a56652b0fde


[ROCm/clr commit: 12e0bdcd32]
This commit is contained in:
Jaydeep Patel
2024-03-28 05:39:53 +00:00
کامیت شده توسط Jaydeepkumar Patel
والد 24bb38acb8
کامیت 3dae883e73
2فایلهای تغییر یافته به همراه51 افزوده شده و 21 حذف شده
@@ -165,7 +165,8 @@ hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* gra
hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
const hipMemsetParams* pMemsetParams, bool capture = true) {
const hipMemsetParams* pMemsetParams,
bool capture = true, size_t depth = 1) {
if (pGraphNode == nullptr || graph == nullptr || pMemsetParams == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || pMemsetParams->height == 0) {
return hipErrorInvalidValue;
@@ -181,6 +182,9 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph
if (status != hipSuccess) {
return status;
}
if (depth == 0) {
return hipErrorInvalidValue;
}
if (pMemsetParams->height == 1) {
status =
ihipMemset_validate(pMemsetParams->dst, pMemsetParams->value, pMemsetParams->elementSize,
@@ -189,15 +193,16 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph
if (pMemsetParams->pitch < (pMemsetParams->width * pMemsetParams->elementSize)) {
return hipErrorInvalidValue;
}
auto sizeBytes = pMemsetParams->width * pMemsetParams->height * pMemsetParams->elementSize * 1;
auto sizeBytes = pMemsetParams->width * pMemsetParams->height *
depth * pMemsetParams->elementSize;
status = ihipMemset3D_validate(
{pMemsetParams->dst, pMemsetParams->pitch, pMemsetParams->width, pMemsetParams->height},
pMemsetParams->value, {pMemsetParams->width, pMemsetParams->height, 1}, sizeBytes);
pMemsetParams->value, {pMemsetParams->width, pMemsetParams->height, depth}, sizeBytes);
}
if (status != hipSuccess) {
return status;
}
*pGraphNode = new hip::GraphMemsetNode(pMemsetParams);
*pGraphNode = new hip::GraphMemsetNode(pMemsetParams, depth);
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture);
return status;
}
@@ -736,9 +741,25 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe
hipExtent& extent) {
ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] Current capture node Memset3D on stream : %p",
stream);
hipMemsetParams memsetParams = {0};
if (!hip::isValid(stream)) {
return hipErrorContextIsDestroyed;
}
memsetParams.dst = pitchedDevPtr.ptr;
memsetParams.value = value;
memsetParams.width = extent.width;
memsetParams.height = extent.height;
memsetParams.pitch = pitchedDevPtr.pitch;
memsetParams.elementSize = 1;
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
hip::GraphNode* pGraphNode;
hipError_t status =
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams, true, extent.depth);
if (status != hipSuccess) {
return status;
}
s->SetLastCapturedNode(pGraphNode);
return hipSuccess;
}
@@ -1704,16 +1704,17 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D {
};
class GraphMemsetNode : public GraphNode {
hipMemsetParams memsetParams_;
size_t depth_ = 1;
public:
GraphMemsetNode(const hipMemsetParams* pMemsetParams)
GraphMemsetNode(const hipMemsetParams* pMemsetParams, size_t depth = 1)
: GraphNode(hipGraphNodeTypeMemset, "solid", "invtrapezium", "MEMSET") {
memsetParams_ = *pMemsetParams;
depth_ = depth;
size_t sizeBytes = 0;
if (memsetParams_.height == 1) {
sizeBytes = memsetParams_.width * memsetParams_.elementSize;
} else {
sizeBytes = memsetParams_.width * memsetParams_.height * memsetParams_.elementSize;
sizeBytes = memsetParams_.width * memsetParams_.height * depth_ * memsetParams_.elementSize;
}
}
@@ -1721,6 +1722,7 @@ class GraphMemsetNode : public GraphNode {
// Copy constructor
GraphMemsetNode(const GraphMemsetNode& memsetNode) : GraphNode(memsetNode) {
memsetParams_ = memsetNode.memsetParams_;
depth_ = memsetNode.depth_;
}
GraphNode* clone() const override {
@@ -1733,17 +1735,17 @@ class GraphMemsetNode : public GraphNode {
char buffer[500];
sprintf(buffer,
"{\n%s\n| {{ID | node handle | dptr | pitch | value | elementSize | width | "
"height} | {%u | %p | %p | %zu | %u | %u | %zu | %zu}}}",
"height | depth} | {%u | %p | %p | %zu | %u | %u | %zu | %zu | %zu}}}",
label_.c_str(), GetID(), this, memsetParams_.dst, memsetParams_.pitch,
memsetParams_.value, memsetParams_.elementSize, memsetParams_.width,
memsetParams_.height);
memsetParams_.height, depth_);
label = buffer;
} else {
size_t sizeBytes;
if (memsetParams_.height == 1) {
sizeBytes = memsetParams_.width * memsetParams_.elementSize;
} else {
sizeBytes = memsetParams_.width * memsetParams_.height * memsetParams_.elementSize;
sizeBytes = memsetParams_.width * memsetParams_.height * depth_ * memsetParams_.elementSize;
}
label = std::to_string(GetID()) + "\n" + label_ + "\n(" +
std::to_string(memsetParams_.value) + "," + std::to_string(sizeBytes) + ")";
@@ -1774,7 +1776,7 @@ class GraphMemsetNode : public GraphNode {
{memsetParams_.dst, memsetParams_.pitch, memsetParams_.width * memsetParams_.elementSize,
memsetParams_.height},
memsetParams_.value,
{memsetParams_.width * memsetParams_.elementSize, memsetParams_.height, 1}, stream,
{memsetParams_.width * memsetParams_.elementSize, memsetParams_.height, depth_}, stream,
memsetParams_.elementSize);
}
return status;
@@ -1793,12 +1795,15 @@ class GraphMemsetNode : public GraphNode {
params->width = memsetParams_.width;
}
hipError_t SetParamsInternal(const hipMemsetParams* params, bool isExec) {
hipError_t SetParamsInternal(const hipMemsetParams* params, bool isExec, size_t depth = 1) {
hipError_t hip_error = hipSuccess;
hip_error = ihipGraphMemsetParams_validate(params);
if (hip_error != hipSuccess) {
return hip_error;
}
if (depth == 0) {
return hipErrorInvalidValue;
}
if (isExec) {
size_t discardOffset = 0;
amd::Memory *memObj = getMemoryObject(params->dst, discardOffset);
@@ -1829,7 +1834,7 @@ class GraphMemsetNode : public GraphNode {
// 2D - hipGraphExecMemsetNodeSetParams returns invalid value if new width or new height is
// not same as what memset node is added with.
if (memsetParams_.width * memsetParams_.elementSize != params->width * params->elementSize
|| memsetParams_.height != params->height) {
|| memsetParams_.height != params->height || depth != depth_) {
return hipErrorInvalidValue;
}
} else {
@@ -1839,26 +1844,30 @@ class GraphMemsetNode : public GraphNode {
amd::Memory *memObj = getMemoryObject(params->dst, discardOffset);
if (memObj != nullptr) {
if (params->width * params->elementSize > memObj->getUserData().width_
|| params->height > memObj->getUserData().height_) {
|| params->height > memObj->getUserData().height_
|| depth > memObj->getUserData().depth_) {
return hipErrorInvalidValue;
}
}
}
sizeBytes = params->width * params->elementSize * params->height * 1;
sizeBytes = params->width * params->elementSize * params->height * depth;
hip_error = ihipMemset3D_validate(
{params->dst, params->pitch, params->width * params->elementSize, params->height},
params->value, {params->width * params->elementSize, params->height, 1}, sizeBytes);
params->value, {params->width * params->elementSize, params->height, depth}, sizeBytes);
}
if (hip_error != hipSuccess) {
return hip_error;
}
std::memcpy(&memsetParams_, params, sizeof(hipMemsetParams));
depth_ = depth;
return hipSuccess;
}
hipError_t SetParams(const hipMemsetParams* params, bool isExec = false) {
return SetParamsInternal(params, isExec);
hipError_t SetParams(const hipMemsetParams* params, bool isExec = false, size_t depth = 1) {
return SetParamsInternal(params, isExec, depth);
}
hipError_t SetParams(const HIP_MEMSET_NODE_PARAMS* params, bool isExec = false) {
hipError_t SetParams(const HIP_MEMSET_NODE_PARAMS* params, bool isExec = false, size_t depth = 1) {
hipMemsetParams pmemsetParams;
pmemsetParams.dst = params->dst;
pmemsetParams.elementSize = params->elementSize;
@@ -1866,11 +1875,11 @@ class GraphMemsetNode : public GraphNode {
pmemsetParams.pitch = params->pitch;
pmemsetParams.value = params->value;
pmemsetParams.width = params->width;
return SetParamsInternal(&pmemsetParams, isExec);
return SetParamsInternal(&pmemsetParams, isExec, depth);
}
hipError_t SetParams(GraphNode* node) override {
const GraphMemsetNode* memsetNode = static_cast<GraphMemsetNode const*>(node);
return SetParams(&memsetNode->memsetParams_);
return SetParams(&memsetNode->memsetParams_, false, memsetNode->depth_);
}
};