Файли
Venkateshwar Reddy Kandula 1ab3e7930a SWDEV-480556:new HIP APIs with struct defined inside union.
- Fixes issue where types defined inside of a union were unable to find by function which generates hip_begin.cpp.i.
- Also works for futureHIP APIs with  nunion,  when defined new types inside it.

Change-Id: If1e1c4f80dc511450fd154d64d39a05f81565220


[ROCm/rocprofiler commit: 52125f7966]
2024-09-09 01:11:44 +00:00

719 рядки
21 KiB
Python

################################################################################
# Copyright (c) 2022 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
################################################################################
import os
import os.path
import sys
import re
import yaml
import CppHeaderParser
# Numeric field type (abstract).
class _NumericFt:
# Returns the C++ expression to cast the expression `expr` to the C
# type of this field type.
def cast(self, expr):
return f"static_cast<{self.c_type}>({expr})"
# Integer field type (abstract).
class _IntFt(_NumericFt):
def __init__(self, size, pref_disp_base="dec"):
self._size = size
self._pref_disp_base = pref_disp_base
# Size (bits).
@property
def size(self):
return self._size
# Preferred display base (`dec` or `hex`).
@property
def pref_disp_base(self):
return self._pref_disp_base
# Equivalent barectf field type in YAML.
@property
def barectf_yaml(self):
return {
"size": self._size,
"preferred-display-base": self._pref_disp_base,
}
# Signed integer field type.
class _SIntFt(_IntFt):
# Equivalent barectf field type in YAML.
@property
def barectf_yaml(self):
ret = super().barectf_yaml
ret["class"] = "sint"
return ret
# Equivalent C type
@property
def c_type(self):
return f"std::int{self._size}_t"
# Unsigned integer field type.
class _UIntFt(_IntFt):
# Equivalent barectf field type in YAML.
@property
def barectf_yaml(self):
ret = super().barectf_yaml
ret["class"] = "uint"
return ret
# Equivalent C type.
@property
def c_type(self):
return f"std::uint{self._size}_t"
# Pointer field type.
class _PointerFt(_UIntFt):
def __init__(self):
super().__init__(64, "hex")
# Returns the C++ expression to cast the expression `expr` to the C
# type of this field type.
def cast(self, expr):
return f"static_cast<{self.c_type}>(reinterpret_cast<std::uintptr_t>({expr}))"
# Enumeration field type (abstract).
class _EnumFt(_IntFt):
def __init__(self, size, mappings):
super().__init__(size)
self._mappings = mappings.copy()
# Mappings (names to integers).
@property
def mappings(self):
return self._mappings
# Equivalent barectf field type in YAML.
@property
def barectf_yaml(self):
ret = super().barectf_yaml
mappings = {}
for name, val in self._mappings.items():
mappings[name] = [val]
ret["mappings"] = mappings
return ret
# Unsigned enumeration field type.
class _UEnumFt(_EnumFt, _UIntFt):
# Equivalent barectf field type in YAML.
@property
def barectf_yaml(self):
ret = super().barectf_yaml
ret["class"] = "uenum"
return ret
# Signed enumeration field type.
class _SEnumFt(_EnumFt, _UIntFt):
# Equivalent barectf field type in YAML.
@property
def barectf_yaml(self):
ret = super().barectf_yaml
ret["class"] = "senum"
return ret
# Optional string field type.
class _OptStrFt:
# Equivalent barectf field type in YAML.
@property
def barectf_yaml(self):
return {
"class": "str",
}
# String field type.
class _StrFt(_OptStrFt):
pass
# Floating-point number field type.
class _FloatFt(_NumericFt):
def __init__(self, size):
self._size = size
# Size (bits): 32 or 64.
@property
def size(self):
return self._size
# Equivalent barectf field type in YAML.
@property
def barectf_yaml(self):
return {
"class": "real",
"size": self._size,
}
# Equivalent C type.
@property
def c_type(self):
if self._size == 32:
return "float"
else:
assert self._size == 64
return "double"
# Event record type.
class _Ert:
def __init__(self, api_func_name, members):
self._api_func_name = api_func_name
self._members = members
# API function name
@property
def api_func_name(self):
return self._api_func_name
# Parameters of function (list of `_ErtMember`).
@property
def members(self):
return self._members
# Beginning event record type.
class _BeginErt(_Ert):
# Name of event record type depending on the API prefix.
def name(self, api_prefix):
suffix = "_begin" if api_prefix == "hsa" else "Begin"
return f"{self._api_func_name}{suffix}"
# End event record type.
class _EndErt(_Ert):
# Name of event record type depending on the API prefix.
def name(self, api_prefix):
suffix = "_end" if api_prefix == "hsa" else "End"
return f"{self._api_func_name}{suffix}"
# Event record type member.
class _ErtMember:
def __init__(self, access, member_names, ft):
self._access = access
self._member_names = member_names.copy()
self._ft = ft
# C++ access expression.
@property
def access(self):
return self._access
# List of member names.
@property
def member_names(self):
return self._member_names
# Equivalent field type.
@property
def ft(self):
return self._ft
# Makes sure some condition is satisfied, or prints the error message
# `error_msg` and quits with exit status 1 otherwise.
#
# This is an unconditional assertion.
def _make_sure(cond, error_msg):
if not cond:
print(f"Error: {error_msg}", file=sys.stderr)
sys.exit(1)
def _enumerator_effective_val(enum_val):
# Try the value, but this value may be a string (an
# enumerator/definition).
val = enum_val.get("value")
if type(val) is int:
return val
# Try the raw value.
val = enum_val.get("raw_value")
if val is not None:
if type(val) is int:
# Raw value is already an integer.
return val
else:
# Try to parse the raw value string as an integer.
try:
return int(val, 0)
except:
pass
_make_sure(False, f'Cannot get the integral value of enumerator `{enum_val["name"]}`')
# Returns the equivalent field type of the C type `c_type`.
def _number_ft_from_c_type(cpp_header, c_type):
# Check for known enumeration.
m = re.match(r"(?:enum\s+)?(\w+)", c_type)
if m:
size = 32
for enum_info in cpp_header.enums:
if m.group(1) == enum_info.get("name"):
# Fill enumeration field type mappings.
mappings = {
str(v["name"]): _enumerator_effective_val(v)
for v in enum_info["values"]
}
if len(mappings) == 0:
return _SIntFt(64)
if max(mappings.values()) >= 2**31 or min(mappings.values()) < -(
2**31
):
size = 64
_make_sure(
len(mappings) > 0, f'Enumeration `{enum_info["name"]}` is empty'
)
# Create corresponding enumeration field type.
return _SEnumFt(size, mappings)
# Find corresponding basic field type.
is_unsigned = "unsigned" in c_type
if "long" in c_type:
if is_unsigned:
return _UIntFt(64)
else:
return _SIntFt(64)
elif "short" in c_type:
if is_unsigned:
return _UIntFt(16)
else:
return _SIntFt(16)
elif "char" in c_type:
if is_unsigned:
return _UIntFt(8)
else:
return _SIntFt(8)
elif "float" in c_type:
return _FloatFt(32)
elif "double" in c_type:
return _FloatFt(64)
else:
# Assume `int` (often an unresolved C enumeration).
if is_unsigned:
return _UIntFt(32)
else:
return _SIntFt(32)
# Returns whether or not a property has a pointer type.
def _prop_is_pointer(prop, c_type):
if prop["pointer"] or prop["function_pointer"]:
return True
if prop["array"] and "array_size" in prop:
return True
if prop["unresolved"]:
# HSA API function pointers.
if prop["name"] in ("callback", "handler"):
return True
# HIP API function pointers.
if c_type.endswith("Fn_t"):
return True
# Check the C type itself.
if "*" in c_type or "*" in prop.get("raw_type", ""):
return True
return False
# Returns a list of event record type member objects for the structure
# `struct` considering the initial C++ access expression `access` and
# member names `member_names`.
def _get_ert_members_for_struct(cpp_header, struct, access, member_names):
members = []
member_names = member_names.copy()
member_names.append(None)
props = struct["properties"]["public"]
for index, prop in enumerate(props):
# Property name.
name = prop["name"]
# Member names, access, and C type.
member_names[-1] = str(name)
this_access = f"{access}.{name}"
c_type = prop["type"]
aliases = prop["aliases"]
# Skip no type.
if c_type == "":
continue
# Skip unnamed or union.
if name == "" or "union" in name or re.match(r"\bunion\b", c_type):
continue
# Check for known C type alias.
while True:
c_type_alias = cpp_header.typedefs.get(c_type)
if c_type_alias is None:
break
c_type = c_type_alias
# Check for C string.
if re.match(r"^((const\s+char)|(char\s+const)|char)\s*\*$", c_type.strip()):
members.append(_ErtMember(this_access, member_names, _OptStrFt()))
continue
# Check for pointer.
if _prop_is_pointer(prop, c_type):
# Pointer: use numeric value.
members.append(_ErtMember(this_access, member_names, _PointerFt()))
continue
# Check for substructure.
sub_struct = cpp_header.classes.get(c_type)
if sub_struct is None and len(aliases) == 1:
sub_struct = cpp_header.classes.get(aliases[0])
if sub_struct is not None:
members += _get_ert_members_for_struct(
cpp_header, sub_struct, this_access, member_names
)
continue
# Use a basic field type.
members.append(
_ErtMember(
this_access, member_names, _number_ft_from_c_type(cpp_header, c_type)
)
)
return members
# Returns the beginning and end event record type objects for the
# callback data structure `struct`.
def _erts_from_cb_data_struct(api_prefix, cpp_header, retval_info, struct):
# The location of the `args` union within the nested structures of
# `struct`.
args_nested_cls_index = 0
# Create return value members (to be used later).
if retval_info is not None:
args_nested_cls_index = 1
retval_members = {}
nested_classes = struct["nested_classes"]
_make_sure(
len(nested_classes) >= 1,
f"Return value union doesn't exist in `{struct['name']}`",
)
retval_union = nested_classes[0]
for prop in retval_union["properties"]["public"]:
name = str(prop["name"])
member = _ErtMember(
f"GetApiData().{name}",
["retval"],
_number_ft_from_c_type(cpp_header, prop["type"]),
)
retval_members[prop["name"]] = member
# Make sure we have everything we need.
for api_func_name, retval_name in retval_info.items():
if retval_name is not None:
_make_sure(
retval_name in retval_members,
f"Return value union member `{retval_name}` doesn't exist (function {api_func_name}())",
)
# Create beginning/end event record type objects.
begin_erts = []
end_erts = []
nested_classes = struct["nested_classes"][args_nested_cls_index]["nested_classes"]
props = struct["nested_classes"][args_nested_cls_index]["properties"]["public"]
_make_sure(
len(nested_classes) == len(props),
f'Mismatch between nested structure and member count in `{struct["name"]}`',
)
for index, prop in enumerate(props):
# API function name is the name of the member.
api_func_name = str(prop["name"])
# Get the parameters.
members = _get_ert_members_for_struct(
cpp_header, nested_classes[index], f"GetApiData().args.{api_func_name}", []
)
# Append new beginning event record type object.
begin_erts.append(_BeginErt(api_func_name, members))
# Append new end event record type object if possible.
ret_members = []
if retval_info is not None:
retval_type = retval_info.get(api_func_name)
if retval_type is not None:
ret_members.append(retval_members[retval_type])
end_erts.append(_EndErt(api_func_name, ret_members))
return begin_erts, end_erts
# Creates and returns the return value information dictionary.
#
# This dictionary maps API function names to the member to get within
# the callback data structure.
#
# This only applies to the HSA API: for other APIs, this function
# returns `None`.
def _get_retval_info(path):
if "hsa" not in os.path.basename(path):
return
retval_info = {}
cur_api_func_name = None
with open(path) as f:
for line in f:
if 'out << ")' in line and cur_api_func_name is not None:
m = re.search(r"api_data.(\w+_retval)", line)
retval_info[cur_api_func_name] = m.group(1) if m else None
else:
m = re.search(r'out << "(hsa_\w+)\(";', line)
if m:
cur_api_func_name = m.group(1)
return retval_info
# Returns a partial barectf data stream type in YAML with the event
# record types `erts`.
def _yaml_dst_from_erts(api_prefix, erts):
# Base.
yaml_erts = {}
yaml_dst = {
"event-record-types": yaml_erts,
}
# Create one event record type per API function.
for ert in erts:
# Base.
yaml_members = []
yaml_ert = {
"payload-field-type": {
"class": "struct",
"members": yaml_members,
},
}
# Create one structure field type member per member.
for member in ert.members:
# barectf doesn't support nested CTF structures, so join
# individual member names with `__` to flatten.
yaml_members.append(
{
"_"
+ "__".join(member.member_names): {
"field-type": member.ft.barectf_yaml,
},
}
)
# Add event record type.
yaml_erts[ert.name(api_prefix)] = yaml_ert
# Convert to YAML.
return yaml.dump(yaml_dst)
# Returns the C++ switch statement which calls the correct barectf
# tracing function depending on the API function operation ID.
def _cpp_switch_statement_from_erts(api_prefix, erts):
lines = []
lines.append("switch (GetOp()) {")
for ert in erts:
if api_prefix == 'hip' and 'R0600' in ert.api_func_name:
continue
lines.append(f" case {api_prefix.upper()}_API_ID_{ert.api_func_name}:")
lines.append(f" barectf_{api_prefix}_api_trace_{ert.name(api_prefix)}(")
lines.append(f" &barectf_ctx,")
lines.append(f" GetThreadId(),")
lines.append(f" GetQueueId(),")
lines.append(f" GetAgentId(),")
lines.append(f" GetCorrelationId(),")
if api_prefix == "hip":
lines.append(f" GetKernelName().c_str(),")
if len(ert.members) == 0:
# Remove last comma.
lines[-1] = lines[-1].replace(",", "")
for index, member in enumerate(ert.members):
if type(member.ft) is _OptStrFt:
# Only dereference C string if not null, otherwise use
# an empty string.
lines.append(f' {member.access} ? {member.access} : ""')
elif type(member.ft) is _StrFt:
lines.append(f" {member.access}")
else:
lines.append(f" {member.ft.cast(member.access)}")
if index + 1 < len(ert.members):
lines[-1] += ","
lines.append(" );")
lines.append(" break;")
lines.append("}")
return lines
# Returns a set of expected API function names based on the
# enumerators of the `*_api_id_t` enumeration.
def _api_func_names(api_prefix, cpp_header):
# Find the `*_api_id_t` enumeration.
for enum in cpp_header.enums:
if enum.get('name') == f'{api_prefix}_api_id_t':
break
# Create the set of API function names based on enumerators.
func_names = set()
pat = re.compile(rf'{api_prefix.upper()}_API_ID_(_*{api_prefix}.+)$')
for entry in enum['values']:
if type(entry['value']) is str and 'API_ID_NONE' in entry['value']:
# An enumerator may have the value `*_API_ID_NONE` which
# means the corresponding API function is not available.
continue
m = pat.match(entry['name'])
if m is not None:
func_names.add(m.group(1))
# Return API function names
return func_names
# Processes the complete API header file `path`.
def _process_file(api_prefix, path):
# Create `CppHeader` object.
try:
cpp_header = CppHeaderParser.CppHeader(path)
except CppHeaderParser.CppParseError as exc:
print(exc, file=sys.stderr)
sys.exit(1)
# Get return value information dictionary.
retval_info = _get_retval_info(path)
# add support for structures defined inside union.
new_items = []
for struct_name, struct in cpp_header.classes.items():
# Check if the struct_name starts with 'union '.
if re.match(r'^union \w+', struct_name) is not None:
parts = struct_name.split('::')
simplified = parts[-1]
if simplified != "union ":
new_items.append((simplified, cpp_header.classes.get(struct_name)))
for key, value in new_items:
cpp_header.classes[key] = value
# Find callback data structure.
for struct_name, struct in cpp_header.classes.items():
if re.match(r'^' + api_prefix + r'_api_data\w+$', struct_name) is not None:
break
# Process callback data structure.
begin_erts, end_erts = _erts_from_cb_data_struct(api_prefix,
cpp_header,
retval_info,
struct)
# API functions without parameters are not part of the callback data
# structure, but they have an ID in the `*_api_id_t` enumeration.
#
# Add missing event record types to `begin_erts` and `end_erts`
# considering the `*_api_id_t` enumeration.
processed_api_func_names = set([ert.api_func_name for ert in begin_erts])
for func_name in _api_func_names(api_prefix, cpp_header):
if func_name not in processed_api_func_names:
begin_erts.append(_BeginErt(func_name, []))
end_erts.append(_EndErt(func_name, []))
# Write barectf YAML file.
with open(f'{api_prefix}_erts.yaml', 'w') as f:
f.write(_yaml_dst_from_erts(api_prefix, begin_erts + end_erts))
# Write C++ code (beginning event record).
with open(f'{api_prefix}_begin.cpp.i', 'w') as f:
f.write('\n'.join(_cpp_switch_statement_from_erts(api_prefix,
begin_erts)))
# Write C++ code (end event record).
with open(f'{api_prefix}_end.cpp.i', 'w') as f:
f.write('\n'.join(_cpp_switch_statement_from_erts(api_prefix,
end_erts)))
if __name__ == "__main__":
# Disable `CppHeaderParser` printing to standard output.
CppHeaderParser.CppHeaderParser.print_warnings = 0
CppHeaderParser.CppHeaderParser.print_errors = 0
CppHeaderParser.CppHeaderParser.debug = 0
CppHeaderParser.CppHeaderParser.debug_trace = 0
# Process the complete API header file.
_process_file(sys.argv[1], sys.argv[2])