[SWDEV-454661][SWDEV-454653] - GraphExecMemcpyNodeSetParam to return error on memcpy direction change
Change-Id: I2c8f5ea394caeaaa6895003e63cd62a052c491f8
[ROCm/clr commit: 880963346d]
Этот коммит содержится в:
коммит произвёл
Aakash Sudhanwa
родитель
99e538f29e
Коммит
85a372e4eb
@@ -1190,6 +1190,10 @@ hipError_t hipGraphExecMemcpyNodeSetParams1D(hipGraphExec_t hGraphExec, hipGraph
|
||||
if (clonedNode == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
hipMemcpyKind oldkind = reinterpret_cast<hip::GraphMemcpyNode1D*>(clonedNode)->GetMemcpyKind();
|
||||
if (oldkind != kind) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
HIP_RETURN(reinterpret_cast<hip::GraphMemcpyNode1D*>(clonedNode)->SetParams(dst, src,
|
||||
count, kind));
|
||||
}
|
||||
@@ -1579,6 +1583,12 @@ hipError_t hipGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo
|
||||
if (clonedNode == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
hipMemcpyKind oldkind = reinterpret_cast<hip::GraphMemcpyNode*>(clonedNode)->GetMemcpyKind();
|
||||
hipMemcpyKind newkind = pNodeParams->kind;
|
||||
if (oldkind != newkind) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
HIP_RETURN(reinterpret_cast<hip::GraphMemcpyNode*>(clonedNode)->SetParams(pNodeParams));
|
||||
}
|
||||
|
||||
@@ -2112,6 +2122,11 @@ hipError_t hipGraphExecMemcpyNodeSetParamsFromSymbol(hipGraphExec_t hGraphExec,
|
||||
if (clonedNode == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
hipMemcpyKind oldkind = reinterpret_cast<hip::GraphMemcpyNodeFromSymbol*>(clonedNode)->GetMemcpyKind();
|
||||
if (oldkind != kind) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
constexpr bool kCheckDeviceIsSame = true;
|
||||
HIP_RETURN(reinterpret_cast<hip::GraphMemcpyNodeFromSymbol*>(clonedNode)
|
||||
->SetParams(dst, symbol, count, offset, kind, kCheckDeviceIsSame));
|
||||
@@ -2182,6 +2197,10 @@ hipError_t hipGraphExecMemcpyNodeSetParamsToSymbol(hipGraphExec_t hGraphExec, hi
|
||||
if (clonedNode == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
hipMemcpyKind oldkind = reinterpret_cast<hip::GraphMemcpyNodeToSymbol*>(clonedNode)->GetMemcpyKind();
|
||||
if (oldkind != kind) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
constexpr bool kCheckDeviceIsSame = true;
|
||||
HIP_RETURN(reinterpret_cast<hip::GraphMemcpyNodeToSymbol*>(clonedNode)
|
||||
->SetParams(symbol, src, count, offset, kind, kCheckDeviceIsSame));
|
||||
|
||||
@@ -1244,7 +1244,7 @@ class GraphMemcpyNode : public GraphNode {
|
||||
std::memcpy(params, ©Params_, sizeof(hipMemcpy3DParms));
|
||||
}
|
||||
|
||||
virtual hipMemcpyKind GetMemcpyKind() const { return hipMemcpyDefault; };
|
||||
virtual hipMemcpyKind GetMemcpyKind() const { return copyParams_.kind; };
|
||||
|
||||
hipError_t SetParams(const hipMemcpy3DParms* params) {
|
||||
hipError_t status = ValidateParams(params);
|
||||
|
||||
Ссылка в новой задаче
Block a user