diff --git a/hipnv/include/hip/nvidia_detail/nvidia_hip_runtime_api.h b/hipnv/include/hip/nvidia_detail/nvidia_hip_runtime_api.h index 65d8dbf900..06ba5cf7f9 100644 --- a/hipnv/include/hip/nvidia_detail/nvidia_hip_runtime_api.h +++ b/hipnv/include/hip/nvidia_detail/nvidia_hip_runtime_api.h @@ -4724,12 +4724,19 @@ inline static void hipMemsetParamsToCUDAMemsetNodeParams(CUDA_MEMSET_NODE_PARAMS } inline static hipError_t hipDrvGraphAddMemsetNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, - const hipGraphNode_t* dependencies, size_t numDependencies, - const hipMemsetParams* memsetParams, hipCtx_t ctx) { - CUDA_MEMSET_NODE_PARAMS cuMemsetParams; - hipMemsetParamsToCUDAMemsetNodeParams(&cuMemsetParams, memsetParams); - return hipCUResultTohipError(cuGraphAddMemsetNode(phGraphNode, hGraph, dependencies, numDependencies, - &cuMemsetParams, ctx)); + const hipGraphNode_t* dependencies, + size_t numDependencies, + const hipMemsetParams* memsetParams, hipCtx_t ctx) +{ + if (memsetParams == nullptr) { + return hipCUResultTohipError( + cuGraphAddMemsetNode(phGraphNode, hGraph, dependencies, numDependencies, nullptr, ctx)); + } else { + CUDA_MEMSET_NODE_PARAMS cuMemsetParams; + hipMemsetParamsToCUDAMemsetNodeParams(&cuMemsetParams, memsetParams); + return hipCUResultTohipError(cuGraphAddMemsetNode(phGraphNode, hGraph, dependencies, + numDependencies, &cuMemsetParams, ctx)); + } } inline static hipError_t hipDrvGraphAddMemcpyNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, @@ -4787,14 +4794,21 @@ inline static hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGrap } } -inline static hipError_t hipDrvGraphExecMemsetNodeSetParams( - hipGraphExec_t hGraphExec, hipGraphNode_t hNode, const hipMemsetParams* memsetParams, - hipCtx_t ctx) { - CUDA_MEMSET_NODE_PARAMS cuMemsetParams; - hipMemsetParamsToCUDAMemsetNodeParams(&cuMemsetParams, memsetParams); - return hipCUResultTohipError( - cuGraphExecMemsetNodeSetParams(hGraphExec, hNode, &cuMemsetParams, ctx)); - } +inline static hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + const hipMemsetParams* memsetParams, + hipCtx_t ctx) +{ + if (memsetParams == nullptr) { + return hipCUResultTohipError( + cuGraphExecMemsetNodeSetParams(hGraphExec, hNode, nullptr, ctx)); + } else { + CUDA_MEMSET_NODE_PARAMS cuMemsetParams; + hipMemsetParamsToCUDAMemsetNodeParams(&cuMemsetParams, memsetParams); + return hipCUResultTohipError( + cuGraphExecMemsetNodeSetParams(hGraphExec, hNode, &cuMemsetParams, ctx)); + } +} #endif #if CUDA_VERSION >= CUDA_11040