[SWDEV-454661][SWDEV-454653] - GraphExecMemcpyNodeSetParam to return error on memcpy direction change

Change-Id: I2c8f5ea394caeaaa6895003e63cd62a052c491f8


[ROCm/clr commit: 880963346d]
Этот коммит содержится в:
Rahul Manocha
2024-04-09 16:06:56 -07:00
коммит произвёл Aakash Sudhanwa
родитель 99e538f29e
Коммит 85a372e4eb
2 изменённых файлов: 20 добавлений и 1 удалений
+19
Просмотреть файл
@@ -1190,6 +1190,10 @@ hipError_t hipGraphExecMemcpyNodeSetParams1D(hipGraphExec_t hGraphExec, hipGraph
if (clonedNode == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hipMemcpyKind oldkind = reinterpret_cast<hip::GraphMemcpyNode1D*>(clonedNode)->GetMemcpyKind();
if (oldkind != kind) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(reinterpret_cast<hip::GraphMemcpyNode1D*>(clonedNode)->SetParams(dst, src,
count, kind));
}
@@ -1579,6 +1583,12 @@ hipError_t hipGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo
if (clonedNode == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hipMemcpyKind oldkind = reinterpret_cast<hip::GraphMemcpyNode*>(clonedNode)->GetMemcpyKind();
hipMemcpyKind newkind = pNodeParams->kind;
if (oldkind != newkind) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(reinterpret_cast<hip::GraphMemcpyNode*>(clonedNode)->SetParams(pNodeParams));
}
@@ -2112,6 +2122,11 @@ hipError_t hipGraphExecMemcpyNodeSetParamsFromSymbol(hipGraphExec_t hGraphExec,
if (clonedNode == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hipMemcpyKind oldkind = reinterpret_cast<hip::GraphMemcpyNodeFromSymbol*>(clonedNode)->GetMemcpyKind();
if (oldkind != kind) {
HIP_RETURN(hipErrorInvalidValue);
}
constexpr bool kCheckDeviceIsSame = true;
HIP_RETURN(reinterpret_cast<hip::GraphMemcpyNodeFromSymbol*>(clonedNode)
->SetParams(dst, symbol, count, offset, kind, kCheckDeviceIsSame));
@@ -2182,6 +2197,10 @@ hipError_t hipGraphExecMemcpyNodeSetParamsToSymbol(hipGraphExec_t hGraphExec, hi
if (clonedNode == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hipMemcpyKind oldkind = reinterpret_cast<hip::GraphMemcpyNodeToSymbol*>(clonedNode)->GetMemcpyKind();
if (oldkind != kind) {
HIP_RETURN(hipErrorInvalidValue);
}
constexpr bool kCheckDeviceIsSame = true;
HIP_RETURN(reinterpret_cast<hip::GraphMemcpyNodeToSymbol*>(clonedNode)
->SetParams(symbol, src, count, offset, kind, kCheckDeviceIsSame));
+1 -1
Просмотреть файл
@@ -1244,7 +1244,7 @@ class GraphMemcpyNode : public GraphNode {
std::memcpy(params, &copyParams_, sizeof(hipMemcpy3DParms));
}
virtual hipMemcpyKind GetMemcpyKind() const { return hipMemcpyDefault; };
virtual hipMemcpyKind GetMemcpyKind() const { return copyParams_.kind; };
hipError_t SetParams(const hipMemcpy3DParms* params) {
hipError_t status = ValidateParams(params);