SWDEV-421020 - Adds hipGraphAddBatchMemOp, SetGetParams and execSetParams APIs

Change-Id: Ieccecfe6173cc68fd3c01f86c99f7cc09fe194a3
Cette révision appartient à :
Sourabh Betigeri
2024-09-12 15:08:43 -07:00
Parent 93f1e8ff60
révision f1c05e9026
8 fichiers modifiés avec 294 ajouts et 8 suppressions
+20 -2
Voir le fichier
@@ -61,7 +61,7 @@
// - Reset any of the *_STEP_VERSION defines to zero if the corresponding *_MAJOR_VERSION increases
#define HIP_API_TABLE_STEP_VERSION 0
#define HIP_COMPILER_API_TABLE_STEP_VERSION 0
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 7
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 8
// HIP API interface
typedef hipError_t (*t___hipPopCallConfiguration)(dim3* gridDim, dim3* blockDim, size_t* sharedMem,
@@ -723,7 +723,8 @@ typedef hipError_t (*t_hipStreamWriteValue32)(hipStream_t stream, void* ptr, uin
typedef hipError_t (*t_hipStreamWriteValue64)(hipStream_t stream, void* ptr, uint64_t value,
unsigned int flags);
typedef hipError_t (*t_hipStreamBatchMemOp)(hipStream_t stream, unsigned int count,
hipStreamBatchMemOpParams* paramArray, unsigned int flags);
hipStreamBatchMemOpParams* paramArray,
unsigned int flags);
typedef hipError_t (*t_hipTexObjectCreate)(hipTextureObject_t* pTexObject,
const HIP_RESOURCE_DESC* pResDesc,
const HIP_TEXTURE_DESC* pTexDesc,
@@ -1006,6 +1007,17 @@ typedef hipError_t (*t_hipExtHostAlloc)(void **ptr, size_t size,
typedef hipError_t (*t_hipDeviceGetTexture1DLinearMaxWidth)(size_t *maxWidthInElements,
const hipChannelFormatDesc *fmtDesc,
int device);
typedef hipError_t (*t_hipGraphAddBatchMemOpNode)(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
const hipGraphNode_t* dependencies,
size_t numDependencies,
const hipBatchMemOpNodeParams* nodeParams);
typedef hipError_t (*t_hipGraphBatchMemOpNodeGetParams)(hipGraphNode_t hNode,
hipBatchMemOpNodeParams* nodeParams_out);
typedef hipError_t (*t_hipGraphBatchMemOpNodeSetParams)(hipGraphNode_t hNode,
hipBatchMemOpNodeParams* nodeParams);
typedef hipError_t (*t_hipGraphExecBatchMemOpNodeSetParams)(
hipGraphExec_t hGraphExec, hipGraphNode_t hNode, const hipBatchMemOpNodeParams* nodeParams);
// HIP Compiler dispatch table
struct HipCompilerDispatchTable {
// HIP_COMPILER_API_TABLE_STEP_VERSION == 0
@@ -1524,6 +1536,12 @@ struct HipDispatchTable {
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 7
t_hipStreamBatchMemOp hipStreamBatchMemOp_fn;
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 8
t_hipGraphAddBatchMemOpNode hipGraphAddBatchMemOpNode_fn;
t_hipGraphBatchMemOpNodeGetParams hipGraphBatchMemOpNodeGetParams_fn;
t_hipGraphBatchMemOpNodeSetParams hipGraphBatchMemOpNodeSetParams_fn;
t_hipGraphExecBatchMemOpNodeSetParams hipGraphExecBatchMemOpNodeSetParams_fn;
// DO NOT EDIT ABOVE!
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 7
+120 -1
Voir le fichier
@@ -426,7 +426,11 @@ enum hip_api_id_t {
HIP_API_ID_hipSetValidDevices = 406,
HIP_API_ID_hipExtHostAlloc = 407,
HIP_API_ID_hipStreamBatchMemOp = 408,
HIP_API_ID_LAST = 408,
HIP_API_ID_hipGraphAddBatchMemOpNode = 409,
HIP_API_ID_hipGraphBatchMemOpNodeGetParams = 410,
HIP_API_ID_hipGraphBatchMemOpNodeSetParams = 411,
HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams = 412,
HIP_API_ID_LAST = 412,
HIP_API_ID_hipChooseDevice = HIP_API_ID_CONCAT(HIP_API_ID_,hipChooseDevice),
HIP_API_ID_hipGetDeviceProperties = HIP_API_ID_CONCAT(HIP_API_ID_,hipGetDeviceProperties),
@@ -861,6 +865,10 @@ static inline const char* hip_api_name(const uint32_t id) {
case HIP_API_ID_hipWaitExternalSemaphoresAsync: return "hipWaitExternalSemaphoresAsync";
case HIP_API_ID_hipExtGetLastError: return "hipExtGetLastError";
case HIP_API_ID_hipStreamBatchMemOp: return "hipStreamBatchMemOp";
case HIP_API_ID_hipGraphAddBatchMemOpNode: return "hipGraphAddBatchMemOpNode";
case HIP_API_ID_hipGraphBatchMemOpNodeGetParams: return "hipGraphBatchMemOpNodeGetParams";
case HIP_API_ID_hipGraphBatchMemOpNodeSetParams: return "hipGraphBatchMemOpNodeSetParams";
case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams: return "hipGraphExecBatchMemOpNodeSetParams";
};
return "unknown";
};
@@ -1265,6 +1273,10 @@ static inline uint32_t hipApiIdByName(const char* name) {
if (strcmp("hipUserObjectRetain", name) == 0) return HIP_API_ID_hipUserObjectRetain;
if (strcmp("hipWaitExternalSemaphoresAsync", name) == 0) return HIP_API_ID_hipWaitExternalSemaphoresAsync;
if (strcmp("hipStreamBatchMemOp", name) == 0) return HIP_API_ID_hipStreamBatchMemOp;
if (strcmp("hipGraphAddBatchMemOpNode", name) == 0) return HIP_API_ID_hipGraphAddBatchMemOpNode;
if (strcmp("hipGraphBatchMemOpNodeGetParams", name) == 0) return HIP_API_ID_hipGraphBatchMemOpNodeGetParams;
if (strcmp("hipGraphBatchMemOpNodeSetParams", name) == 0) return HIP_API_ID_hipGraphBatchMemOpNodeSetParams;
if (strcmp("hipGraphExecBatchMemOpNodeSetParams", name) == 0) return HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams;
return HIP_API_ID_NONE;
}
@@ -3633,6 +3645,32 @@ typedef struct hip_api_data_s {
hipStreamBatchMemOpParams paramArray__val;
unsigned int flags;
} hipStreamBatchMemOp;
struct {
hipGraphNode_t* phGraphNode;
hipGraphNode_t phGraphNode__val;
hipGraph_t hGraph;
const hipGraphNode_t* dependencies;
hipGraphNode_t dependencies__val;
size_t numDependencies;
const hipBatchMemOpNodeParams* nodeParams;
hipBatchMemOpNodeParams nodeParams__val;
} hipGraphAddBatchMemOpNode;
struct {
hipGraphNode_t hNode;
hipBatchMemOpNodeParams* nodeParams_out;
hipBatchMemOpNodeParams nodeParams_out__val;
} hipGraphBatchMemOpNodeGetParams;
struct {
hipGraphNode_t hNode;
hipBatchMemOpNodeParams* nodeParams;
hipBatchMemOpNodeParams nodeParams__val;
} hipGraphBatchMemOpNodeSetParams;
struct {
hipGraphExec_t hGraphExec;
hipGraphNode_t hNode;
const hipBatchMemOpNodeParams* nodeParams;
hipBatchMemOpNodeParams nodeParams__val;
} hipGraphExecBatchMemOpNodeSetParams;
} args;
uint64_t *phase_data;
} hip_api_data_t;
@@ -6045,6 +6083,36 @@ typedef struct hip_api_data_s {
cb_data.args.hipWaitExternalSemaphoresAsync.numExtSems = (unsigned int)numExtSems; \
cb_data.args.hipWaitExternalSemaphoresAsync.stream = (hipStream_t)stream; \
};
// hipGraphAddBatchMemOpNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', 'hGraph'),
// ('hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'),
// ('hipBatchMemOpNodeParams*'), 'nodeParams')]
#define INIT_hipGraphAddBatchMemOpNode_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipGraphAddBatchMemOpNode.phGraphNode = (hipGraphNode_t*)phGraphNode; \
cb_data.args.hipGraphAddBatchMemOpNode.hGraph = (hipGraph_t)hGraph; \
cb_data.args.hipGraphAddBatchMemOpNode.dependencies= (hipGraphNode_t*)dependencies; \
cb_data.args.hipGraphAddBatchMemOpNode.numDependencies = (size_t)numDependencies; \
cb_data.args.hipGraphAddBatchMemOpNode.nodeParams = (hipBatchMemOpNodeParams*)nodeParams; \
};
// hipGraphBatchMemOpNodeGetParams[('hipGraphNode_t', hNode),
// ('hipBatchMemOpNodeParams*', 'nodeParams_out')]
#define INIT_hipGraphBatchMemOpNodeGetParams_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipGraphBatchMemOpNodeGetParams.hNode = (hipGraphNode_t)hNode; \
cb_data.args.hipGraphBatchMemOpNodeGetParams.nodeParams_out = (hipBatchMemOpNodeParams*)nodeParams_out; \
};
// hipGraphBatchMemOpNodeSetParams[('hipGraphNode_t', hNode),
// ('hipBatchMemOpNodeParams*', 'nodeParams')]
#define INIT_hipGraphBatchMemOpNodeSetParams_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipGraphBatchMemOpNodeSetParams.hNode = (hipGraphNode_t)hNode; \
cb_data.args.hipGraphBatchMemOpNodeSetParams.nodeParams = (hipBatchMemOpNodeParams*)nodeParams; \
};
// hipGraphExecBatchMemOpNodeSetParams[('hipGraphExec_t'. hGraphExec),
// ('hipGraphNode_t'. hNode), ('hipBatchMemOpNodeParams*', 'nodeParams')]
#define INIT_hipGraphExecBatchMemOpNodeSetParams_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipGraphExecBatchMemOpNodeSetParams.hGraphExec = (hipGraphExec_t)hGraphExec; \
cb_data.args.hipGraphExecBatchMemOpNodeSetParams.hNode = (hipGraphNode_t)hNode; \
cb_data.args.hipGraphExecBatchMemOpNodeSetParams.nodeParams= (hipBatchMemOpNodeParams*)nodeParams; \
};
#define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(cb_data)
// Macros for non-public API primitives
@@ -7551,6 +7619,29 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
case HIP_API_ID_hipStreamBatchMemOp:
if (data->args.hipStreamBatchMemOp.paramArray) data->args.hipStreamBatchMemOp.paramArray__val = *(data->args.hipStreamBatchMemOp.paramArray);
break;
// hipGraphAddBatchMemOpNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', 'hGraph'),
// ('hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'),
// ('hipBatchMemOpNodeParams*'), 'nodeParams')]
case HIP_API_ID_hipGraphAddBatchMemOpNode:
if (data->args.hipGraphAddBatchMemOpNode.phGraphNode) data->args.hipGraphAddBatchMemOpNode.phGraphNode__val = *(data->args.hipGraphAddBatchMemOpNode.phGraphNode);
if (data->args.hipGraphAddBatchMemOpNode.dependencies) data->args.hipGraphAddBatchMemOpNode.dependencies__val = *(data->args.hipGraphAddBatchMemOpNode.dependencies);
if (data->args.hipGraphAddBatchMemOpNode.nodeParams) data->args.hipGraphAddBatchMemOpNode.nodeParams__val = *(data->args.hipGraphAddBatchMemOpNode.nodeParams);
break;
// hipGraphBatchMemOpNodeGetParams[('hipGraphNode_t', hNode),
// ('hipBatchMemOpNodeParams*', 'nodeParams_out')]
case HIP_API_ID_hipGraphBatchMemOpNodeGetParams:
if (data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out) data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out__val = *(data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out);
break;
// hipGraphBatchMemOpNodeSetParams[('hipGraphNode_t', hNode),
// ('hipBatchMemOpNodeParams*', 'nodeParams')]
case HIP_API_ID_hipGraphBatchMemOpNodeSetParams:
if (data->args.hipGraphBatchMemOpNodeSetParams.nodeParams) data->args.hipGraphBatchMemOpNodeSetParams.nodeParams__val = *(data->args.hipGraphBatchMemOpNodeSetParams.nodeParams);
break;
// hipGraphExecBatchMemOpNodeSetParams[('hipGraphExec_t'. hGraphExec),
// ('hipGraphNode_t'. hNode), ('hipBatchMemOpNodeParams*', 'nodeParams')]
case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams:
if (data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams) data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams__val = *(data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams);
break;
// hipTexRefGetAddress[('hipDeviceptr_t*', 'dev_ptr'), ('const textureReference*', 'texRef')]
case HIP_API_ID_hipTexRefGetAddress:
if (data->args.hipTexRefGetAddress.dev_ptr) data->args.hipTexRefGetAddress.dev_ptr__val = *(data->args.hipTexRefGetAddress.dev_ptr);
@@ -10843,6 +10934,34 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
oss << ", stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipWaitExternalSemaphoresAsync.stream);
oss << ")";
break;
case HIP_API_ID_hipGraphAddBatchMemOpNode:
oss << "hipGraphAddBatchMemOpNode(";
oss << "phGraphNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.phGraphNode);
oss << ", hGraph="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.hGraph);
oss << ", dependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.dependencies);
oss << ", numDependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.numDependencies);
oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.nodeParams);
oss << ")";
break;
case HIP_API_ID_hipGraphBatchMemOpNodeGetParams:
oss << "hipGraphBatchMemOpNodeGetParams(";
oss << "hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeGetParams.hNode);
oss << ", nodeParams_out="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out);
oss << ")";
break;
case HIP_API_ID_hipGraphBatchMemOpNodeSetParams:
oss << "hipGraphBatchMemOpNodeSetParams(";
oss << "hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeSetParams.hNode);
oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeSetParams.nodeParams);
oss << ")";
break;
case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams:
oss << "hipGraphExecBatchMemOpNodeSetParams(";
oss << "hGraphExec="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.hGraphExec);
oss << ", hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.hNode);
oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams);
oss << ")";
break;
default: oss << "unknown";
};
return strdup(oss.str().c_str());
+4
Voir le fichier
@@ -481,3 +481,7 @@ hipDrvGraphMemcpyNodeSetParams
hipDrvGraphMemcpyNodeGetParams
hipExtHostAlloc
hipStreamBatchMemOp
hipGraphAddBatchMemOpNode
hipGraphBatchMemOpNodeGetParams
hipGraphBatchMemOpNodeSetParams
hipGraphExecBatchMemOpNodeSetParams
+21 -4
Voir le fichier
@@ -804,7 +804,15 @@ hipError_t hipExternalMemoryGetMappedMipmappedArray(
const hipExternalMemoryMipmappedArrayDesc* mipmapDesc);
hipError_t hipDrvGraphMemcpyNodeGetParams(hipGraphNode_t hNode, HIP_MEMCPY3D* nodeParams);
hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY3D* nodeParams);
hipError_t hipGraphAddBatchMemOpNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
const hipGraphNode_t* dependencies, size_t numDependencies,
const hipBatchMemOpNodeParams* nodeParams);
hipError_t hipGraphBatchMemOpNodeGetParams(hipGraphNode_t hNode,
hipBatchMemOpNodeParams* nodeParams_out);
hipError_t hipGraphBatchMemOpNodeSetParams(hipGraphNode_t hNode,
hipBatchMemOpNodeParams* nodeParams);
hipError_t hipGraphExecBatchMemOpNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
const hipBatchMemOpNodeParams* nodeParams);
} // namespace hip
namespace hip {
@@ -1301,6 +1309,11 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
hip::hipExternalMemoryGetMappedMipmappedArray;
ptrDispatchTable->hipDrvGraphMemcpyNodeGetParams_fn = hip::hipDrvGraphMemcpyNodeGetParams;
ptrDispatchTable->hipDrvGraphMemcpyNodeSetParams_fn = hip::hipDrvGraphMemcpyNodeSetParams;
ptrDispatchTable->hipGraphAddBatchMemOpNode_fn = hip::hipGraphAddBatchMemOpNode;
ptrDispatchTable->hipGraphBatchMemOpNodeGetParams_fn = hip::hipGraphBatchMemOpNodeGetParams;
ptrDispatchTable->hipGraphBatchMemOpNodeSetParams_fn = hip::hipGraphBatchMemOpNodeSetParams;
ptrDispatchTable->hipGraphExecBatchMemOpNodeSetParams_fn =
hip::hipGraphExecBatchMemOpNodeSetParams;
}
#if HIP_ROCPROFILER_REGISTER > 0
@@ -1892,7 +1905,11 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipExtHostAlloc_fn, 461)
HIP_ENFORCE_ABI(HipDispatchTable, hipDeviceGetTexture1DLinearMaxWidth_fn, 462)
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 7
HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBatchMemOp_fn, 463);
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 8
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphAddBatchMemOpNode_fn, 464);
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphBatchMemOpNodeGetParams_fn, 465);
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphBatchMemOpNodeSetParams_fn, 466);
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecBatchMemOpNodeSetParams_fn, 467);
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
@@ -1900,9 +1917,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBatchMemOp_fn, 463);
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
//
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 464)
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 468)
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 7,
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 8,
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
"pointers and then update this check so it is true");
#endif
+54
Voir le fichier
@@ -3427,4 +3427,58 @@ hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY
HIP_RETURN(reinterpret_cast<hip::GraphDrvMemcpyNode*>(hNode)->SetParams(nodeParams));
}
hipError_t hipGraphAddBatchMemOpNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
const hipGraphNode_t* dependencies, size_t numDependencies,
const hipBatchMemOpNodeParams* nodeParams) {
HIP_INIT_API(hipGraphAddBatchMemOpNode, phGraphNode, hGraph, dependencies, numDependencies,
nodeParams);
if (phGraphNode == nullptr || hGraph == nullptr ||
(numDependencies > 0 && dependencies == nullptr) || nodeParams == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* node = new hip::hipGraphBatchMemOpNode(nodeParams);
hipError_t status =
ihipGraphAddNode(node, reinterpret_cast<hip::Graph*>(hGraph),
reinterpret_cast<hip::GraphNode* const*>(dependencies), numDependencies);
*phGraphNode = reinterpret_cast<hipGraphNode_t>(node);
HIP_RETURN(status);
}
hipError_t hipGraphBatchMemOpNodeGetParams(hipGraphNode_t hNode,
hipBatchMemOpNodeParams* nodeParams_out) {
HIP_INIT_API(hipGraphBatchMemOpNodeGetParams, hNode, nodeParams_out);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
if (!hip::GraphNode::isNodeValid(n) || nodeParams_out == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
reinterpret_cast<hip::hipGraphBatchMemOpNode*>(n)->GetParams(nodeParams_out);
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphBatchMemOpNodeSetParams(hipGraphNode_t hNode,
hipBatchMemOpNodeParams* nodeParams) {
HIP_INIT_API(hipGraphBatchMemOpNodeSetParams, hNode, nodeParams);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
if (!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(reinterpret_cast<hip::hipGraphBatchMemOpNode*>(n)->SetParams(nodeParams));
}
hipError_t hipGraphExecBatchMemOpNodeSetParams(hipGraphExec_t hGraphExec,
hipGraphNode_t hNode,
const hipBatchMemOpNodeParams* nodeParams) {
HIP_INIT_API(hipGraphExecBatchMemOpNodeSetParams, hGraphExec, hNode, nodeParams);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
hip::GraphExec* graphExec = reinterpret_cast<hip::GraphExec*>(hGraphExec);
if (hGraphExec == nullptr || hNode == nullptr || !hip::GraphExec::isGraphExecValid(graphExec) ||
!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(graphExec)->GetClonedNode(n);
if (clonedNode == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(reinterpret_cast<hip::hipGraphBatchMemOpNode*>(clonedNode)->SetParams(nodeParams));
}
} // namespace hip
+45
Voir le fichier
@@ -2736,4 +2736,49 @@ class hipGraphExternalSemWaitNode : public GraphNode {
}
};
class hipGraphBatchMemOpNode : public GraphNode {
hipBatchMemOpNodeParams batchMemOpNodeParam_;
public:
hipGraphBatchMemOpNode(const hipBatchMemOpNodeParams* pNodeParams)
: GraphNode(hipGraphNodeTypeBatchMemOp, "solid", "rectangle", "BATCH_MEM_OP_NODE") {
batchMemOpNodeParam_ = *pNodeParams;
}
hipGraphBatchMemOpNode(const hipGraphBatchMemOpNode& rhs) : GraphNode(rhs) {
batchMemOpNodeParam_ = rhs.batchMemOpNodeParam_;
}
~hipGraphBatchMemOpNode() {}
GraphNode* clone() const {
return new hipGraphBatchMemOpNode(static_cast<hipGraphBatchMemOpNode const&>(*this));
}
hipError_t CreateCommand(hip::Stream* stream) {
hipError_t status = GraphNode::CreateCommand(stream);
if (status != hipSuccess) {
return status;
}
amd::Command::EventWaitList waitList;
amd::BatchMemoryOperationCommand* command = new amd::BatchMemoryOperationCommand(
*stream, ROCCLR_COMMAND_BATCH_STREAM, batchMemOpNodeParam_.count,
batchMemOpNodeParam_.flags, waitList, batchMemOpNodeParam_.paramArray,
sizeof(hipStreamBatchMemOpParams));
if (command == nullptr) {
return hipErrorOutOfMemory;
}
commands_.emplace_back(command);
return hipSuccess;
}
void GetParams(hipBatchMemOpNodeParams* pNodeParams) const {
std::memcpy(pNodeParams, &batchMemOpNodeParam_, sizeof(hipBatchMemOpNodeParams));
}
hipError_t SetParams(const hipBatchMemOpNodeParams* pNodeParams) {
std::memcpy(&batchMemOpNodeParam_, pNodeParams, sizeof(hipBatchMemOpNodeParams));
return hipSuccess;
}
};
} // namespace hip
+10 -1
Voir le fichier
@@ -583,7 +583,16 @@ local:
hip_6.3 {
global:
hipExtHostAlloc;
hipStreamBatchMemOp;
local:
*;
} hip_6.2;
hip_6.4 {
global:
hipGraphAddBatchMemOpNode;
hipGraphBatchMemOpNodeGetParams;
hipGraphBatchMemOpNodeSetParams;
hipGraphExecBatchMemOpNodeSetParams;
local:
*;
} hip_6.3;
+20
Voir le fichier
@@ -1823,3 +1823,23 @@ hipError_t hipGraphNodeSetParams(hipGraphNode_t node, hipGraphNodeParams *nodePa
hipError_t hipExtHostAlloc(void** ptr, size_t size, unsigned int flags) {
return hip::GetHipDispatchTable()->hipExtHostAlloc_fn(ptr, size, flags);
}
hipError_t hipGraphAddBatchMemOpNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
const hipGraphNode_t* dependencies, size_t numDependencies,
const hipBatchMemOpNodeParams* nodeParams) {
return hip::GetHipDispatchTable()->hipGraphAddBatchMemOpNode_fn(pGraphNode, graph, dependencies,
numDependencies, nodeParams);
}
hipError_t hipGraphBatchMemOpNodeGetParams(hipGraphNode_t hNode,
hipBatchMemOpNodeParams* nodeParams_out) {
return hip::GetHipDispatchTable()->hipGraphBatchMemOpNodeGetParams_fn(hNode, nodeParams_out);
}
hipError_t hipGraphBatchMemOpNodeSetParams(hipGraphNode_t hNode,
hipBatchMemOpNodeParams* nodeParams) {
return hip::GetHipDispatchTable()->hipGraphBatchMemOpNodeSetParams_fn(hNode, nodeParams);
}
hipError_t hipGraphExecBatchMemOpNodeSetParams(hipGraphExec_t hGraphExec,
hipGraphNode_t hNode,
const hipBatchMemOpNodeParams* nodeParams) {
return hip::GetHipDispatchTable()->hipGraphExecBatchMemOpNodeSetParams_fn(hGraphExec, hNode,
nodeParams);
}