adding queue destroy callback

This commit is contained in:
Evgeny
2018-02-21 19:59:47 -06:00
parent a9a5119399
commit 33b3546fe4
6 ha cambiato i file con 92 aggiunte e 49 eliminazioni
+16 -7
Vedi File
@@ -207,13 +207,15 @@ hsa_status_t rocprofiler_reset(rocprofiler_t* context, // [in] profiling contex
uint32_t group_index); // group index
////////////////////////////////////////////////////////////////////////////////
// Runtime API observer
// Queue callbacks
//
// Runtime API observer is called on enter and exit for the API
// Queue callbacks for initiating profiling per kernel dispatch and to wait
// the profiling data on the queue destroy.
// Profiling callback data
typedef struct {
hsa_agent_t agent;
const hsa_queue_t* queue;
uint64_t queue_index;
uint64_t kernel_object;
const char* kernel_name;
@@ -226,12 +228,19 @@ typedef hsa_status_t (*rocprofiler_callback_t)(
void* user_data, // [in/out] user data passed to the callback
rocprofiler_group_t* group); // [out] profiling group
// Set/remove kernel dispatch observer
hsa_status_t rocprofiler_set_dispatch_callback(
rocprofiler_callback_t callback, // observer callback
void* data); // [in/out] passed callback data
// Queue callbacks
typedef struct {
rocprofiler_callback_t dispatch; // dispatch callback
hsa_status_t (*destroy)(hsa_queue_t* queue, void* data); // destroy callback
} rocprofiler_queue_callbacks_t;
hsa_status_t rocprofiler_remove_dispatch_callback();
// Set queue callbacks
hsa_status_t rocprofiler_set_queue_callbacks(
rocprofiler_queue_callbacks_t callbacks, // callbacks
void* data); // [in/out] passed callbacks data
// Remove queue callbacks
hsa_status_t rocprofiler_remove_queue_callbacks();
////////////////////////////////////////////////////////////////////////////////
// Start/stop profiling
+8 -1
Vedi File
@@ -329,7 +329,7 @@ class Context {
const profile_vector_t profile_vector = GetProfiles(group_index);
for (auto& tuple : profile_vector) {
// Wait for stop packet to complete
const uint64_t timeout = UINT64_MAX;
const uint64_t timeout = timeout_;
bool complete = false;
while (!complete) {
const hsa_signal_value_t signal_value = hsa_signal_wait_scacquire(tuple.completion_signal, HSA_SIGNAL_CONDITION_LT, 1, timeout,
@@ -372,6 +372,10 @@ class Context {
}
}
static void SetTimeout(uint64_t timeout) {
timeout_ = timeout;
}
private:
// Getting profling packets
profile_vector_t GetProfiles(const uint32_t& index) {
@@ -469,6 +473,9 @@ class Context {
return info;
}
// Profiling data waiting timeout
static uint64_t timeout_;
// GPU handel
const hsa_agent_t agent_;
const util::AgentInfo* agent_info_;
+3 -2
Vedi File
@@ -7,8 +7,9 @@ void InterceptQueue::HsaIntercept(HsaApiTable* table) {
}
InterceptQueue::mutex_t InterceptQueue::mutex_;
rocprofiler_callback_t InterceptQueue::on_dispatch_cb_ = NULL;
void* InterceptQueue::on_dispatch_cb_data_ = NULL;
rocprofiler_callback_t InterceptQueue::dispatch_callback_ = NULL;
InterceptQueue::queue_callback_t InterceptQueue::destroy_callback_ = NULL;
void* InterceptQueue::callback_data_ = NULL;
InterceptQueue::obj_map_t* InterceptQueue::obj_map_ = NULL;
const char* InterceptQueue::kernel_none_ = "";
} // namespace rocprofiler
+27 -17
Vedi File
@@ -23,6 +23,7 @@ class InterceptQueue {
public:
typedef std::recursive_mutex mutex_t;
typedef std::map<uint64_t, InterceptQueue*> obj_map_t;
typedef hsa_status_t (*queue_callback_t)(hsa_queue_t*, void* data);
static void HsaIntercept(HsaApiTable* table);
@@ -39,7 +40,7 @@ class InterceptQueue {
ProxyQueue* proxy = ProxyQueue::Create(agent, size, type, callback, data, private_segment_size,
group_segment_size, queue, &status);
if (status == HSA_STATUS_SUCCESS) {
InterceptQueue* obj = new InterceptQueue(agent, proxy);
InterceptQueue* obj = new InterceptQueue(agent, *queue, proxy);
(*obj_map_)[(uint64_t)(*queue)] = obj;
status = proxy->SetInterceptCB(OnSubmitCB, obj);
}
@@ -53,9 +54,15 @@ class InterceptQueue {
std::lock_guard<mutex_t> lck(mutex_);
hsa_status_t status = HSA_STATUS_ERROR;
if (destroy_callback_ != NULL) {
status = destroy_callback_(queue, callback_data_);
if (status != HSA_STATUS_SUCCESS) return status;
}
obj_map_t::iterator it = obj_map_->find((uint64_t)queue);
if (it != obj_map_->end()) {
const InterceptQueue* obj = it->second;
assert(queue == obj->queue_);
delete obj;
obj_map_->erase(it);
status = HSA_STATUS_SUCCESS;
@@ -74,14 +81,17 @@ class InterceptQueue {
bool to_submit = true;
const packet_t* packet = &packets_arr[j];
if ((GetHeaderType(packet) == HSA_PACKET_TYPE_KERNEL_DISPATCH) && (on_dispatch_cb_ != NULL)) {
if ((GetHeaderType(packet) == HSA_PACKET_TYPE_KERNEL_DISPATCH) && (dispatch_callback_ != NULL)) {
rocprofiler_group_t group = {};
const hsa_kernel_dispatch_packet_t* dispatch_packet =
reinterpret_cast<const hsa_kernel_dispatch_packet_t*>(packet);
const char* kernel_name = GetKernelName(dispatch_packet);
rocprofiler_callback_data_t data = {obj->agent_info_->dev_id, user_que_idx,
dispatch_packet->kernel_object, kernel_name};
hsa_status_t status = on_dispatch_cb_(&data, on_dispatch_cb_data_, &group);
rocprofiler_callback_data_t data = {obj->agent_info_->dev_id,
obj->queue_,
user_que_idx,
dispatch_packet->kernel_object,
kernel_name};
hsa_status_t status = dispatch_callback_(&data, callback_data_, &group);
free(const_cast<char*>(kernel_name));
if ((status == HSA_STATUS_SUCCESS) && (group.context != NULL)) {
Context* context = reinterpret_cast<Context*>(group.context);
@@ -112,20 +122,18 @@ class InterceptQueue {
}
}
static void SetDispatchCB(rocprofiler_callback_t on_dispatch_cb, void* data) {
static void SetCallbacks(rocprofiler_callback_t dispatch_callback, queue_callback_t destroy_callback, void* data) {
std::lock_guard<mutex_t> lck(mutex_);
on_dispatch_cb_ = on_dispatch_cb;
on_dispatch_cb_data_ = data;
}
static void UnsetDispatchCB() {
std::lock_guard<mutex_t> lck(mutex_);
on_dispatch_cb_ = NULL;
on_dispatch_cb_data_ = NULL;
callback_data_ = data;
dispatch_callback_ = dispatch_callback;
destroy_callback_ = destroy_callback;
}
private:
InterceptQueue(const hsa_agent_t& agent, ProxyQueue* proxy) : proxy_(proxy) {
InterceptQueue(const hsa_agent_t& agent, hsa_queue_t* const queue, ProxyQueue* proxy) :
queue_(queue),
proxy_(proxy)
{
agent_info_ = util::HsaRsrcFactory::Instance().GetAgentInfo(agent);
}
~InterceptQueue() { ProxyQueue::Destroy(proxy_); }
@@ -164,11 +172,13 @@ class InterceptQueue {
static mutex_t mutex_;
static const packet_word_t header_type_mask = (1ul << HSA_PACKET_HEADER_WIDTH_TYPE) - 1;
static rocprofiler_callback_t on_dispatch_cb_;
static void* on_dispatch_cb_data_;
static rocprofiler_callback_t dispatch_callback_;
static queue_callback_t destroy_callback_;
static void* callback_data_;
static obj_map_t* obj_map_;
static const char* kernel_none_;
hsa_queue_t* const queue_;
ProxyQueue* const proxy_;
const util::AgentInfo* agent_info_;
};
+13 -7
Vedi File
@@ -141,6 +141,12 @@ void UnloadTool() {
CONSTRUCTOR_API void constructor() {
util::Logger::Create();
const char* timeout_str = getenv("ROCP_DATA_TIMEOUT");
if (timeout_str != NULL) {
const uint64_t timeout_val = strtoull(timeout_str, NULL, 0);
Context::SetTimeout(timeout_val);
}
}
DESTRUCTOR_API void destructor() {
@@ -168,6 +174,7 @@ const MetricsDict* GetMetrics(const hsa_agent_t& agent) {
util::Logger::mutex_t util::Logger::mutex_;
util::Logger* util::Logger::instance_ = NULL;
uint64_t Context::timeout_ = 1000;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -341,18 +348,17 @@ PUBLIC_API hsa_status_t rocprofiler_get_metrics(const rocprofiler_t* handle) {
API_METHOD_SUFFIX
}
// Set kernel dispatch observer
PUBLIC_API hsa_status_t rocprofiler_set_dispatch_callback(rocprofiler_callback_t callback,
void* data) {
// Set/remove queue callbacks
PUBLIC_API hsa_status_t rocprofiler_set_queue_callbacks(rocprofiler_queue_callbacks_t callbacks, void* data) {
API_METHOD_PREFIX
rocprofiler::InterceptQueue::SetDispatchCB(callback, data);
rocprofiler::InterceptQueue::SetCallbacks(callbacks.dispatch, callbacks.destroy, data);
API_METHOD_SUFFIX
}
// Set kernel dispatch observer
PUBLIC_API hsa_status_t rocprofiler_remove_dispatch_callback() {
// Remove queue callbacks
PUBLIC_API hsa_status_t rocprofiler_remove_queue_callbacks() {
API_METHOD_PREFIX
rocprofiler::InterceptQueue::UnsetDispatchCB();
rocprofiler::InterceptQueue::SetCallbacks(NULL, NULL, NULL);
API_METHOD_SUFFIX
}
+25 -15
Vedi File
@@ -27,7 +27,7 @@
#define KERNEL_NAME_LEN_MAX 128
// Disoatch callback data type
struct dispatch_data_t {
struct callbacks_data_t {
rocprofiler_feature_t* features;
unsigned feature_count;
unsigned group_index;
@@ -48,7 +48,7 @@ struct context_entry_t {
// Dispatch callbacks and context handlers synchronization
pthread_mutex_t mutex = PTHREAD_RECURSIVE_MUTEX_INITIALIZER_NP;
// Dispatch callback data
dispatch_data_t* dispatch_data = NULL;
callbacks_data_t* callbacks_data = NULL;
// Stored contexts array
typedef std::map<uint32_t, context_entry_t> context_array_t;
context_array_t* context_array = NULL;
@@ -227,8 +227,7 @@ void dump_context(context_entry_t* entry) {
const unsigned feature_count = entry->feature_count;
fprintf(file_handle,
"dispatch[%u], queue_index(%lu), kernel_object(0x%lx), kernel_name(\"%s\"):\n", index,
entry->data.queue_index, entry->data.kernel_object, entry->data.kernel_name);
"dispatch[%u], queue_index(%lu), kernel_name(\"%s\"):\n", index, entry->data.queue_index, entry->data.kernel_name);
rocprofiler_group_t group = entry->group;
status = rocprofiler_group_get_data(&group);
@@ -289,7 +288,7 @@ hsa_status_t dispatch_callback(const rocprofiler_callback_data_t* callback_data,
// HSA status
hsa_status_t status = HSA_STATUS_ERROR;
// Passed tool data
dispatch_data_t* tool_data = reinterpret_cast<dispatch_data_t*>(user_data);
callbacks_data_t* tool_data = reinterpret_cast<callbacks_data_t*>(user_data);
// Profiling context
rocprofiler_t* context = NULL;
// Context entry
@@ -326,6 +325,11 @@ hsa_status_t dispatch_callback(const rocprofiler_callback_data_t* callback_data,
return status;
}
hsa_status_t destroy_callback(hsa_queue_t* queue, void*) {
dump_context_array();
return HSA_STATUS_SUCCESS;
}
static hsa_status_t info_callback(const rocprofiler_info_data_t info, void * arg) {
const char symb = *reinterpret_cast<const char*>(arg);
if (((symb == 'b') && (info.metric.expr == NULL)) ||
@@ -473,14 +477,20 @@ extern "C" PUBLIC_API void OnLoadTool()
}
fflush(stdout);
// Adding dispatch observer
if (feature_count) {
dispatch_data = new dispatch_data_t{};
dispatch_data->features = features;
dispatch_data->feature_count = feature_count;
dispatch_data->group_index = 0;
dispatch_data->file_handle = result_file_handle;
rocprofiler_set_dispatch_callback(dispatch_callback, dispatch_data);
rocprofiler_queue_callbacks_t callbacks_ptrs{0};
callbacks_ptrs.dispatch = dispatch_callback;
callbacks_ptrs.destroy = destroy_callback;
callbacks_data = new callbacks_data_t{};
callbacks_data->features = features;
callbacks_data->feature_count = feature_count;
callbacks_data->group_index = 0;
callbacks_data->file_handle = result_file_handle;
rocprofiler_set_queue_callbacks(callbacks_ptrs, callbacks_data);
}
xml::Xml::Destroy(xml);
@@ -489,7 +499,7 @@ extern "C" PUBLIC_API void OnLoadTool()
// Tool destructor
extern "C" PUBLIC_API void OnUnloadTool() {
// Unregister dispatch callback
rocprofiler_remove_dispatch_callback();
rocprofiler_remove_queue_callbacks();
// Dump stored profiling output data
const bool result_file_opened = (result_prefix != NULL) && (result_file_handle != NULL);
@@ -500,8 +510,8 @@ extern "C" PUBLIC_API void OnUnloadTool() {
if (result_file_opened) fclose(result_file_handle);
// Cleanup
if (dispatch_data != NULL) {
delete[] dispatch_data->features;
delete dispatch_data;
if (callbacks_data != NULL) {
delete[] callbacks_data->features;
delete callbacks_data;
}
}