From 70b20857e90ffffd8455775d505aa161acdcf2eb Mon Sep 17 00:00:00 2001 From: Satyanvesh Dittakavi Date: Tue, 19 Nov 2024 14:34:08 +0000 Subject: [PATCH] SWDEV-494808 - Do not allow hipMallocAsync/hipFreeAsync when another stream is capturing hipMallocAsync/hipFreeAsync APIs should return error stating operation is not supported, if a stream is actively capturing and is different from the passed stream Change-Id: I2a1b8260c5eb22d99a936ac529d6788a83f81a17 --- hipamd/src/hip_mempool.cpp | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/hipamd/src/hip_mempool.cpp b/hipamd/src/hip_mempool.cpp index b3a3cfec06..bb63edc9f7 100644 --- a/hipamd/src/hip_mempool.cpp +++ b/hipamd/src/hip_mempool.cpp @@ -92,11 +92,19 @@ hipError_t hipMallocAsync(void** dev_ptr, size_t size, hipStream_t stream) { *dev_ptr = nullptr; HIP_RETURN(hipSuccess); } + hip::Stream* s = reinterpret_cast(stream); auto hip_stream = (stream == nullptr || stream == hipStreamLegacy) ? - hip::getCurrentDevice()->NullStream() : reinterpret_cast(stream); + hip::getCurrentDevice()->NullStream() : s; auto device = hip_stream->GetDevice(); auto mem_pool = device->GetCurrentMemoryPool(); + // Return error if any stream other than the current stream is in capture mode + if (device->StreamCaptureBlocking()) { + if (s->GetCaptureStatus() != hipStreamCaptureStatusActive) { + return hipErrorStreamCaptureUnsupported; + } + } + STREAM_CAPTURE(hipMallocAsync, stream, reinterpret_cast(mem_pool), size, dev_ptr); *dev_ptr = mem_pool->AllocateMemory(size, hip_stream); @@ -138,17 +146,28 @@ class FreeAsyncCommand : public amd::Command { // ================================================================================================ hipError_t hipFreeAsync(void* dev_ptr, hipStream_t stream) { HIP_INIT_API(hipFreeAsync, dev_ptr, stream); - if (dev_ptr == nullptr) { - HIP_RETURN(hipErrorInvalidValue); - } + if (!hip::isValid(stream)) { HIP_RETURN(hipErrorInvalidHandle); } - STREAM_CAPTURE(hipFreeAsync, stream, dev_ptr); - + hip::Stream* s = reinterpret_cast(stream); auto hip_stream = (stream == nullptr || stream == hipStreamLegacy) ? - hip::getCurrentDevice()->NullStream(): reinterpret_cast(stream); + hip::getCurrentDevice()->NullStream(): s; + + auto device = hip_stream->GetDevice(); + // Return error if any stream other than the current stream is in capture mode + if (device->StreamCaptureBlocking()) { + if (s->GetCaptureStatus() != hipStreamCaptureStatusActive) { + return hipErrorStreamCaptureUnsupported; + } + } + + if (dev_ptr == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + + STREAM_CAPTURE(hipFreeAsync, stream, dev_ptr); hip::Event* event = nullptr; bool graph_in_use = false;