SWDEV-524745 - Part-I Add multi device support for hip graph. Update nodes with DevId. (#812)

- The graph nodes have been updated to capture the device ID from the capture stream or the current device when explicitly added.
- Update the device ID for the memcpy node, ensuring that the device where the memory is allocated is taken into account for H2D and D2H pinned operations.

Co-authored-by: Anusha GodavarthySurya <Anusha.GodavarthySurya@amd.com>
Этот коммит содержится в:
Godavarthy Surya, Anusha
2025-09-10 11:35:25 +05:30
коммит произвёл GitHub
родитель 75602772aa
Коммит 1be5c9870a
2 изменённых файлов: 67 добавлений и 35 удалений
+36 -34
Просмотреть файл
@@ -51,7 +51,7 @@ inline hipError_t ihipGraphUpload(hipGraphExec_t graphExec, hipStream_t stream)
inline hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
bool capture = true) {
bool capture, int devId) {
graph->AddNode(graphNode);
std::unordered_set<hip::GraphNode*> DuplicateDep;
for (size_t i = 0; i < numDependencies; i++) {
@@ -76,6 +76,9 @@ inline hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph,
}
}
}
if (devId != 0) {
graphNode->SetDeviceId(devId);
}
return hipSuccess;
}
@@ -125,16 +128,14 @@ hipError_t ihipGraphAddKernelNode(hip::GraphNode** pGraphNode, hip::Graph* graph
*pGraphNode =
new hip::GraphKernelNode(pNodeParams, pNodeEvents, coopKernel, globalWorkSizeX_remainder,
globalWorkSizeY_remainder, globalWorkSizeZ_remainder);
if (devId != 0) {
(*pGraphNode)->SetDeviceId(devId);
}
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture);
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture, devId);
return status;
}
hipError_t ihipGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
const hipMemcpy3DParms* pCopyParams, bool capture = true) {
const hipMemcpy3DParms* pCopyParams, bool capture = true,
int devId = 0) {
if (pGraphNode == nullptr || graph == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || pCopyParams == nullptr) {
return hipErrorInvalidValue;
@@ -151,7 +152,7 @@ hipError_t ihipGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph
hipError_t ihipDrvGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
const HIP_MEMCPY3D* pCopyParams, hipCtx_t ctx,
bool capture = true) {
bool capture = true, int devId = 0) {
if (pGraphNode == nullptr || graph == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || pCopyParams == nullptr) {
return hipErrorInvalidValue;
@@ -168,7 +169,7 @@ hipError_t ihipDrvGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* gr
hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
void* dst, const void* src, size_t count, hipMemcpyKind kind,
bool capture = true) {
bool capture = true, int devId = 0) {
if (pGraphNode == nullptr || graph == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || count == 0) {
return hipErrorInvalidValue;
@@ -185,7 +186,8 @@ hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* gra
hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
const hipMemsetParams* pMemsetParams, bool capture = true,
size_t depth = 1, size_t arrWidth = 1, size_t arrHeight = 1) {
size_t depth = 1, size_t arrWidth = 1, size_t arrHeight = 1,
int devId = 0) {
if (pGraphNode == nullptr || graph == nullptr || pMemsetParams == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || pMemsetParams->height == 0) {
return hipErrorInvalidValue;
@@ -222,7 +224,7 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph
return status;
}
*pGraphNode = new hip::GraphMemsetNode(pMemsetParams, depth, arrWidth, arrHeight);
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture);
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture, devId);
return status;
}
@@ -477,7 +479,7 @@ hipError_t capturehipMemcpy3DAsync(hipStream_t& stream, const hipMemcpy3DParms*&
hip::GraphNode* pGraphNode;
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), p);
s->GetLastCapturedNodes().size(), p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -524,7 +526,7 @@ hipError_t capturehipMemcpy2DAsync(hipStream_t& stream, void*& dst, size_t& dpit
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -566,7 +568,7 @@ hipError_t capturehipMemcpy2DFromArrayAsync(hipStream_t& stream, void*& dst, siz
p.extent = {width / hip::getElementSize(p.srcArray), height, 1};
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -607,7 +609,7 @@ hipError_t capturehipMemcpy2DToArrayAsync(hipStream_t& stream, hipArray_t& dst,
p.extent = {width / hip::getElementSize(p.dstArray), height, 1};
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -681,7 +683,7 @@ hipError_t capturehipMemcpyParam2DAsync(hipStream_t& stream, const hip_Memcpy2D*
}
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -710,7 +712,7 @@ hipError_t capturehipMemcpyAtoHAsync(hipStream_t& stream, void*& dstHost, hipArr
p.kind = hipMemcpyDeviceToHost;
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -738,7 +740,7 @@ hipError_t capturehipMemcpyHtoAAsync(hipStream_t& stream, hipArray_t& dstArray,
p.extent = {ByteCount / hip::getElementSize(p.dstArray), 1, 1};
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -757,12 +759,9 @@ hipError_t capturehipMemcpy(hipStream_t stream, void* dst, const void* src, size
std::vector<hip::GraphNode*> pDependencies = s->GetLastCapturedNodes();
size_t numDependencies = s->GetLastCapturedNodes().size();
hip::Graph* graph = s->GetCaptureGraph();
hipError_t status = ihipMemcpy_validate(dst, src, sizeBytes, kind);
if (status != hipSuccess) {
return status;
}
hip::GraphNode* node = new hip::GraphMemcpyNode1D(dst, src, sizeBytes, kind);
status = ihipGraphAddNode(node, graph, pDependencies.data(), numDependencies);
hipError_t status = ihipGraphAddMemcpyNode1D(&node, graph, pDependencies.data(), numDependencies,
dst, src, sizeBytes, kind, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -842,7 +841,7 @@ hipError_t capturehipMemcpyFromSymbolAsync(hipStream_t& stream, void*& dst, cons
hip::GraphNode* pGraphNode =
new hip::GraphMemcpyNodeFromSymbol(dst, symbol, sizeBytes, offset, kind);
status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size());
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -877,7 +876,7 @@ hipError_t capturehipMemcpyToSymbolAsync(hipStream_t& stream, const void*& symbo
hip::GraphNode* pGraphNode =
new hip::GraphMemcpyNodeToSymbol(symbol, src, sizeBytes, offset, kind);
status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size());
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -901,9 +900,9 @@ hipError_t capturehipMemsetAsync(hipStream_t& stream, void*& dst, int& value, si
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
hip::GraphNode* pGraphNode;
hipError_t status =
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams);
hipError_t status = ihipGraphAddMemsetNode(
&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams, true, 1, 1, 1, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -927,9 +926,9 @@ hipError_t capturehipMemset2DAsync(hipStream_t& stream, void*& dst, size_t& pitc
memsetParams.elementSize = 1;
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
hip::GraphNode* pGraphNode;
hipError_t status =
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams);
hipError_t status = ihipGraphAddMemsetNode(
&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams, true, 1, 1, 1, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -962,7 +961,7 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe
hipError_t status =
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams, true, extent.depth,
pitchedDevPtr.xsize, pitchedDevPtr.ysize);
pitchedDevPtr.xsize, pitchedDevPtr.ysize, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -986,7 +985,7 @@ hipError_t capturehipLaunchHostFunc(hipStream_t& stream, hipHostFn_t& fn, void*&
hip::GraphNode* pGraphNode = new hip::GraphHostNode(&hostParams);
hipError_t status =
ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size());
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -1023,7 +1022,7 @@ hipError_t capturehipMallocAsync(hipStream_t stream, hipMemPool_t mem_pool, size
auto mem_alloc_node = new hip::GraphMemAllocNode(&node_params);
auto status =
ihipGraphAddNode(mem_alloc_node, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size());
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -1042,7 +1041,7 @@ hipError_t capturehipFreeAsync(hipStream_t stream, void* dev_ptr) {
auto mem_free_node = new hip::GraphMemFreeNode(dev_ptr);
auto status =
ihipGraphAddNode(mem_free_node, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size());
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
@@ -1627,6 +1626,9 @@ hipError_t hipGraphLaunch_common(hip::GraphExec* graphExec, hipStream_t stream)
if (graphExec == nullptr || !hip::GraphExec::isGraphExecValid(graphExec)) {
return hipErrorInvalidValue;
}
if (!hip::isValid(stream)) {
return hipErrorContextIsDestroyed;
}
if (graphExec->GetNodeCount() == 0) {
return hipSuccess;
}
+31 -1
Просмотреть файл
@@ -49,7 +49,7 @@ class GraphKernelNode;
typedef GraphNode* Node;
hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
bool capture);
bool capture = true, int devId = 0);
class UserObject : public amd::ReferenceCountedObject {
typedef void (*UserCallbackDestructor)(void* data);
@@ -1319,6 +1319,7 @@ class GraphKernelNode : public GraphNode {
}
hipError_t SetParams(GraphNode* node) override {
dev_id_ = ihipGetDevice();
const GraphKernelNode* kernelNode = static_cast<GraphKernelNode const*>(node);
return SetParams(&kernelNode->kernelParams_);
}
@@ -1526,6 +1527,32 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
hipMemcpyKind kind_;
public:
// When device memory is on dev1 and graph node is added from different device update the device
// id accordingly so that node can be executed on dev1.
void UpdateDevId() {
size_t sOffset = 0;
amd::Memory* srcMemory = getMemoryObject(src_, sOffset);
size_t dOffset = 0;
amd::Memory* dstMemory = getMemoryObject(dst_, dOffset);
hip::MemcpyType memType = ihipGetMemcpyType(src_, dst_, kind_);
switch (memType) {
case hipCopyBuffer:
// D2H/H2D source/dst is pinned memory
// Override the device id when node is created
if (!((srcMemory->GetDeviceById() != dstMemory->GetDeviceById()) &&
srcMemory->getContext().devices().size() == 1 &&
dstMemory->getContext().devices().size() == 1)) {
if (srcMemory->getContext().devices().size() == 1) {
dev_id_ = srcMemory->GetDeviceById()->index();
} else {
dev_id_ = dstMemory->GetDeviceById()->index();
}
}
break;
default:
break;
}
}
GraphMemcpyNode1D(void* dst, const void* src, size_t count, hipMemcpyKind kind,
hipGraphNodeType type = hipGraphNodeTypeMemcpy)
: GraphMemcpyNode(nullptr), dst_(dst), src_(src), count_(count), kind_(kind) {
@@ -1535,6 +1562,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
copyParams_.extent.height = 1;
copyParams_.extent.depth = 1;
copyParams_.kind = kind;
UpdateDevId();
}
~GraphMemcpyNode1D() {}
@@ -1544,6 +1572,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
src_ = rhs.src_;
count_ = rhs.count_;
kind_ = rhs.kind_;
UpdateDevId();
}
GraphNode* clone() const override { return new GraphMemcpyNode1D(*this); }
@@ -1646,6 +1675,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
src_ = src;
count_ = count;
kind_ = kind;
UpdateDevId();
return hipSuccess;
}