SWDEV-453535 - Capture hipMemset3DAsync.
Change-Id: I517c2557573db258b3e3e353f02f6a56652b0fde
[ROCm/clr commit: 12e0bdcd32]
This commit is contained in:
کامیت شده توسط
Jaydeepkumar Patel
والد
24bb38acb8
کامیت
3dae883e73
@@ -165,7 +165,8 @@ hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* gra
|
||||
|
||||
hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
|
||||
hip::GraphNode* const* pDependencies, size_t numDependencies,
|
||||
const hipMemsetParams* pMemsetParams, bool capture = true) {
|
||||
const hipMemsetParams* pMemsetParams,
|
||||
bool capture = true, size_t depth = 1) {
|
||||
if (pGraphNode == nullptr || graph == nullptr || pMemsetParams == nullptr ||
|
||||
(numDependencies > 0 && pDependencies == nullptr) || pMemsetParams->height == 0) {
|
||||
return hipErrorInvalidValue;
|
||||
@@ -181,6 +182,9 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
if (depth == 0) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
if (pMemsetParams->height == 1) {
|
||||
status =
|
||||
ihipMemset_validate(pMemsetParams->dst, pMemsetParams->value, pMemsetParams->elementSize,
|
||||
@@ -189,15 +193,16 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph
|
||||
if (pMemsetParams->pitch < (pMemsetParams->width * pMemsetParams->elementSize)) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
auto sizeBytes = pMemsetParams->width * pMemsetParams->height * pMemsetParams->elementSize * 1;
|
||||
auto sizeBytes = pMemsetParams->width * pMemsetParams->height *
|
||||
depth * pMemsetParams->elementSize;
|
||||
status = ihipMemset3D_validate(
|
||||
{pMemsetParams->dst, pMemsetParams->pitch, pMemsetParams->width, pMemsetParams->height},
|
||||
pMemsetParams->value, {pMemsetParams->width, pMemsetParams->height, 1}, sizeBytes);
|
||||
pMemsetParams->value, {pMemsetParams->width, pMemsetParams->height, depth}, sizeBytes);
|
||||
}
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
*pGraphNode = new hip::GraphMemsetNode(pMemsetParams);
|
||||
*pGraphNode = new hip::GraphMemsetNode(pMemsetParams, depth);
|
||||
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture);
|
||||
return status;
|
||||
}
|
||||
@@ -736,9 +741,25 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe
|
||||
hipExtent& extent) {
|
||||
ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] Current capture node Memset3D on stream : %p",
|
||||
stream);
|
||||
hipMemsetParams memsetParams = {0};
|
||||
if (!hip::isValid(stream)) {
|
||||
return hipErrorContextIsDestroyed;
|
||||
}
|
||||
memsetParams.dst = pitchedDevPtr.ptr;
|
||||
memsetParams.value = value;
|
||||
memsetParams.width = extent.width;
|
||||
memsetParams.height = extent.height;
|
||||
memsetParams.pitch = pitchedDevPtr.pitch;
|
||||
memsetParams.elementSize = 1;
|
||||
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
|
||||
hip::GraphNode* pGraphNode;
|
||||
hipError_t status =
|
||||
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &memsetParams, true, extent.depth);
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
s->SetLastCapturedNode(pGraphNode);
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
|
||||
@@ -1704,16 +1704,17 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D {
|
||||
};
|
||||
class GraphMemsetNode : public GraphNode {
|
||||
hipMemsetParams memsetParams_;
|
||||
|
||||
size_t depth_ = 1;
|
||||
public:
|
||||
GraphMemsetNode(const hipMemsetParams* pMemsetParams)
|
||||
GraphMemsetNode(const hipMemsetParams* pMemsetParams, size_t depth = 1)
|
||||
: GraphNode(hipGraphNodeTypeMemset, "solid", "invtrapezium", "MEMSET") {
|
||||
memsetParams_ = *pMemsetParams;
|
||||
depth_ = depth;
|
||||
size_t sizeBytes = 0;
|
||||
if (memsetParams_.height == 1) {
|
||||
sizeBytes = memsetParams_.width * memsetParams_.elementSize;
|
||||
} else {
|
||||
sizeBytes = memsetParams_.width * memsetParams_.height * memsetParams_.elementSize;
|
||||
sizeBytes = memsetParams_.width * memsetParams_.height * depth_ * memsetParams_.elementSize;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1721,6 +1722,7 @@ class GraphMemsetNode : public GraphNode {
|
||||
// Copy constructor
|
||||
GraphMemsetNode(const GraphMemsetNode& memsetNode) : GraphNode(memsetNode) {
|
||||
memsetParams_ = memsetNode.memsetParams_;
|
||||
depth_ = memsetNode.depth_;
|
||||
}
|
||||
|
||||
GraphNode* clone() const override {
|
||||
@@ -1733,17 +1735,17 @@ class GraphMemsetNode : public GraphNode {
|
||||
char buffer[500];
|
||||
sprintf(buffer,
|
||||
"{\n%s\n| {{ID | node handle | dptr | pitch | value | elementSize | width | "
|
||||
"height} | {%u | %p | %p | %zu | %u | %u | %zu | %zu}}}",
|
||||
"height | depth} | {%u | %p | %p | %zu | %u | %u | %zu | %zu | %zu}}}",
|
||||
label_.c_str(), GetID(), this, memsetParams_.dst, memsetParams_.pitch,
|
||||
memsetParams_.value, memsetParams_.elementSize, memsetParams_.width,
|
||||
memsetParams_.height);
|
||||
memsetParams_.height, depth_);
|
||||
label = buffer;
|
||||
} else {
|
||||
size_t sizeBytes;
|
||||
if (memsetParams_.height == 1) {
|
||||
sizeBytes = memsetParams_.width * memsetParams_.elementSize;
|
||||
} else {
|
||||
sizeBytes = memsetParams_.width * memsetParams_.height * memsetParams_.elementSize;
|
||||
sizeBytes = memsetParams_.width * memsetParams_.height * depth_ * memsetParams_.elementSize;
|
||||
}
|
||||
label = std::to_string(GetID()) + "\n" + label_ + "\n(" +
|
||||
std::to_string(memsetParams_.value) + "," + std::to_string(sizeBytes) + ")";
|
||||
@@ -1774,7 +1776,7 @@ class GraphMemsetNode : public GraphNode {
|
||||
{memsetParams_.dst, memsetParams_.pitch, memsetParams_.width * memsetParams_.elementSize,
|
||||
memsetParams_.height},
|
||||
memsetParams_.value,
|
||||
{memsetParams_.width * memsetParams_.elementSize, memsetParams_.height, 1}, stream,
|
||||
{memsetParams_.width * memsetParams_.elementSize, memsetParams_.height, depth_}, stream,
|
||||
memsetParams_.elementSize);
|
||||
}
|
||||
return status;
|
||||
@@ -1793,12 +1795,15 @@ class GraphMemsetNode : public GraphNode {
|
||||
params->width = memsetParams_.width;
|
||||
}
|
||||
|
||||
hipError_t SetParamsInternal(const hipMemsetParams* params, bool isExec) {
|
||||
hipError_t SetParamsInternal(const hipMemsetParams* params, bool isExec, size_t depth = 1) {
|
||||
hipError_t hip_error = hipSuccess;
|
||||
hip_error = ihipGraphMemsetParams_validate(params);
|
||||
if (hip_error != hipSuccess) {
|
||||
return hip_error;
|
||||
}
|
||||
if (depth == 0) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
if (isExec) {
|
||||
size_t discardOffset = 0;
|
||||
amd::Memory *memObj = getMemoryObject(params->dst, discardOffset);
|
||||
@@ -1829,7 +1834,7 @@ class GraphMemsetNode : public GraphNode {
|
||||
// 2D - hipGraphExecMemsetNodeSetParams returns invalid value if new width or new height is
|
||||
// not same as what memset node is added with.
|
||||
if (memsetParams_.width * memsetParams_.elementSize != params->width * params->elementSize
|
||||
|| memsetParams_.height != params->height) {
|
||||
|| memsetParams_.height != params->height || depth != depth_) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
} else {
|
||||
@@ -1839,26 +1844,30 @@ class GraphMemsetNode : public GraphNode {
|
||||
amd::Memory *memObj = getMemoryObject(params->dst, discardOffset);
|
||||
if (memObj != nullptr) {
|
||||
if (params->width * params->elementSize > memObj->getUserData().width_
|
||||
|| params->height > memObj->getUserData().height_) {
|
||||
|| params->height > memObj->getUserData().height_
|
||||
|| depth > memObj->getUserData().depth_) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
sizeBytes = params->width * params->elementSize * params->height * 1;
|
||||
sizeBytes = params->width * params->elementSize * params->height * depth;
|
||||
hip_error = ihipMemset3D_validate(
|
||||
{params->dst, params->pitch, params->width * params->elementSize, params->height},
|
||||
params->value, {params->width * params->elementSize, params->height, 1}, sizeBytes);
|
||||
params->value, {params->width * params->elementSize, params->height, depth}, sizeBytes);
|
||||
}
|
||||
if (hip_error != hipSuccess) {
|
||||
return hip_error;
|
||||
}
|
||||
std::memcpy(&memsetParams_, params, sizeof(hipMemsetParams));
|
||||
depth_ = depth;
|
||||
return hipSuccess;
|
||||
}
|
||||
hipError_t SetParams(const hipMemsetParams* params, bool isExec = false) {
|
||||
return SetParamsInternal(params, isExec);
|
||||
|
||||
hipError_t SetParams(const hipMemsetParams* params, bool isExec = false, size_t depth = 1) {
|
||||
return SetParamsInternal(params, isExec, depth);
|
||||
}
|
||||
hipError_t SetParams(const HIP_MEMSET_NODE_PARAMS* params, bool isExec = false) {
|
||||
|
||||
hipError_t SetParams(const HIP_MEMSET_NODE_PARAMS* params, bool isExec = false, size_t depth = 1) {
|
||||
hipMemsetParams pmemsetParams;
|
||||
pmemsetParams.dst = params->dst;
|
||||
pmemsetParams.elementSize = params->elementSize;
|
||||
@@ -1866,11 +1875,11 @@ class GraphMemsetNode : public GraphNode {
|
||||
pmemsetParams.pitch = params->pitch;
|
||||
pmemsetParams.value = params->value;
|
||||
pmemsetParams.width = params->width;
|
||||
return SetParamsInternal(&pmemsetParams, isExec);
|
||||
return SetParamsInternal(&pmemsetParams, isExec, depth);
|
||||
}
|
||||
hipError_t SetParams(GraphNode* node) override {
|
||||
const GraphMemsetNode* memsetNode = static_cast<GraphMemsetNode const*>(node);
|
||||
return SetParams(&memsetNode->memsetParams_);
|
||||
return SetParams(&memsetNode->memsetParams_, false, memsetNode->depth_);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
مرجع در شماره جدید
Block a user