update
This commit is contained in:
+162
-97
@@ -1,11 +1,11 @@
|
||||
#!/usr/bin/python
|
||||
import os, sys, re
|
||||
|
||||
HEADER = "hip_cbstr.h"
|
||||
HEADER = "hip_prof_str.h"
|
||||
REC_MAX_LEN = 1024
|
||||
|
||||
#############################################################
|
||||
# Filling API map with API call name and args
|
||||
# Normalizing API arguments
|
||||
def filtr_api_args(args_str):
|
||||
args_str = re.sub(r'^\s*', r'', args_str);
|
||||
args_str = re.sub(r'\s*$', r'', args_str);
|
||||
@@ -15,6 +15,7 @@ def filtr_api_args(args_str):
|
||||
args_str = re.sub(r'(enum|struct) ', '', args_str);
|
||||
return args_str
|
||||
|
||||
# Creating a list of arguments [(type, name), ...]
|
||||
def list_api_args(args_str):
|
||||
args_list = []
|
||||
for arg_pair in args_str.split(','):
|
||||
@@ -30,6 +31,7 @@ def list_api_args(args_str):
|
||||
args_list.append((arg_type, arg_name))
|
||||
return args_list;
|
||||
|
||||
# Creating arguments string "type0, type1, ..."
|
||||
def filtr_api_types(args_str):
|
||||
args_str = filtr_api_args(args_str)
|
||||
args_list = list_api_args(args_str)
|
||||
@@ -38,11 +40,13 @@ def filtr_api_types(args_str):
|
||||
types_str += arg_tuple[0] + ', '
|
||||
return types_str
|
||||
|
||||
# Normalizing types
|
||||
def norm_api_types(types_str):
|
||||
types_str = re.sub(r'uint32_t,', r'unsigned int,', types_str)
|
||||
types_str = re.sub(r'unsigned,', r'unsigned int,', types_str)
|
||||
return types_str
|
||||
|
||||
# Creating options list [name0, name1, ...]
|
||||
def filtr_api_opts(args_str):
|
||||
args_str = filtr_api_args(args_str)
|
||||
args_list = list_api_args(args_str)
|
||||
@@ -51,18 +55,19 @@ def filtr_api_opts(args_str):
|
||||
opts_list.append(arg_tuple[1])
|
||||
return opts_list
|
||||
|
||||
# Filling API map with API call name and args
|
||||
def fill_api_map(out, api_name, args_str):
|
||||
args_str = filtr_api_args(args_str)
|
||||
out[api_name + '.a'] = args_str
|
||||
out[api_name] = list_api_args(args_str)
|
||||
|
||||
def patch_args(api_opts, eta_opts, content):
|
||||
api_opts_list = api_opts.split(',');
|
||||
eta_opts_list = eta_opts.split(',');
|
||||
length = len(api_opts_list)
|
||||
for index in range(0, length):
|
||||
content = re.sub(' ' + api_opts_list[index], ' ' + eta_opts_list[index], content)
|
||||
return content
|
||||
#def patch_args(api_opts, eta_opts, content):
|
||||
# api_opts_list = api_opts.split(',');
|
||||
# eta_opts_list = eta_opts.split(',');
|
||||
# length = len(api_opts_list)
|
||||
# for index in range(0, length):
|
||||
# content = re.sub(' ' + api_opts_list[index], ' ' + eta_opts_list[index], content)
|
||||
# return content
|
||||
#############################################################
|
||||
# Parsing API header
|
||||
# hipError_t hipSetupArgument(const void* arg, size_t size, size_t offset);
|
||||
@@ -103,24 +108,40 @@ def parse_api(inp_file, out):
|
||||
# Patching API implementation
|
||||
# hipError_t hipSetupArgument(const void* arg, size_t size, size_t offset) {
|
||||
# HIP_INIT_CB(hipSetupArgument, arg, size, offset);
|
||||
# inp_file - input implementation source file
|
||||
# api_map - input public API map [<api name>] => <api args>
|
||||
# out - output map [<api name>] => <api args>
|
||||
def patch_content(inp_file, api_map, out):
|
||||
beg_pattern = re.compile("^hipError_t");
|
||||
api_pattern = re.compile("^hipError_t\s+([^\(]+)\(([^\)]*)\)\s*{");
|
||||
# API definition begin pattern
|
||||
beg_pattern = re.compile("^(hipError_t|const char\s*\*\s+[_\w]+\()");
|
||||
# API definition complete pattern
|
||||
api_pattern = re.compile("^(hipError_t|const char\s*\*)\s+([^\(]+)\(([^\)]*)\)\s*{");
|
||||
# API init macro pattern
|
||||
init_pattern = re.compile("^\s*HIP_INIT[_\w]*_API\(([^,]+)(,|\))");
|
||||
target_pattern = re.compile("^(\s*HIP_INIT[^\(]*)(_API\()(.*)\);\s*$");
|
||||
|
||||
# Open input file
|
||||
inp = open(inp_file, 'r')
|
||||
|
||||
# API name
|
||||
api_name = ""
|
||||
# Valid public API found flag
|
||||
api_valid = 0
|
||||
api_valid_always = 1
|
||||
|
||||
# Input file patched content
|
||||
content = ''
|
||||
# Sub content for found API defiition
|
||||
sub_content = ''
|
||||
# Current record, accumulating several API definition related lines
|
||||
record = ''
|
||||
# Current input file line number
|
||||
line_num = -1
|
||||
# API beginning found flag
|
||||
found = 0
|
||||
|
||||
# Reading input file
|
||||
for line in inp.readlines():
|
||||
# Accumulating record
|
||||
record += re.sub(r'^\s+', r' ', line[:-1])
|
||||
line_num += 1
|
||||
|
||||
@@ -128,50 +149,91 @@ def patch_content(inp_file, api_map, out):
|
||||
print "Error: bad record \"" + record + "\"\nfile '" + inp_file + ", line (" + str(line_num) + ")"
|
||||
break;
|
||||
|
||||
# Looking for API begin
|
||||
if beg_pattern.match(record): found = 1
|
||||
|
||||
if found != 0:
|
||||
# Matching complete API definition
|
||||
if found == 1:
|
||||
record = re.sub("\s__dparm\([^\)]*\)", '', record);
|
||||
m = api_pattern.match(record)
|
||||
# Checking if complete API matched
|
||||
if m:
|
||||
found = 0
|
||||
api_name = m.group(1);
|
||||
found = 2
|
||||
api_name = m.group(2);
|
||||
# Checking if API name is in the API map
|
||||
if api_name in api_map:
|
||||
api_args = filtr_api_args(m.group(2))
|
||||
# Getting API arguments
|
||||
api_args = m.group(3)
|
||||
# Getting etalon arguments from the API map
|
||||
eta_args = api_map[api_name]
|
||||
# Normalizing API arguments
|
||||
api_types = filtr_api_types(api_args)
|
||||
# Normalizing etalon arguments
|
||||
eta_types = filtr_api_types(eta_args)
|
||||
# Comparing API and etalon arguments
|
||||
# Normalizing types if not matching
|
||||
api_types_n = api_types
|
||||
eta_types_n = eta_types
|
||||
if api_types != eta_types:
|
||||
api_types = norm_api_types(api_types)
|
||||
eta_types = norm_api_types(eta_types)
|
||||
if api_types == eta_types:
|
||||
api_types_n = norm_api_types(api_types)
|
||||
eta_types_n = norm_api_types(eta_types)
|
||||
# Comparing API and etalon arguments
|
||||
if api_types_n == eta_types_n:
|
||||
# API is already found
|
||||
if api_name in out:
|
||||
print "Error: API redefined \"" + api_name + "\", record \"" + record + "\"\nfile '" + inp_file + ", line (" + str(line_num) + ")"
|
||||
print "Error: API redefined \"" + api_name + "\", record \"" + record + "\"\nfile '" + inp_file + "', line (" + str(line_num) + ")"
|
||||
sys.exit(1)
|
||||
# Set valid public API found flag
|
||||
api_valid = 1
|
||||
# Set output API map with API arguments
|
||||
out[api_name] = filtr_api_opts(api_args)
|
||||
elif not api_name in out:
|
||||
api_diff = '\t\t' + inp_file + " line(" + str(line_num) + ")\n\t\tapi: " + api_types + "\n\t\teta: " + eta_types
|
||||
else:
|
||||
# Warning about mismatched API, possible non public overloaded version
|
||||
api_diff = '\t\t' + inp_file + " line(" + str(line_num) + ")\n\t\tapi: " + api_types_n + "\n\t\teta: " + eta_types_n
|
||||
print "\t" + api_name + ':\n' + api_diff + '\n'
|
||||
|
||||
content += sub_content
|
||||
sub_content = ''
|
||||
else:
|
||||
sub_content += line
|
||||
continue
|
||||
|
||||
if (api_valid_always == 1) || (api_valid == 1):
|
||||
m = target_pattern.match(line)
|
||||
# API found action
|
||||
if found == 2:
|
||||
# Looking for INIT macro
|
||||
m = init_pattern.match(line)
|
||||
if m:
|
||||
api_valid = 0
|
||||
if not re.search("_CB_API\(", line):
|
||||
found = 0
|
||||
if api_valid == 1:
|
||||
api_valid = 0
|
||||
print (api_name);
|
||||
api_label = api_name
|
||||
if m.group(3) != "": api_label += ', '
|
||||
line = m.group(1) + '_CB' + m.group(2) + api_label + m.group(3) + ");\n"
|
||||
else:
|
||||
# Registering dummy API for non public API if the name in INIT is not NONE
|
||||
dummy_name = m.group(1)
|
||||
if (not dummy_name in api_map) and (dummy_name != 'NONE'):
|
||||
if dummy_name in out:
|
||||
print "Error: API reinit \"" + api_name + "\", record \"" + record + "\"\nfile '" + inp_file + "', line (" + str(line_num) + ")"
|
||||
sys.exit(1)
|
||||
out[dummy_name] = []
|
||||
elif re.search('}', line):
|
||||
found = 0
|
||||
# Expect INIT macro for valid public API
|
||||
if api_valid == 1:
|
||||
api_valid = 0
|
||||
print "\tAPI init missing \"" + api_name + "\", record \"" + record + "\"\n\tfile '" + inp_file + "', line (" + str(line_num) + ")"
|
||||
if api_name in out:
|
||||
del out[api_name]
|
||||
else:
|
||||
print "Error: API is not in out \"" + api_name + "\", record \"" + record + "\"\nfile '" + inp_file + "', line (" + str(line_num) + ")"
|
||||
sys.exit(1)
|
||||
|
||||
# # Valid API found action
|
||||
# if api_valid == 1:
|
||||
# m = target_pattern.match(line)
|
||||
# if m:
|
||||
# api_valid = 0
|
||||
# if not re.search("_CB_API\(", line):
|
||||
# print (api_name);
|
||||
# api_label = api_name
|
||||
# if m.group(3) != "": api_label += ', '
|
||||
# line = m.group(1) + '_CB' + m.group(2) + api_label + m.group(3) + ");\n"
|
||||
|
||||
if found != 1: record = ""
|
||||
content += line
|
||||
record = ""
|
||||
|
||||
inp.close()
|
||||
|
||||
@@ -201,6 +263,7 @@ def patch_src(api_map, src_path, src_patt, out):
|
||||
# Usage
|
||||
if (len(sys.argv) < 2):
|
||||
print >>sys.stderr, "Usage:", sys.argv[0], " <input HIP API .h file> [patched srcs path]"
|
||||
print >>sys.stderr, " $ hipap.py hip/include/hip/hcc_detail/hip_runtime_api.h hip/src"
|
||||
sys.exit(1)
|
||||
|
||||
# API header file given as an argument
|
||||
@@ -213,6 +276,8 @@ if not os.path.isfile(api_hfile):
|
||||
api_map = {}
|
||||
# API options map
|
||||
opts_map = {}
|
||||
# Private API list
|
||||
priv_lst = []
|
||||
|
||||
# Parsing API header
|
||||
parse_api(api_hfile, api_map)
|
||||
@@ -236,41 +301,42 @@ if len(opts_map) != 0:
|
||||
if not name in opts_map:
|
||||
print "Not found: " + name
|
||||
#############################################################
|
||||
# Generating the header
|
||||
#api_map['hipLaunchKernel'] = [
|
||||
# ('void*', 'kernel'),
|
||||
# ('hipStream_t', 'stream')
|
||||
#]
|
||||
#api_map['hipKernel'] = [
|
||||
# ('const char*', 'name'),
|
||||
# ('uint64_t', 'start'),
|
||||
# ('uint64_t', 'end')
|
||||
#]
|
||||
|
||||
f = open(HEADER, 'w')
|
||||
f.write('// automatically generated sources\n')
|
||||
f.write('#ifndef _HIP_CBSTR_H\n');
|
||||
f.write('#define _HIP_CBSTR_H\n');
|
||||
f.write('#ifndef _HIP_PROF_STR_H\n');
|
||||
f.write('#define _HIP_PROF_STR_H\n');
|
||||
f.write('#include <sstream>\n');
|
||||
f.write('#include <string>\n');
|
||||
|
||||
# Generating the callbacks function type
|
||||
f.write('\n// HIP API callbacks function type\n\
|
||||
struct hip_cb_data_t;\n\
|
||||
struct hip_act_record_t;\n\
|
||||
typedef void (*hip_cb_fun_t)(uint32_t domain, uint32_t cid, const void* data, void* arg);\n\
|
||||
typedef void (*hip_cb_act_t)(uint32_t cid, hip_act_record_t** record, const void* data, void* arg);\n\
|
||||
typedef void (*hip_cb_async_t)(uint32_t op_id, void* record, void* arg);\n\
|
||||
')
|
||||
# Generating dummy macro for non-public API
|
||||
f.write('\n// Dummy API primitives\n')
|
||||
f.write('#define INIT_NONE_CB_ARGS_DATA(cb_data) {};\n')
|
||||
for name in opts_map:
|
||||
if not name in api_map:
|
||||
opts_lst = opts_map[name]
|
||||
if len(opts_lst) != 0:
|
||||
print ("Error: bad dummy API \"" + name + "\", args: ", opts_lst)
|
||||
sys.exit(1)
|
||||
f.write('#define INIT_'+ name + '_CB_ARGS_DATA(cb_data) {};\n')
|
||||
priv_lst.append(name)
|
||||
|
||||
for name in priv_lst:
|
||||
print "Private: ", name
|
||||
|
||||
# Generating the callbacks ID enumaration
|
||||
f.write('\n// HIP API callbacks ID enumaration\n')
|
||||
f.write('enum hip_cb_id_t {\n')
|
||||
f.write('enum hip_api_id_t {\n')
|
||||
cb_id = 0
|
||||
for name in api_map.keys():
|
||||
f.write(' HIP_API_ID_' + name + ' = ' + str(cb_id) + ',\n')
|
||||
cb_id += 1
|
||||
f.write(' HIP_API_ID_NUMBER = ' + str(cb_id) + ',\n')
|
||||
f.write(' HIP_API_ID_ANY = ' + str(cb_id + 1) + ',\n')
|
||||
f.write('\n')
|
||||
f.write(' HIP_API_ID_NONE = HIP_API_ID_NUMBER,\n')
|
||||
for name in priv_lst:
|
||||
f.write(' HIP_API_ID_' + name + ' = HIP_API_ID_NUMBER,\n')
|
||||
f.write('};\n')
|
||||
|
||||
# Generating the callbacks ID enumaration
|
||||
@@ -286,7 +352,7 @@ f.write('};\n')
|
||||
# Generating the callbacks data structure
|
||||
f.write('\n// HIP API callbacks data structure\n')
|
||||
f.write(
|
||||
'struct hip_cb_data_t {\n' +
|
||||
'struct hip_api_data_t {\n' +
|
||||
' uint64_t correlation_id;\n' +
|
||||
' uint32_t phase;\n' +
|
||||
' union {\n'
|
||||
@@ -306,17 +372,16 @@ f.write(
|
||||
f.write('\n// HIP API callbacks args data filling macros\n')
|
||||
for name, args in api_map.items():
|
||||
f.write('#define INIT_' + name + '_CB_ARGS_DATA(cb_data) { \\\n')
|
||||
opts_list = []
|
||||
if name in opts_map:
|
||||
opts_list = opts_map[name]
|
||||
for ind in range(0, len(args)):
|
||||
arg_tuple = args[ind]
|
||||
arg_type = arg_tuple[0]
|
||||
fld_name = arg_tuple[1]
|
||||
arg_name = arg_tuple[1]
|
||||
if len(opts_list) != 0:
|
||||
if len(args) != len(opts_list):
|
||||
print ("Error: \"" + name + "\" API args and opts mismatch, args: ", args, ", opts: ", opts_list)
|
||||
for ind in range(0, len(args)):
|
||||
arg_tuple = args[ind]
|
||||
arg_type = arg_tuple[0]
|
||||
fld_name = arg_tuple[1]
|
||||
arg_name = opts_list[ind]
|
||||
f.write(' cb_data.args.' + name + '.' + fld_name + ' = (' + arg_type + ')' + arg_name + '; \\\n')
|
||||
f.write(' cb_data.args.' + name + '.' + fld_name + ' = (' + arg_type + ')' + arg_name + '; \\\n')
|
||||
f.write('};\n')
|
||||
f.write('#define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(cb_data)\n')
|
||||
|
||||
@@ -324,7 +389,7 @@ f.write('#define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(c
|
||||
f.write('\n')
|
||||
f.write('#if 0\n')
|
||||
f.write('// HIP API string method, method name and parameters\n')
|
||||
f.write('const char* hipApiString(hip_cb_id_t id, const hip_cb_data_t* data) {\n')
|
||||
f.write('const char* hipApiString(hip_api_id_t id, const hip_api_data_t* data) {\n')
|
||||
f.write(' std::ostringstream oss;\n')
|
||||
f.write(' switch (id) {\n')
|
||||
for name, args in api_map.items():
|
||||
@@ -343,35 +408,35 @@ f.write(' return strdup(oss.str().c_str());\n')
|
||||
f.write('};\n')
|
||||
f.write('#endif\n')
|
||||
|
||||
# Generating the activity record type
|
||||
f.write('\
|
||||
\n\
|
||||
// HIP API activity record type\n\
|
||||
// Base record type\n\
|
||||
struct hip_act_record_t {\n\
|
||||
uint32_t domain; // activity domain id\n\
|
||||
uint32_t op_id; // operation id, dispatch/copy/barrier\n\
|
||||
uint32_t activity_kind; // activity kind\n\
|
||||
uint64_t correlation_id; // activity correlation ID\n\
|
||||
uint64_t begin_ns; // host begin timestamp, nano-seconds\n\
|
||||
uint64_t end_ns; // host end timestamp, nano-seconds\n\
|
||||
};\n\
|
||||
// Async record type\n\
|
||||
struct hip_async_record_t : hip_act_record_t {\n\
|
||||
int device_id;\n\
|
||||
uint64_t stream_id;\n\
|
||||
};\n\
|
||||
// Dispatch record type\n\
|
||||
struct hip_dispatch_record_t : hip_async_record_t {};\n\
|
||||
// Barrier record type\n\
|
||||
struct hip_barrier_record_t : hip_async_record_t {};\n\
|
||||
// Memcpy record type\n\
|
||||
struct hip_copy_record_t : hip_async_record_t {\n\
|
||||
size_t bytes;\n\
|
||||
};\n\
|
||||
// Generic async operation record\n\
|
||||
typedef hip_copy_record_t hip_ops_record_t;\n\
|
||||
')
|
||||
# # Generating the activity record type
|
||||
# f.write('\
|
||||
# \n\
|
||||
# // HIP API activity record type\n\
|
||||
# // Base record type\n\
|
||||
# struct hip_act_record_t {\n\
|
||||
# uint32_t domain; // activity domain id\n\
|
||||
# uint32_t op_id; // operation id, dispatch/copy/barrier\n\
|
||||
# uint32_t activity_kind; // activity kind\n\
|
||||
# uint64_t correlation_id; // activity correlation ID\n\
|
||||
# uint64_t begin_ns; // host begin timestamp, nano-seconds\n\
|
||||
# uint64_t end_ns; // host end timestamp, nano-seconds\n\
|
||||
# };\n\
|
||||
# // Async record type\n\
|
||||
# struct hip_async_record_t : hip_act_record_t {\n\
|
||||
# int device_id;\n\
|
||||
# uint64_t stream_id;\n\
|
||||
# };\n\
|
||||
# // Dispatch record type\n\
|
||||
# struct hip_dispatch_record_t : hip_async_record_t {};\n\
|
||||
# // Barrier record type\n\
|
||||
# struct hip_barrier_record_t : hip_async_record_t {};\n\
|
||||
# // Memcpy record type\n\
|
||||
# struct hip_copy_record_t : hip_async_record_t {\n\
|
||||
# size_t bytes;\n\
|
||||
# };\n\
|
||||
# // Generic async operation record\n\
|
||||
# typedef hip_copy_record_t hip_ops_record_t;\n\
|
||||
# ')
|
||||
|
||||
# # Generating the callbacks table
|
||||
# f.write('\n// HIP API callbacks table\n')
|
||||
@@ -437,7 +502,7 @@ typedef hip_copy_record_t hip_ops_record_t;\n\
|
||||
# api_callbacks_spawner_t api_callbacks_spawner(cb_data); \n\
|
||||
# ')
|
||||
|
||||
f.write('#endif // _HIP_CBSTR\n');
|
||||
f.write('#endif // _HIP_PROF_STR_H\n');
|
||||
|
||||
print "Header '" + HEADER + "' is generated"
|
||||
#############################################################
|
||||
|
||||
Reference in New Issue
Block a user