SWDEV-491303 - Add_concurrent_stream_error_checking

Change-Id: I70268bfa2b97994e0906ded6bae9885ec540292e
Этот коммит содержится в:
Aidan Belton-Schure
2024-12-06 11:30:38 +00:00
коммит произвёл Aidan Belton-Schure
родитель d8f5c2560f
Коммит 5007b3f5e4
+106 -4
Просмотреть файл
@@ -55,6 +55,33 @@ static void hostNodeCallback(void* data) {
gCbackIter++;
}
// Functors for testing concurrent stream capture codes
class streamSync {
public:
void operator()(hipStream_t stream) { result_status = hipStreamSynchronize(stream); }
hipError_t result_status = hipSuccess;
};
class streamQuery {
public:
void operator()(hipStream_t stream) { result_status = hipStreamQuery(stream); }
hipError_t result_status = hipSuccess;
};
class deviceSync {
public:
void operator()() { result_status = hipDeviceSynchronize(); }
hipError_t result_status = hipSuccess;
};
class eventSync {
public:
void operator()(hipEvent_t event) { auto status = hipEventSynchronize(event); }
hipError_t result_status = hipSuccess;
};
class eventQuery {
public:
void operator()(hipEvent_t event) { auto status = hipEventQuery(event); }
hipError_t result_status = hipSuccess;
};
template <typename T, typename F>
void captureStreamAndLaunchGraph(F graphFunc, hipStreamCaptureMode mode, hipStream_t stream) {
constexpr size_t N = 1000000;
@@ -799,15 +826,16 @@ TEST_CASE("Unit_hipStreamBeginCapture_Positive_CapturingMultGraphsFrom1Strm") {
* - HIP_VERSION >= 5.2
*/
TEST_CASE("Unit_hipStreamBeginCapture_Negative_CheckingSyncDuringCapture") {
StreamGuard stream_guard(Streams::created);
const hipStreamCaptureMode captureMode = GENERATE(
hipStreamCaptureModeGlobal, hipStreamCaptureModeThreadLocal, hipStreamCaptureModeRelaxed);
const unsigned int stream_flag = GENERATE(hipStreamDefault, hipStreamNonBlocking);
StreamGuard stream_guard(Streams::created, stream_flag);
hipStream_t stream = stream_guard.stream();
EventsGuard events_guard(1);
hipEvent_t e = events_guard[0];
const hipStreamCaptureMode captureMode = GENERATE(
hipStreamCaptureModeGlobal, hipStreamCaptureModeThreadLocal, hipStreamCaptureModeRelaxed);
HIP_CHECK(hipStreamBeginCapture(stream, captureMode));
SECTION("Synchronize stream during capture") {
HIP_CHECK_ERROR(hipStreamSynchronize(stream), hipErrorStreamCaptureUnsupported);
@@ -828,6 +856,80 @@ TEST_CASE("Unit_hipStreamBeginCapture_Negative_CheckingSyncDuringCapture") {
}
}
/**
* Test Description
* ------------------------
* - Test to verify synchronization during stream capture returns an error:
* -# Synchronize stream during capture
* -# Synchronize device during capture
* -# Synchronize event during capture
* -# Query stream during capture
* -# Query for an event during capture
* Test source
* ------------------------
* - catch\unit\graph\hipStreamBeginCapture.cc
* Test requirements
* ------------------------
* - HIP_VERSION >= 5.2
*/
TEST_CASE("Unit_hipStreamBeginCapture_Negative_Concurrent_CheckingSyncDuringCapture") {
const hipStreamCaptureMode captureMode = GENERATE(
hipStreamCaptureModeGlobal, hipStreamCaptureModeThreadLocal, hipStreamCaptureModeRelaxed);
const unsigned int stream_flag = GENERATE(hipStreamDefault, hipStreamNonBlocking);
StreamGuard stream_guard(Streams::created, stream_flag);
StreamGuard concurrent_stream_guard(Streams::created);
hipStream_t stream = stream_guard.stream();
hipStream_t concurrent_stream = concurrent_stream_guard.stream();
EventsGuard events_guard(1);
hipEvent_t e = events_guard[0];
HIP_CHECK(hipStreamBeginCapture(stream, captureMode));
SECTION("Synchronize stream during capture") {
streamSync func;
hipError_t expected = hipSuccess;
if (captureMode == hipStreamCaptureModeGlobal) expected = hipErrorStreamCaptureUnsupported;
std::thread t(std::ref(func), concurrent_stream);
t.join();
REQUIRE(func.result_status == expected);
}
SECTION("Query stream during capture") {
streamQuery func;
hipError_t expected = hipSuccess;
if (captureMode == hipStreamCaptureModeGlobal) expected = hipErrorStreamCaptureUnsupported;
std::thread t(std::ref(func), concurrent_stream);
t.join();
REQUIRE(func.result_status == expected);
}
SECTION("Synchronize device during capture") {
deviceSync func;
hipError_t expected = hipErrorStreamCaptureUnsupported;
std::thread t(std::ref(func));
t.join();
REQUIRE(func.result_status == expected);
}
SECTION("Synchronize event during capture") {
eventSync func;
hipError_t expected = hipSuccess;
std::thread t(std::ref(func), e);
t.join();
REQUIRE(func.result_status == expected);
}
SECTION("Query for an event during capture") {
eventQuery func;
hipError_t expected = hipSuccess;
std::thread t(std::ref(func), e);
t.join();
REQUIRE(func.result_status == expected);
}
}
/**
* Test Description
* ------------------------