SWDEV-524745 - Part-I Add multi device support for hip graph. Update nodes with DevId. (#812)
- The graph nodes have been updated to capture the device ID from the capture stream or the current device when explicitly added. - Update the device ID for the memcpy node, ensuring that the device where the memory is allocated is taken into account for H2D and D2H pinned operations. Co-authored-by: Anusha GodavarthySurya <Anusha.GodavarthySurya@amd.com>
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
75602772aa
Коммит
1be5c9870a
@@ -51,7 +51,7 @@ inline hipError_t ihipGraphUpload(hipGraphExec_t graphExec, hipStream_t stream)
|
||||
|
||||
inline hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph,
|
||||
hip::GraphNode* const* pDependencies, size_t numDependencies,
|
||||
bool capture = true) {
|
||||
bool capture, int devId) {
|
||||
graph->AddNode(graphNode);
|
||||
std::unordered_set<hip::GraphNode*> DuplicateDep;
|
||||
for (size_t i = 0; i < numDependencies; i++) {
|
||||
@@ -76,6 +76,9 @@ inline hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph,
|
||||
}
|
||||
}
|
||||
}
|
||||
if (devId != 0) {
|
||||
graphNode->SetDeviceId(devId);
|
||||
}
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
@@ -125,16 +128,14 @@ hipError_t ihipGraphAddKernelNode(hip::GraphNode** pGraphNode, hip::Graph* graph
|
||||
*pGraphNode =
|
||||
new hip::GraphKernelNode(pNodeParams, pNodeEvents, coopKernel, globalWorkSizeX_remainder,
|
||||
globalWorkSizeY_remainder, globalWorkSizeZ_remainder);
|
||||
if (devId != 0) {
|
||||
(*pGraphNode)->SetDeviceId(devId);
|
||||
}
|
||||
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture);
|
||||
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture, devId);
|
||||
return status;
|
||||
}
|
||||
|
||||
hipError_t ihipGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
|
||||
hip::GraphNode* const* pDependencies, size_t numDependencies,
|
||||
const hipMemcpy3DParms* pCopyParams, bool capture = true) {
|
||||
const hipMemcpy3DParms* pCopyParams, bool capture = true,
|
||||
int devId = 0) {
|
||||
if (pGraphNode == nullptr || graph == nullptr ||
|
||||
(numDependencies > 0 && pDependencies == nullptr) || pCopyParams == nullptr) {
|
||||
return hipErrorInvalidValue;
|
||||
@@ -151,7 +152,7 @@ hipError_t ihipGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph
|
||||
hipError_t ihipDrvGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
|
||||
hip::GraphNode* const* pDependencies, size_t numDependencies,
|
||||
const HIP_MEMCPY3D* pCopyParams, hipCtx_t ctx,
|
||||
bool capture = true) {
|
||||
bool capture = true, int devId = 0) {
|
||||
if (pGraphNode == nullptr || graph == nullptr ||
|
||||
(numDependencies > 0 && pDependencies == nullptr) || pCopyParams == nullptr) {
|
||||
return hipErrorInvalidValue;
|
||||
@@ -168,7 +169,7 @@ hipError_t ihipDrvGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* gr
|
||||
hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* graph,
|
||||
hip::GraphNode* const* pDependencies, size_t numDependencies,
|
||||
void* dst, const void* src, size_t count, hipMemcpyKind kind,
|
||||
bool capture = true) {
|
||||
bool capture = true, int devId = 0) {
|
||||
if (pGraphNode == nullptr || graph == nullptr ||
|
||||
(numDependencies > 0 && pDependencies == nullptr) || count == 0) {
|
||||
return hipErrorInvalidValue;
|
||||
@@ -185,7 +186,8 @@ hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* gra
|
||||
hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
|
||||
hip::GraphNode* const* pDependencies, size_t numDependencies,
|
||||
const hipMemsetParams* pMemsetParams, bool capture = true,
|
||||
size_t depth = 1, size_t arrWidth = 1, size_t arrHeight = 1) {
|
||||
size_t depth = 1, size_t arrWidth = 1, size_t arrHeight = 1,
|
||||
int devId = 0) {
|
||||
if (pGraphNode == nullptr || graph == nullptr || pMemsetParams == nullptr ||
|
||||
(numDependencies > 0 && pDependencies == nullptr) || pMemsetParams->height == 0) {
|
||||
return hipErrorInvalidValue;
|
||||
@@ -222,7 +224,7 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph
|
||||
return status;
|
||||
}
|
||||
*pGraphNode = new hip::GraphMemsetNode(pMemsetParams, depth, arrWidth, arrHeight);
|
||||
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture);
|
||||
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture, devId);
|
||||
return status;
|
||||
}
|
||||
|
||||
@@ -477,7 +479,7 @@ hipError_t capturehipMemcpy3DAsync(hipStream_t& stream, const hipMemcpy3DParms*&
|
||||
hip::GraphNode* pGraphNode;
|
||||
hipError_t status =
|
||||
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), p);
|
||||
s->GetLastCapturedNodes().size(), p, true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -524,7 +526,7 @@ hipError_t capturehipMemcpy2DAsync(hipStream_t& stream, void*& dst, size_t& dpit
|
||||
|
||||
hipError_t status =
|
||||
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &p);
|
||||
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -566,7 +568,7 @@ hipError_t capturehipMemcpy2DFromArrayAsync(hipStream_t& stream, void*& dst, siz
|
||||
p.extent = {width / hip::getElementSize(p.srcArray), height, 1};
|
||||
hipError_t status =
|
||||
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &p);
|
||||
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -607,7 +609,7 @@ hipError_t capturehipMemcpy2DToArrayAsync(hipStream_t& stream, hipArray_t& dst,
|
||||
p.extent = {width / hip::getElementSize(p.dstArray), height, 1};
|
||||
hipError_t status =
|
||||
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &p);
|
||||
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -681,7 +683,7 @@ hipError_t capturehipMemcpyParam2DAsync(hipStream_t& stream, const hip_Memcpy2D*
|
||||
}
|
||||
hipError_t status =
|
||||
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &p);
|
||||
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -710,7 +712,7 @@ hipError_t capturehipMemcpyAtoHAsync(hipStream_t& stream, void*& dstHost, hipArr
|
||||
p.kind = hipMemcpyDeviceToHost;
|
||||
hipError_t status =
|
||||
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &p);
|
||||
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -738,7 +740,7 @@ hipError_t capturehipMemcpyHtoAAsync(hipStream_t& stream, hipArray_t& dstArray,
|
||||
p.extent = {ByteCount / hip::getElementSize(p.dstArray), 1, 1};
|
||||
hipError_t status =
|
||||
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &p);
|
||||
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -757,12 +759,9 @@ hipError_t capturehipMemcpy(hipStream_t stream, void* dst, const void* src, size
|
||||
std::vector<hip::GraphNode*> pDependencies = s->GetLastCapturedNodes();
|
||||
size_t numDependencies = s->GetLastCapturedNodes().size();
|
||||
hip::Graph* graph = s->GetCaptureGraph();
|
||||
hipError_t status = ihipMemcpy_validate(dst, src, sizeBytes, kind);
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
hip::GraphNode* node = new hip::GraphMemcpyNode1D(dst, src, sizeBytes, kind);
|
||||
status = ihipGraphAddNode(node, graph, pDependencies.data(), numDependencies);
|
||||
hipError_t status = ihipGraphAddMemcpyNode1D(&node, graph, pDependencies.data(), numDependencies,
|
||||
dst, src, sizeBytes, kind, true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -842,7 +841,7 @@ hipError_t capturehipMemcpyFromSymbolAsync(hipStream_t& stream, void*& dst, cons
|
||||
hip::GraphNode* pGraphNode =
|
||||
new hip::GraphMemcpyNodeFromSymbol(dst, symbol, sizeBytes, offset, kind);
|
||||
status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size());
|
||||
s->GetLastCapturedNodes().size(), true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -877,7 +876,7 @@ hipError_t capturehipMemcpyToSymbolAsync(hipStream_t& stream, const void*& symbo
|
||||
hip::GraphNode* pGraphNode =
|
||||
new hip::GraphMemcpyNodeToSymbol(symbol, src, sizeBytes, offset, kind);
|
||||
status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size());
|
||||
s->GetLastCapturedNodes().size(), true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -901,9 +900,9 @@ hipError_t capturehipMemsetAsync(hipStream_t& stream, void*& dst, int& value, si
|
||||
|
||||
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
|
||||
hip::GraphNode* pGraphNode;
|
||||
hipError_t status =
|
||||
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &memsetParams);
|
||||
hipError_t status = ihipGraphAddMemsetNode(
|
||||
&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &memsetParams, true, 1, 1, 1, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -927,9 +926,9 @@ hipError_t capturehipMemset2DAsync(hipStream_t& stream, void*& dst, size_t& pitc
|
||||
memsetParams.elementSize = 1;
|
||||
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
|
||||
hip::GraphNode* pGraphNode;
|
||||
hipError_t status =
|
||||
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &memsetParams);
|
||||
hipError_t status = ihipGraphAddMemsetNode(
|
||||
&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &memsetParams, true, 1, 1, 1, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -962,7 +961,7 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe
|
||||
hipError_t status =
|
||||
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size(), &memsetParams, true, extent.depth,
|
||||
pitchedDevPtr.xsize, pitchedDevPtr.ysize);
|
||||
pitchedDevPtr.xsize, pitchedDevPtr.ysize, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -986,7 +985,7 @@ hipError_t capturehipLaunchHostFunc(hipStream_t& stream, hipHostFn_t& fn, void*&
|
||||
hip::GraphNode* pGraphNode = new hip::GraphHostNode(&hostParams);
|
||||
hipError_t status =
|
||||
ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size());
|
||||
s->GetLastCapturedNodes().size(), true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -1023,7 +1022,7 @@ hipError_t capturehipMallocAsync(hipStream_t stream, hipMemPool_t mem_pool, size
|
||||
auto mem_alloc_node = new hip::GraphMemAllocNode(&node_params);
|
||||
auto status =
|
||||
ihipGraphAddNode(mem_alloc_node, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size());
|
||||
s->GetLastCapturedNodes().size(), true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -1042,7 +1041,7 @@ hipError_t capturehipFreeAsync(hipStream_t stream, void* dev_ptr) {
|
||||
auto mem_free_node = new hip::GraphMemFreeNode(dev_ptr);
|
||||
auto status =
|
||||
ihipGraphAddNode(mem_free_node, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
|
||||
s->GetLastCapturedNodes().size());
|
||||
s->GetLastCapturedNodes().size(), true, s->DeviceId());
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
@@ -1627,6 +1626,9 @@ hipError_t hipGraphLaunch_common(hip::GraphExec* graphExec, hipStream_t stream)
|
||||
if (graphExec == nullptr || !hip::GraphExec::isGraphExecValid(graphExec)) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
if (!hip::isValid(stream)) {
|
||||
return hipErrorContextIsDestroyed;
|
||||
}
|
||||
if (graphExec->GetNodeCount() == 0) {
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ class GraphKernelNode;
|
||||
typedef GraphNode* Node;
|
||||
hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph,
|
||||
hip::GraphNode* const* pDependencies, size_t numDependencies,
|
||||
bool capture);
|
||||
bool capture = true, int devId = 0);
|
||||
|
||||
class UserObject : public amd::ReferenceCountedObject {
|
||||
typedef void (*UserCallbackDestructor)(void* data);
|
||||
@@ -1319,6 +1319,7 @@ class GraphKernelNode : public GraphNode {
|
||||
}
|
||||
|
||||
hipError_t SetParams(GraphNode* node) override {
|
||||
dev_id_ = ihipGetDevice();
|
||||
const GraphKernelNode* kernelNode = static_cast<GraphKernelNode const*>(node);
|
||||
return SetParams(&kernelNode->kernelParams_);
|
||||
}
|
||||
@@ -1526,6 +1527,32 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
|
||||
hipMemcpyKind kind_;
|
||||
|
||||
public:
|
||||
// When device memory is on dev1 and graph node is added from different device update the device
|
||||
// id accordingly so that node can be executed on dev1.
|
||||
void UpdateDevId() {
|
||||
size_t sOffset = 0;
|
||||
amd::Memory* srcMemory = getMemoryObject(src_, sOffset);
|
||||
size_t dOffset = 0;
|
||||
amd::Memory* dstMemory = getMemoryObject(dst_, dOffset);
|
||||
hip::MemcpyType memType = ihipGetMemcpyType(src_, dst_, kind_);
|
||||
switch (memType) {
|
||||
case hipCopyBuffer:
|
||||
// D2H/H2D source/dst is pinned memory
|
||||
// Override the device id when node is created
|
||||
if (!((srcMemory->GetDeviceById() != dstMemory->GetDeviceById()) &&
|
||||
srcMemory->getContext().devices().size() == 1 &&
|
||||
dstMemory->getContext().devices().size() == 1)) {
|
||||
if (srcMemory->getContext().devices().size() == 1) {
|
||||
dev_id_ = srcMemory->GetDeviceById()->index();
|
||||
} else {
|
||||
dev_id_ = dstMemory->GetDeviceById()->index();
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
GraphMemcpyNode1D(void* dst, const void* src, size_t count, hipMemcpyKind kind,
|
||||
hipGraphNodeType type = hipGraphNodeTypeMemcpy)
|
||||
: GraphMemcpyNode(nullptr), dst_(dst), src_(src), count_(count), kind_(kind) {
|
||||
@@ -1535,6 +1562,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
|
||||
copyParams_.extent.height = 1;
|
||||
copyParams_.extent.depth = 1;
|
||||
copyParams_.kind = kind;
|
||||
UpdateDevId();
|
||||
}
|
||||
|
||||
~GraphMemcpyNode1D() {}
|
||||
@@ -1544,6 +1572,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
|
||||
src_ = rhs.src_;
|
||||
count_ = rhs.count_;
|
||||
kind_ = rhs.kind_;
|
||||
UpdateDevId();
|
||||
}
|
||||
|
||||
GraphNode* clone() const override { return new GraphMemcpyNode1D(*this); }
|
||||
@@ -1646,6 +1675,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
|
||||
src_ = src;
|
||||
count_ = count;
|
||||
kind_ = kind;
|
||||
UpdateDevId();
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
|
||||
Ссылка в новой задаче
Block a user