diff --git a/hipamd/src/hip_mempool.cpp b/hipamd/src/hip_mempool.cpp index b3a3cfec06..bb63edc9f7 100644 --- a/hipamd/src/hip_mempool.cpp +++ b/hipamd/src/hip_mempool.cpp @@ -92,11 +92,19 @@ hipError_t hipMallocAsync(void** dev_ptr, size_t size, hipStream_t stream) { *dev_ptr = nullptr; HIP_RETURN(hipSuccess); } + hip::Stream* s = reinterpret_cast(stream); auto hip_stream = (stream == nullptr || stream == hipStreamLegacy) ? - hip::getCurrentDevice()->NullStream() : reinterpret_cast(stream); + hip::getCurrentDevice()->NullStream() : s; auto device = hip_stream->GetDevice(); auto mem_pool = device->GetCurrentMemoryPool(); + // Return error if any stream other than the current stream is in capture mode + if (device->StreamCaptureBlocking()) { + if (s->GetCaptureStatus() != hipStreamCaptureStatusActive) { + return hipErrorStreamCaptureUnsupported; + } + } + STREAM_CAPTURE(hipMallocAsync, stream, reinterpret_cast(mem_pool), size, dev_ptr); *dev_ptr = mem_pool->AllocateMemory(size, hip_stream); @@ -138,17 +146,28 @@ class FreeAsyncCommand : public amd::Command { // ================================================================================================ hipError_t hipFreeAsync(void* dev_ptr, hipStream_t stream) { HIP_INIT_API(hipFreeAsync, dev_ptr, stream); - if (dev_ptr == nullptr) { - HIP_RETURN(hipErrorInvalidValue); - } + if (!hip::isValid(stream)) { HIP_RETURN(hipErrorInvalidHandle); } - STREAM_CAPTURE(hipFreeAsync, stream, dev_ptr); - + hip::Stream* s = reinterpret_cast(stream); auto hip_stream = (stream == nullptr || stream == hipStreamLegacy) ? - hip::getCurrentDevice()->NullStream(): reinterpret_cast(stream); + hip::getCurrentDevice()->NullStream(): s; + + auto device = hip_stream->GetDevice(); + // Return error if any stream other than the current stream is in capture mode + if (device->StreamCaptureBlocking()) { + if (s->GetCaptureStatus() != hipStreamCaptureStatusActive) { + return hipErrorStreamCaptureUnsupported; + } + } + + if (dev_ptr == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + + STREAM_CAPTURE(hipFreeAsync, stream, dev_ptr); hip::Event* event = nullptr; bool graph_in_use = false;