diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index b0d28b7769..32b21582f6 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -143,6 +143,9 @@ hipError_t capturehipLaunchKernel(hipStream_t& stream, const void*& hostFunction dim3& blockDim, void**& args, size_t& sharedMemBytes) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node kernel launch on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } hip::Stream* s = reinterpret_cast(stream); hipKernelNodeParams nodeParams; nodeParams.func = const_cast(hostFunction); @@ -166,6 +169,9 @@ hipError_t capturehipLaunchKernel(hipStream_t& stream, const void*& hostFunction hipError_t capturehipMemcpy3DAsync(hipStream_t& stream, const hipMemcpy3DParms*& p) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memcpy3D on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } hip::Stream* s = reinterpret_cast(stream); hipGraphNode_t pGraphNode; hipError_t status = @@ -183,7 +189,7 @@ hipError_t capturehipMemcpy2DAsync(hipStream_t& stream, void*& dst, size_t& dpit hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memcpy2D on stream : %p", stream); - if (dst == nullptr || src == nullptr) { + if (dst == nullptr || src == nullptr || !hip::isValid(stream)) { return hipErrorInvalidValue; } hip::Stream* s = reinterpret_cast(stream); @@ -218,7 +224,7 @@ hipError_t capturehipMemcpy2DFromArrayAsync(hipStream_t& stream, void*& dst, siz hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memcpy2DFromArray on stream : %p", stream); - if (src == nullptr || dst == nullptr) { + if (src == nullptr || dst == nullptr || !hip::isValid(stream)) { return hipErrorInvalidValue; } hip::Stream* s = reinterpret_cast(stream); @@ -250,7 +256,7 @@ hipError_t capturehipMemcpyFromArrayAsync(hipStream_t& stream, void*& dst, hipAr hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memcpy2DFromArray on stream : %p", stream); - if (src == nullptr || dst == nullptr) { + if (src == nullptr || dst == nullptr || !hip::isValid(stream)) { return hipErrorInvalidValue; } hip::Stream* s = reinterpret_cast(stream); @@ -285,7 +291,7 @@ hipError_t capturehipMemcpy2DToArrayAsync(hipStream_t& stream, hipArray*& dst, s size_t& width, size_t& height, hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memcpy2DFromArray on stream : %p", stream); - if (src == nullptr || dst == nullptr) { + if (src == nullptr || dst == nullptr || !hip::isValid(stream)) { return hipErrorInvalidValue; } hip::Stream* s = reinterpret_cast(stream); @@ -317,7 +323,7 @@ hipError_t capturehipMemcpyToArrayAsync(hipStream_t& stream, hipArray_t& dst, si hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memcpy2DFromArray on stream : %p", stream); - if (src == nullptr || dst == nullptr) { + if (src == nullptr || dst == nullptr || !hip::isValid(stream)) { return hipErrorInvalidValue; } hip::Stream* s = reinterpret_cast(stream); @@ -350,6 +356,9 @@ hipError_t capturehipMemcpyToArrayAsync(hipStream_t& stream, hipArray_t& dst, si hipError_t capturehipMemcpyParam2DAsync(hipStream_t& stream, const hip_Memcpy2D*& pCopy) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node MemcpyParam2D on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } hip::Stream* s = reinterpret_cast(stream); hipGraphNode_t pGraphNode; hipMemcpy3DParms p = {}; @@ -396,7 +405,7 @@ hipError_t capturehipMemcpyAtoHAsync(hipStream_t& stream, void*& dstHost, hipArr size_t& srcOffset, size_t& ByteCount) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node MemcpyParam2D on stream : %p", stream); - if (srcArray == nullptr || dstHost == nullptr) { + if (srcArray == nullptr || dstHost == nullptr || !hip::isValid(stream)) { return hipErrorInvalidValue; } hip::Stream* s = reinterpret_cast(stream); @@ -421,7 +430,7 @@ hipError_t capturehipMemcpyHtoAAsync(hipStream_t& stream, hipArray*& dstArray, s const void*& srcHost, size_t& ByteCount) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node MemcpyParam2D on stream : %p", stream); - if (dstArray == nullptr || srcHost == nullptr) { + if (dstArray == nullptr || srcHost == nullptr || !hip::isValid(stream)) { return hipErrorInvalidValue; } hip::Stream* s = reinterpret_cast(stream); @@ -444,6 +453,9 @@ hipError_t capturehipMemcpyHtoAAsync(hipStream_t& stream, hipArray*& dstArray, s hipError_t capturehipMemcpy(hipStream_t stream, void* dst, const void* src, size_t sizeBytes, hipMemcpyKind kind) { + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } hip::Stream* s = reinterpret_cast(stream); hipGraph_t graph = nullptr; std::vector pDependencies = s->GetLastCapturedNodes(); @@ -463,6 +475,9 @@ hipError_t capturehipMemcpyAsync(hipStream_t& stream, void*& dst, const void*& s size_t& sizeBytes, hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memcpy1D on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } return capturehipMemcpy(stream, dst, src, sizeBytes, kind); } @@ -470,6 +485,9 @@ hipError_t capturehipMemcpyHtoDAsync(hipStream_t& stream, hipDeviceptr_t& dstDev size_t& ByteCount, hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node MemcpyHtoD on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } return capturehipMemcpy(stream, dstDevice, srcHost, ByteCount, kind); } @@ -478,6 +496,9 @@ hipError_t capturehipMemcpyDtoDAsync(hipStream_t& stream, hipDeviceptr_t& dstDev hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node hipMemcpyDtoD on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } return capturehipMemcpy(stream, dstDevice, srcDevice, ByteCount, kind); } @@ -485,6 +506,9 @@ hipError_t capturehipMemcpyDtoHAsync(hipStream_t& stream, void*& dstHost, hipDev size_t& ByteCount, hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node hipMemcpyDtoH on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } return capturehipMemcpy(stream, dstHost, srcDevice, ByteCount, kind); } @@ -492,6 +516,9 @@ hipError_t capturehipMemcpyFromSymbolAsync(hipStream_t& stream, void*& dst, cons size_t& sizeBytes, size_t& offset, hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node MemcpyFromSymbolNode on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } size_t sym_size = 0; hipDeviceptr_t device_ptr = nullptr; @@ -512,6 +539,9 @@ hipError_t capturehipMemcpyToSymbolAsync(hipStream_t& stream, const void*& symbo size_t& sizeBytes, size_t& offset, hipMemcpyKind& kind) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node MemcpyToSymbolNode on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } size_t sym_size = 0; hipDeviceptr_t device_ptr = nullptr; hipError_t status = ihipMemcpySymbol_validate(symbol, sizeBytes, offset, sym_size, device_ptr); @@ -530,7 +560,9 @@ hipError_t capturehipMemsetAsync(hipStream_t& stream, void*& dst, int& value, si size_t& sizeBytes) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memset1D on stream : %p", stream); - + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } hipMemsetParams memsetParams = {0}; memsetParams.dst = dst; memsetParams.value = value; @@ -555,7 +587,9 @@ hipError_t capturehipMemset2DAsync(hipStream_t& stream, void*& dst, size_t& pitc ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memset2D on stream : %p", stream); hipMemsetParams memsetParams = {0}; - + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } memsetParams.dst = dst; memsetParams.value = value; memsetParams.width = width; @@ -577,6 +611,9 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe hipExtent& extent) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memset3D on stream : %p", stream); + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } return hipSuccess; } @@ -586,6 +623,9 @@ hipError_t capturehipEventRecord(hipStream_t& stream, hipEvent_t& event) { if (event == nullptr) { return hipErrorInvalidHandle; } + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } hip::Event* e = reinterpret_cast(event); e->StartCapture(stream); hip::Stream* s = reinterpret_cast(stream); @@ -600,7 +640,9 @@ hipError_t capturehipStreamWaitEvent(hipEvent_t& event, hipStream_t& stream, uns ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node StreamWaitEvent on stream : %p, Event %p", stream, event); - + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } hip::Stream* s = reinterpret_cast(stream); hip::Event* e = reinterpret_cast(event); @@ -620,7 +662,7 @@ hipError_t capturehipStreamWaitEvent(hipEvent_t& event, hipStream_t& stream, uns hipError_t capturehipLaunchHostFunc(hipStream_t& stream, hipHostFn_t& fn, void*& userData) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node Memset2D on stream : %p", stream); - if (fn == nullptr || userData == nullptr) { + if (fn == nullptr || userData == nullptr || !hip::isValid(stream)) { return hipErrorInvalidValue; } hipHostNodeParams hostParams = {0}; @@ -636,7 +678,7 @@ hipError_t capturehipLaunchHostFunc(hipStream_t& stream, hipHostFn_t& fn, void*& hipError_t hipStreamIsCapturing(hipStream_t stream, hipStreamCaptureStatus* pCaptureStatus) { HIP_INIT_API(hipStreamIsCapturing, stream, pCaptureStatus); - if (stream == nullptr) { + if (stream == nullptr || !hip::isValid(stream)) { HIP_RETURN(hipErrorInvalidValue); } *pCaptureStatus = reinterpret_cast(stream)->GetCaptureStatus(); @@ -645,6 +687,9 @@ hipError_t hipStreamIsCapturing(hipStream_t stream, hipStreamCaptureStatus* pCap hipError_t hipStreamBeginCapture(hipStream_t stream, hipStreamCaptureMode mode) { HIP_INIT_API(hipStreamBeginCapture, stream, mode); + if (!hip::isValid(stream)) { + HIP_RETURN(hipErrorInvalidValue); + } hip::Stream* s = reinterpret_cast(stream); // capture cannot be initiated on legacy stream // It can be initiated if the stream is not already in capture mode @@ -660,6 +705,9 @@ hipError_t hipStreamBeginCapture(hipStream_t stream, hipStreamCaptureMode mode) hipError_t hipStreamEndCapture(hipStream_t stream, hipGraph_t* pGraph) { HIP_INIT_API(hipStreamEndCapture, stream, pGraph); + if (!hip::isValid(stream)) { + HIP_RETURN(hipErrorInvalidValue); + } hip::Stream* s = reinterpret_cast(stream); // Capture must be ended on the same stream in which it was initiated if (!s->IsOriginStream()) { @@ -832,16 +880,19 @@ hipError_t hipGraphExecDestroy(hipGraphExec_t pGraphExec) { HIP_RETURN(hipSuccess); } -hipError_t ihipGraphlaunch(hipGraphExec_t graphExec, hipStream_t stream) { +hipError_t ihipGraphLaunch(hipGraphExec_t graphExec, hipStream_t stream) { + if (!hip::isValid(stream)) { + return hipErrorInvalidValue; + } return graphExec->Run(stream); } hipError_t hipGraphLaunch(hipGraphExec_t graphExec, hipStream_t stream) { HIP_INIT_API(hipGraphLaunch, graphExec, stream); - if (graphExec == nullptr) { + if (graphExec == nullptr || !hip::isValid(stream)) { HIP_RETURN(hipErrorInvalidValue); } - HIP_RETURN_DURATION(ihipGraphlaunch(graphExec, stream)); + HIP_RETURN_DURATION(ihipGraphLaunch(graphExec, stream)); } hipError_t hipGraphGetNodes(hipGraph_t graph, hipGraphNode_t* nodes, size_t* numNodes) { @@ -1003,7 +1054,7 @@ hipError_t hipGraphExecChildGraphNodeSetParams(hipGraphExec_t hGraphExec, hipGra hipError_t hipStreamGetCaptureInfo(hipStream_t stream, hipStreamCaptureStatus* pCaptureStatus, unsigned long long* pId) { HIP_INIT_API(hipStreamGetCaptureInfo, stream, pCaptureStatus, pId); - if (pCaptureStatus == nullptr || pId == nullptr) { + if (pCaptureStatus == nullptr || pId == nullptr || !hip::isValid(stream)) { HIP_RETURN(hipErrorInvalidValue); } if (stream == nullptr) { @@ -1024,6 +1075,9 @@ hipError_t hipStreamGetCaptureInfo_v2(hipStream_t stream, hipStreamCaptureStatus if (stream == nullptr) { HIP_RETURN(hipErrorStreamCaptureImplicit); } + if (!hip::isValid(stream)) { + HIP_RETURN(hipErrorInvalidValue); + } hip::Stream* s = reinterpret_cast(stream); *captureStatus_out = s->GetCaptureStatus(); if (*captureStatus_out == hipStreamCaptureStatusActive) { @@ -1044,6 +1098,9 @@ hipError_t hipStreamGetCaptureInfo_v2(hipStream_t stream, hipStreamCaptureStatus hipError_t hipStreamUpdateCaptureDependencies(hipStream_t stream, hipGraphNode_t* dependencies, size_t numDependencies, unsigned int flags) { HIP_INIT_API(hipStreamUpdateCaptureDependencies, stream, dependencies, numDependencies, flags); + if (!hip::isValid(stream)) { + HIP_RETURN(hipErrorInvalidValue); + } hip::Stream* s = reinterpret_cast(stream); if (s->GetCaptureStatus() == hipStreamCaptureStatusActive) { HIP_RETURN(hipErrorIllegalState);