diff --git a/inc/roctracer_roctx.h b/inc/roctracer_roctx.h index 2ce749c2bc..c3acec9e92 100644 --- a/inc/roctracer_roctx.h +++ b/inc/roctracer_roctx.h @@ -32,12 +32,15 @@ THE SOFTWARE. #ifndef INC_ROCTRACER_ROCTX_H_ #define INC_ROCTRACER_ROCTX_H_ +#include // ROC-TX API ID enumeration enum roctx_api_id_t { ROCTX_API_ID_roctxMarkA = 0, ROCTX_API_ID_roctxRangePushA = 1, ROCTX_API_ID_roctxRangePop = 2, + ROCTX_API_ID_roctxRangeStartA = 3, + ROCTX_API_ID_roctxRangeStop = 4, ROCTX_API_ID_NUMBER, }; @@ -45,7 +48,10 @@ enum roctx_api_id_t { // ROCTX callbacks data type typedef struct roctx_api_data_s { union { - const char* message; + struct { + const char* message; + roctx_range_id_t id; + }; struct { const char* message; } roctxMarkA; @@ -55,6 +61,14 @@ typedef struct roctx_api_data_s { struct { const char* message; } roctxRangePop; + struct { + const char* message; + roctx_range_id_t id; + } roctxRangeStartA; + struct { + const char* message; + roctx_range_id_t id; + } roctxRangeStop; } args; } roctx_api_data_t; diff --git a/inc/roctx.h b/inc/roctx.h index 9831c4bdee..fa390bc92a 100644 --- a/inc/roctx.h +++ b/inc/roctx.h @@ -70,6 +70,16 @@ int roctxRangePushA(const char* message); // A negative value is returned on the error. int roctxRangePop(); +// ROCTX range id type +typedef uint64_t roctx_range_id_t; + +// Starts a process range +roctx_range_id_t roctxRangeStartA(const char* message); +#define roctxRangeStart(message) roctxRangeStartA(message) + +// Stop a process range +void roctxRangeStop(roctx_range_id_t id); + #ifdef __cplusplus } // extern "C" block #endif // __cplusplus diff --git a/src/roctx/roctx.cpp b/src/roctx/roctx.cpp index fcf379f0fe..e9a0815f16 100644 --- a/src/roctx/roctx.cpp +++ b/src/roctx/roctx.cpp @@ -106,6 +106,7 @@ extern cb_table_t cb_table; // Logger instantiation roctracer::util::Logger::mutex_t roctracer::util::Logger::mutex_; std::atomic roctracer::util::Logger::instance_{}; +std::atomic roctx_range_counter(0); /////////////////////////////////////////////////////////////////////////////////////////////////// // Public library methods @@ -165,6 +166,33 @@ PUBLIC_API int roctxRangePop() { API_METHOD_CATCH(-1) } +PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) { + API_METHOD_PREFIX + roctx_range_counter++; + + roctx_api_data_t api_data{}; + api_data.args.roctxRangeStartA.message = strdup(message); + api_data.args.roctxRangeStartA.id = roctx_range_counter; + activity_rtapi_callback_t api_callback_fun = NULL; + void* api_callback_arg = NULL; + roctx::cb_table.get(ROCTX_API_ID_roctxRangeStartA, &api_callback_fun, &api_callback_arg); + if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data, api_callback_arg); + + return roctx_range_counter; + API_METHOD_CATCH(-1); +} + +PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) { + API_METHOD_PREFIX + roctx_api_data_t api_data{}; + api_data.args.roctxRangeStop.id = rangeId; + activity_rtapi_callback_t api_callback_fun = NULL; + void* api_callback_arg = NULL; + roctx::cb_table.get(ROCTX_API_ID_roctxRangeStop, &api_callback_fun, &api_callback_arg); + if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data, api_callback_arg); + API_METHOD_SUFFIX_NRET +} + PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg) { for (const auto& entry : *roctx::thread_map) { const auto tid = entry.first; diff --git a/test/MatrixTranspose/MatrixTranspose.cpp b/test/MatrixTranspose/MatrixTranspose.cpp index cb7430b8b2..404e25e505 100644 --- a/test/MatrixTranspose/MatrixTranspose.cpp +++ b/test/MatrixTranspose/MatrixTranspose.cpp @@ -97,6 +97,7 @@ int main() { roctracer_mark("before HIP LaunchKernel"); roctxMark("before hipLaunchKernel"); + int rangeId = roctxRangeStart("hipLaunchKernel range"); roctxRangePush("hipLaunchKernel"); // Lauching kernel from host hipLaunchKernelGGL(matrixTranspose, dim3(WIDTH / THREADS_PER_BLOCK_X, WIDTH / THREADS_PER_BLOCK_Y), @@ -112,6 +113,7 @@ int main() { roctxRangePop(); // for "hipMemcpy" roctxRangePop(); // for "hipLaunchKernel" + roctxRangeStop(rangeId); // CPU MatrixTranspose computation matrixTransposeCPUReference(cpuTransposeMatrix, Matrix, WIDTH); diff --git a/test/tool/tracer_tool.cpp b/test/tool/tracer_tool.cpp index ec12e75866..df2530a510 100644 --- a/test/tool/tracer_tool.cpp +++ b/test/tool/tracer_tool.cpp @@ -203,6 +203,7 @@ struct roctx_trace_entry_t { timestamp_t time; uint32_t pid; uint32_t tid; + roctx_range_id_t rid; const char* message; }; @@ -215,6 +216,7 @@ static inline void roctx_callback_fun( uint32_t domain, uint32_t cid, uint32_t tid, + roctx_range_id_t rid, const char* message) { #if ROCTX_CLOCK_TIME @@ -229,6 +231,7 @@ static inline void roctx_callback_fun( entry->time = time; entry->pid = GetPid(); entry->tid = tid; + entry->rid = rid; entry->message = (message != NULL) ? strdup(message) : NULL; } @@ -240,15 +243,15 @@ void roctx_api_callback( { (void)arg; const roctx_api_data_t* data = reinterpret_cast(callback_data); - roctx_callback_fun(domain, cid, GetTid(), data->args.message); + roctx_callback_fun(domain, cid, GetTid(), data->args.id, data->args.message); } // rocTX Start/Stop callbacks void roctx_range_start_callback(const roctx_range_data_t* data, void* arg) { - roctx_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, data->tid, data->message); + roctx_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, data->tid, 0, data->message); } void roctx_range_stop_callback(const roctx_range_data_t* data, void* arg) { - roctx_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, data->tid, NULL); + roctx_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, data->tid, 0, NULL); } void start_callback() { roctracer::RocTxLoader::Instance().RangeStackIterate(roctx_range_start_callback, NULL); } void stop_callback() { roctracer::RocTxLoader::Instance().RangeStackIterate(roctx_range_stop_callback, NULL); } @@ -262,7 +265,7 @@ void roctx_flush_cb(roctx_trace_entry_t* entry) { const timestamp_t timestamp = entry->time; #endif std::ostringstream os; - os << timestamp << " " << entry->pid << ":" << entry->tid << " " << entry->cid; + os << timestamp << " " << entry->pid << ":" << entry->tid << " " << entry->cid << ":" << entry->rid; if (entry->message != NULL) os << ":\"" << entry->message << "\""; else os << ":\"\""; fprintf(roctx_file_handle, "%s\n", os.str().c_str()); fflush(roctx_file_handle);