Files
rocm-systems/script/hipap.py
T
2019-03-13 21:30:29 -05:00

509 строки
17 KiB
Python
Исполняемый файл

#!/usr/bin/python
import os, sys, re
HEADER = "hip_prof_str.h"
REC_MAX_LEN = 1024
#############################################################
# Normalizing API arguments
def filtr_api_args(args_str):
args_str = re.sub(r'^\s*', r'', args_str);
args_str = re.sub(r'\s*$', r'', args_str);
args_str = re.sub(r'\s*,\s*', r',', args_str);
args_str = re.sub(r'\s+', r' ', args_str);
args_str = re.sub(r'void \*', r'void* ', args_str);
args_str = re.sub(r'(enum|struct) ', '', args_str);
return args_str
# Creating a list of arguments [(type, name), ...]
def list_api_args(args_str):
args_list = []
for arg_pair in args_str.split(','):
arg_pair = re.sub(r'\s+=\s+\S+$','', arg_pair);
m = re.match("^(.*)\s(\S+)$", arg_pair);
if m:
arg_type = m.group(1)
arg_name = m.group(2)
# m = re.match("^(.*_t)\s(.*)$", arg_type)
# if m:
# arg_type = m.group(1)
# arg_name = m.group(2)
args_list.append((arg_type, arg_name))
return args_list;
# Creating arguments string "type0, type1, ..."
def filtr_api_types(args_str):
args_str = filtr_api_args(args_str)
args_list = list_api_args(args_str)
types_str = ''
for arg_tuple in args_list:
types_str += arg_tuple[0] + ', '
return types_str
# Normalizing types
def norm_api_types(types_str):
types_str = re.sub(r'uint32_t,', r'unsigned int,', types_str)
types_str = re.sub(r'unsigned,', r'unsigned int,', types_str)
return types_str
# Creating options list [name0, name1, ...]
def filtr_api_opts(args_str):
args_str = filtr_api_args(args_str)
args_list = list_api_args(args_str)
opts_list = []
for arg_tuple in args_list:
opts_list.append(arg_tuple[1])
return opts_list
# Filling API map with API call name and args
def fill_api_map(out, api_name, args_str):
args_str = filtr_api_args(args_str)
out[api_name + '.a'] = args_str
out[api_name] = list_api_args(args_str)
#def patch_args(api_opts, eta_opts, content):
# api_opts_list = api_opts.split(',');
# eta_opts_list = eta_opts.split(',');
# length = len(api_opts_list)
# for index in range(0, length):
# content = re.sub(' ' + api_opts_list[index], ' ' + eta_opts_list[index], content)
# return content
#############################################################
# Parsing API header
# hipError_t hipSetupArgument(const void* arg, size_t size, size_t offset);
def parse_api(inp_file, out):
beg_pattern = re.compile("^hipError_t");
api_pattern = re.compile("^hipError_t\s+([^\(]+)\(([^\)]*)\)");
end_pattern = re.compile("Texture");
inp = open(inp_file, 'r')
found = 0
record = ""
line_num = -1
for line in inp.readlines():
record += re.sub(r'^\s+', r' ', line[:-1])
line_num += 1
if len(record) > REC_MAX_LEN:
print "Error: bad record \"" + record + "\"\nfile '" + inp_file + ", line (" + str(line_num) + ")"
break;
if beg_pattern.match(record): found = 1
if found != 0:
record = re.sub("\s__dparm\([^\)]*\)", '', record);
m = api_pattern.match(record)
if m:
found = 0
if end_pattern.search(record): break
out[m.group(1)] = filtr_api_args(m.group(2))
else: continue
record = ""
inp.close()
#############################################################
# Patching API implementation
# hipError_t hipSetupArgument(const void* arg, size_t size, size_t offset) {
# HIP_INIT_CB(hipSetupArgument, arg, size, offset);
# inp_file - input implementation source file
# api_map - input public API map [<api name>] => <api args>
# out - output map [<api name>] => <api args>
def patch_content(inp_file, api_map, out):
# API definition begin pattern
beg_pattern = re.compile("^(hipError_t|const char\s*\*\s+[_\w]+\()");
# API definition complete pattern
api_pattern = re.compile("^(hipError_t|const char\s*\*)\s+([^\(]+)\(([^\)]*)\)\s*{");
# API init macro pattern
init_pattern = re.compile("^\s*HIP_INIT[_\w]*_API\(([^,]+)(,|\))");
target_pattern = re.compile("^(\s*HIP_INIT[^\(]*)(_API\()(.*)\);\s*$");
# Open input file
inp = open(inp_file, 'r')
# API name
api_name = ""
# Valid public API found flag
api_valid = 0
# Input file patched content
content = ''
# Sub content for found API defiition
sub_content = ''
# Current record, accumulating several API definition related lines
record = ''
# Current input file line number
line_num = -1
# API beginning found flag
found = 0
# Reading input file
for line in inp.readlines():
# Accumulating record
record += re.sub(r'^\s+', r' ', line[:-1])
line_num += 1
if len(record) > REC_MAX_LEN:
print "Error: bad record \"" + record + "\"\nfile '" + inp_file + ", line (" + str(line_num) + ")"
break;
# Looking for API begin
if beg_pattern.match(record): found = 1
# Matching complete API definition
if found == 1:
record = re.sub("\s__dparm\([^\)]*\)", '', record);
m = api_pattern.match(record)
# Checking if complete API matched
if m:
found = 2
api_name = m.group(2);
# Checking if API name is in the API map
if api_name in api_map:
# Getting API arguments
api_args = m.group(3)
# Getting etalon arguments from the API map
eta_args = api_map[api_name]
# Normalizing API arguments
api_types = filtr_api_types(api_args)
# Normalizing etalon arguments
eta_types = filtr_api_types(eta_args)
# Comparing API and etalon arguments
# Normalizing types if not matching
api_types_n = api_types
eta_types_n = eta_types
if api_types != eta_types:
api_types_n = norm_api_types(api_types)
eta_types_n = norm_api_types(eta_types)
# Comparing API and etalon arguments
if api_types_n == eta_types_n:
# API is already found
if api_name in out:
print "Error: API redefined \"" + api_name + "\", record \"" + record + "\"\nfile '" + inp_file + "', line (" + str(line_num) + ")"
sys.exit(1)
# Set valid public API found flag
api_valid = 1
# Set output API map with API arguments
out[api_name] = filtr_api_opts(api_args)
else:
# Warning about mismatched API, possible non public overloaded version
api_diff = '\t\t' + inp_file + " line(" + str(line_num) + ")\n\t\tapi: " + api_types_n + "\n\t\teta: " + eta_types_n
print "\t" + api_name + ':\n' + api_diff + '\n'
# API found action
if found == 2:
# Looking for INIT macro
m = init_pattern.match(line)
if m:
found = 0
if api_valid == 1:
api_valid = 0
print (api_name);
else:
# Registering dummy API for non public API if the name in INIT is not NONE
dummy_name = m.group(1)
if (not dummy_name in api_map) and (dummy_name != 'NONE'):
if dummy_name in out:
print "Error: API reinit \"" + api_name + "\", record \"" + record + "\"\nfile '" + inp_file + "', line (" + str(line_num) + ")"
sys.exit(1)
out[dummy_name] = []
elif re.search('}', line):
found = 0
# Expect INIT macro for valid public API
if api_valid == 1:
api_valid = 0
print "\tAPI init missing \"" + api_name + "\", record \"" + record + "\"\n\tfile '" + inp_file + "', line (" + str(line_num) + ")"
if api_name in out:
del out[api_name]
else:
print "Error: API is not in out \"" + api_name + "\", record \"" + record + "\"\nfile '" + inp_file + "', line (" + str(line_num) + ")"
sys.exit(1)
# # Valid API found action
# if api_valid == 1:
# m = target_pattern.match(line)
# if m:
# api_valid = 0
# if not re.search("_CB_API\(", line):
# print (api_name);
# api_label = api_name
# if m.group(3) != "": api_label += ', '
# line = m.group(1) + '_CB' + m.group(2) + api_label + m.group(3) + ");\n"
if found != 1: record = ""
content += line
inp.close()
if len(out) != 0:
return content
else:
return ''
# srcs path walk
def patch_src(api_map, src_path, src_patt, out):
pattern = re.compile(src_patt)
src_path = re.sub(r'\s', '', src_path)
for src_dir in src_path.split(':'):
print "Patching " + src_dir + " for '" + src_patt + "'"
for root, dirs, files in os.walk(src_dir):
for fnm in files:
if pattern.search(fnm):
file = root + '/' + fnm
print "\t" + file
content = patch_content(file, api_map, out);
if content != '':
f = open(file, 'w')
f.write(content)
f.close()
#############################################################
# main
# Usage
if (len(sys.argv) < 2):
print >>sys.stderr, "Usage:", sys.argv[0], " <input HIP API .h file> [patched srcs path]"
print >>sys.stderr, " $ hipap.py hip/include/hip/hcc_detail/hip_runtime_api.h hip/src"
sys.exit(1)
# API header file given as an argument
api_hfile = sys.argv[1]
if not os.path.isfile(api_hfile):
print >>sys.stderr, "Error: input file '" + api_hfile + "' not found"
sys.exit(1)
# API declaration map
api_map = {}
# API options map
opts_map = {}
# Private API list
priv_lst = []
# Parsing API header
parse_api(api_hfile, api_map)
# Patching API implementation sources
# Sources path is given as an argument
if len(sys.argv) == 3:
src_path = sys.argv[2]
src_patt = "\.cpp$"
patch_src(api_map, src_path, src_patt, opts_map)
# Converting api map to map of lists
for name in api_map.keys():
args_str = api_map[name];
# Printing not found APIs
if len(opts_map) != 0:
for name in api_map.keys():
args_str = api_map[name];
api_map[name] = list_api_args(args_str)
if not name in opts_map:
print "Not found: " + name
#############################################################
f = open(HEADER, 'w')
f.write('// automatically generated sources\n')
f.write('#ifndef _HIP_PROF_STR_H\n');
f.write('#define _HIP_PROF_STR_H\n');
f.write('#include <sstream>\n');
f.write('#include <string>\n');
# Generating dummy macro for non-public API
f.write('\n// Dummy API primitives\n')
f.write('#define INIT_NONE_CB_ARGS_DATA(cb_data) {};\n')
for name in opts_map:
if not name in api_map:
opts_lst = opts_map[name]
if len(opts_lst) != 0:
print ("Error: bad dummy API \"" + name + "\", args: ", opts_lst)
sys.exit(1)
f.write('#define INIT_'+ name + '_CB_ARGS_DATA(cb_data) {};\n')
priv_lst.append(name)
for name in priv_lst:
print "Private: ", name
# Generating the callbacks ID enumaration
f.write('\n// HIP API callbacks ID enumaration\n')
f.write('enum hip_api_id_t {\n')
cb_id = 0
for name in api_map.keys():
f.write(' HIP_API_ID_' + name + ' = ' + str(cb_id) + ',\n')
cb_id += 1
f.write(' HIP_API_ID_NUMBER = ' + str(cb_id) + ',\n')
f.write(' HIP_API_ID_ANY = ' + str(cb_id + 1) + ',\n')
f.write('\n')
f.write(' HIP_API_ID_NONE = HIP_API_ID_NUMBER,\n')
for name in priv_lst:
f.write(' HIP_API_ID_' + name + ' = HIP_API_ID_NUMBER,\n')
f.write('};\n')
# Generating the callbacks ID enumaration
f.write('\n// Return HIP API string\n')
f.write('static const char* hip_api_name(const uint32_t& id) {\n')
f.write(' switch(id) {\n')
for name in api_map.keys():
f.write(' case HIP_API_ID_' + name + ': return "' + name + '";\n')
f.write(' };\n')
f.write(' return "unknown";\n')
f.write('};\n')
# Generating the callbacks data structure
f.write('\n// HIP API callbacks data structure\n')
f.write(
'struct hip_api_data_t {\n' +
' uint64_t correlation_id;\n' +
' uint32_t phase;\n' +
' union {\n'
)
for name, args in api_map.items():
if len(args) != 0:
f.write(' struct {\n')
for arg_tuple in args:
f.write(' ' + arg_tuple[0] + ' ' + arg_tuple[1] + ';\n')
f.write(' } ' + name + ';\n')
f.write(
' } args;\n' +
'};\n'
)
# Generating the callbacks args data filling macros
f.write('\n// HIP API callbacks args data filling macros\n')
for name, args in api_map.items():
f.write('#define INIT_' + name + '_CB_ARGS_DATA(cb_data) { \\\n')
if name in opts_map:
opts_list = opts_map[name]
if len(args) != len(opts_list):
print ("Error: \"" + name + "\" API args and opts mismatch, args: ", args, ", opts: ", opts_list)
for ind in range(0, len(args)):
arg_tuple = args[ind]
arg_type = arg_tuple[0]
fld_name = arg_tuple[1]
arg_name = opts_list[ind]
f.write(' cb_data.args.' + name + '.' + fld_name + ' = (' + arg_type + ')' + arg_name + '; \\\n')
f.write('};\n')
f.write('#define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(cb_data)\n')
# Generating the method for the API string, name and parameters
f.write('\n')
f.write('#if 0\n')
f.write('// HIP API string method, method name and parameters\n')
f.write('const char* hipApiString(hip_api_id_t id, const hip_api_data_t* data) {\n')
f.write(' std::ostringstream oss;\n')
f.write(' switch (id) {\n')
for name, args in api_map.items():
f.write(' case HIP_API_ID_' + name + ':\n')
f.write(' oss << "' + name + '("')
for ind in range(0, len(args)):
arg_tuple = args[ind]
arg_name = arg_tuple[1]
if ind != 0: f.write(' << ","')
f.write('\n << " ' + arg_name + '=" << data->args.' + name + '.' + arg_name)
f.write('\n << ")";\n')
f.write(' break;\n')
f.write(' default: oss << "unknown";\n')
f.write(' };\n')
f.write(' return strdup(oss.str().c_str());\n')
f.write('};\n')
f.write('#endif\n')
# # Generating the activity record type
# f.write('\
# \n\
# // HIP API activity record type\n\
# // Base record type\n\
# struct hip_act_record_t {\n\
# uint32_t domain; // activity domain id\n\
# uint32_t op_id; // operation id, dispatch/copy/barrier\n\
# uint32_t activity_kind; // activity kind\n\
# uint64_t correlation_id; // activity correlation ID\n\
# uint64_t begin_ns; // host begin timestamp, nano-seconds\n\
# uint64_t end_ns; // host end timestamp, nano-seconds\n\
# };\n\
# // Async record type\n\
# struct hip_async_record_t : hip_act_record_t {\n\
# int device_id;\n\
# uint64_t stream_id;\n\
# };\n\
# // Dispatch record type\n\
# struct hip_dispatch_record_t : hip_async_record_t {};\n\
# // Barrier record type\n\
# struct hip_barrier_record_t : hip_async_record_t {};\n\
# // Memcpy record type\n\
# struct hip_copy_record_t : hip_async_record_t {\n\
# size_t bytes;\n\
# };\n\
# // Generic async operation record\n\
# typedef hip_copy_record_t hip_ops_record_t;\n\
# ')
# # Generating the callbacks table
# f.write('\n// HIP API callbacks table\n')
# f.write('\
# struct hip_cb_table_t {\n\
# struct { hip_cb_fun_t act; hip_cb_fun_t fun; void* arg; } arr[HIP_API_ID_NUMBER];\n\
# };\n\
# #define HIP_CALLBACKS_TABLE hip_cb_table_t HIP_API_callbacks_table{};\n\
# ')
# f.write('\
# inline bool HIP_SET_ACTIVITY(uint32_t id, hip_cb_fun_t fun, void* arg = NULL) {\n\
# (void)arg;\n\
# extern hip_cb_table_t HIP_API_callbacks_table;\n\
# if (id < HIP_API_ID_NUMBER) {\n\
# HIP_API_callbacks_table.arr[id].act = fun;\n\
# return true;\n\
# }\n\
# return false;\n\
# }\n')
# f.write('\
# inline bool HIP_SET_CALLBACK(uint32_t id, hip_cb_fun_t fun, void* arg) {\n\
# extern hip_cb_table_t HIP_API_callbacks_table; \n\
# if (id < HIP_API_ID_NUMBER) {\n\
# HIP_API_callbacks_table.arr[id].fun = fun;\n\
# HIP_API_callbacks_table.arr[id].arg = arg;\n\
# return true;\n\
# }\n\
# return false;\n\
# }\n')
#
# # Generating the callback spawning class
# f.write('\n// HIP API callbacks spawning class macro\n\
# #define CB_SPAWNER_OBJECT(cb_id) \\\n\
# class api_callbacks_spawner_t { \\\n\
# public: \\\n\
# api_callbacks_spawner_t(hip_cb_data_t& cb_data) : cb_data_(cb_data) { \\\n\
# hip_cb_id_t id = HIP_API_ID_##cb_id; \\\n\
# cb_data_.id = id; \\\n\
# cb_data_.correlation_id = UINT_MAX; \\\n\
# cb_data_.name = #cb_id; \\\n\
# extern const hip_cb_table_t* getApiCallbackTabel(); \\\n\
# const hip_cb_table_t* cb_table = getApiCallbackTabel(); \\\n\
# cb_act_ = cb_table->arr[id].act; \\\n\
# cb_fun_ = cb_table->arr[id].fun; \\\n\
# cb_arg_ = cb_table->arr[id].arg; \\\n\
# cb_data_.on_enter = true; \\\n\
# if (cb_act_ != NULL) cb_act_(&cb_data_, NULL); \\\n\
# if (cb_fun_ != NULL) cb_fun_(&cb_data_, cb_arg_); \\\n\
# } \\\n\
# ~api_callbacks_spawner_t() { \\\n\
# cb_data_.on_enter = false; \\\n\
# if (cb_act_ != NULL) cb_act_(&cb_data_, NULL); \\\n\
# if (cb_fun_ != NULL) cb_fun_(&cb_data_, cb_arg_); \\\n\
# } \\\n\
# private: \\\n\
# hip_cb_data_t& cb_data_; \\\n\
# hip_cb_fun_t cb_act_; \\\n\
# hip_cb_fun_t cb_fun_; \\\n\
# void* cb_arg_; \\\n\
# }; \\\n\
# hip_cb_data_t cb_data{}; \\\n\
# INIT_CB_ARGS_DATA(cb_id, cb_data); \\\n\
# api_callbacks_spawner_t api_callbacks_spawner(cb_data); \n\
# ')
f.write('#endif // _HIP_PROF_STR_H\n');
print "Header '" + HEADER + "' is generated"
#############################################################