diff --git a/catch/unit/graph/hipStreamBeginCapture.cc b/catch/unit/graph/hipStreamBeginCapture.cc index ad15f2f451..bb781f7739 100644 --- a/catch/unit/graph/hipStreamBeginCapture.cc +++ b/catch/unit/graph/hipStreamBeginCapture.cc @@ -55,6 +55,33 @@ static void hostNodeCallback(void* data) { gCbackIter++; } +// Functors for testing concurrent stream capture codes +class streamSync { + public: + void operator()(hipStream_t stream) { result_status = hipStreamSynchronize(stream); } + hipError_t result_status = hipSuccess; +}; +class streamQuery { + public: + void operator()(hipStream_t stream) { result_status = hipStreamQuery(stream); } + hipError_t result_status = hipSuccess; +}; +class deviceSync { + public: + void operator()() { result_status = hipDeviceSynchronize(); } + hipError_t result_status = hipSuccess; +}; +class eventSync { + public: + void operator()(hipEvent_t event) { auto status = hipEventSynchronize(event); } + hipError_t result_status = hipSuccess; +}; +class eventQuery { + public: + void operator()(hipEvent_t event) { auto status = hipEventQuery(event); } + hipError_t result_status = hipSuccess; +}; + template void captureStreamAndLaunchGraph(F graphFunc, hipStreamCaptureMode mode, hipStream_t stream) { constexpr size_t N = 1000000; @@ -799,15 +826,16 @@ TEST_CASE("Unit_hipStreamBeginCapture_Positive_CapturingMultGraphsFrom1Strm") { * - HIP_VERSION >= 5.2 */ TEST_CASE("Unit_hipStreamBeginCapture_Negative_CheckingSyncDuringCapture") { - StreamGuard stream_guard(Streams::created); + const hipStreamCaptureMode captureMode = GENERATE( + hipStreamCaptureModeGlobal, hipStreamCaptureModeThreadLocal, hipStreamCaptureModeRelaxed); + const unsigned int stream_flag = GENERATE(hipStreamDefault, hipStreamNonBlocking); + + StreamGuard stream_guard(Streams::created, stream_flag); hipStream_t stream = stream_guard.stream(); EventsGuard events_guard(1); hipEvent_t e = events_guard[0]; - const hipStreamCaptureMode captureMode = GENERATE( - hipStreamCaptureModeGlobal, hipStreamCaptureModeThreadLocal, hipStreamCaptureModeRelaxed); - HIP_CHECK(hipStreamBeginCapture(stream, captureMode)); SECTION("Synchronize stream during capture") { HIP_CHECK_ERROR(hipStreamSynchronize(stream), hipErrorStreamCaptureUnsupported); @@ -828,6 +856,80 @@ TEST_CASE("Unit_hipStreamBeginCapture_Negative_CheckingSyncDuringCapture") { } } +/** + * Test Description + * ------------------------ + * - Test to verify synchronization during stream capture returns an error: + * -# Synchronize stream during capture + * -# Synchronize device during capture + * -# Synchronize event during capture + * -# Query stream during capture + * -# Query for an event during capture + * Test source + * ------------------------ + * - catch\unit\graph\hipStreamBeginCapture.cc + * Test requirements + * ------------------------ + * - HIP_VERSION >= 5.2 + */ +TEST_CASE("Unit_hipStreamBeginCapture_Negative_Concurrent_CheckingSyncDuringCapture") { + const hipStreamCaptureMode captureMode = GENERATE( + hipStreamCaptureModeGlobal, hipStreamCaptureModeThreadLocal, hipStreamCaptureModeRelaxed); + const unsigned int stream_flag = GENERATE(hipStreamDefault, hipStreamNonBlocking); + + StreamGuard stream_guard(Streams::created, stream_flag); + StreamGuard concurrent_stream_guard(Streams::created); + hipStream_t stream = stream_guard.stream(); + hipStream_t concurrent_stream = concurrent_stream_guard.stream(); + + EventsGuard events_guard(1); + hipEvent_t e = events_guard[0]; + + HIP_CHECK(hipStreamBeginCapture(stream, captureMode)); + SECTION("Synchronize stream during capture") { + streamSync func; + hipError_t expected = hipSuccess; + if (captureMode == hipStreamCaptureModeGlobal) expected = hipErrorStreamCaptureUnsupported; + + std::thread t(std::ref(func), concurrent_stream); + t.join(); + REQUIRE(func.result_status == expected); + } + SECTION("Query stream during capture") { + streamQuery func; + hipError_t expected = hipSuccess; + if (captureMode == hipStreamCaptureModeGlobal) expected = hipErrorStreamCaptureUnsupported; + + std::thread t(std::ref(func), concurrent_stream); + t.join(); + REQUIRE(func.result_status == expected); + } + SECTION("Synchronize device during capture") { + deviceSync func; + hipError_t expected = hipErrorStreamCaptureUnsupported; + + std::thread t(std::ref(func)); + t.join(); + REQUIRE(func.result_status == expected); + } + SECTION("Synchronize event during capture") { + eventSync func; + hipError_t expected = hipSuccess; + + std::thread t(std::ref(func), e); + t.join(); + REQUIRE(func.result_status == expected); + } + SECTION("Query for an event during capture") { + eventQuery func; + hipError_t expected = hipSuccess; + + std::thread t(std::ref(func), e); + t.join(); + REQUIRE(func.result_status == expected); + } +} + /** * Test Description * ------------------------