Added support for Navi ATT

Change-Id: If65afd850b1a63fdda6382133c6269c8d17bfb4d


[ROCm/rocprofiler commit: a6a61c5f51]
This commit is contained in:
Giovanni LB
2023-05-02 05:17:47 -03:00
zatwierdzone przez Giovanni Baraldi
rodzic 1d95e00954
commit 0d25fd5727
2 zmienionych plików z 140 dodań i 72 usunięć
+38 -29
Wyświetl plik
@@ -117,22 +117,20 @@ class Wave(ctypes.Structure):
('timeline_string', ctypes.c_char_p),
('instructions_string', ctypes.c_char_p)]
# Flags :
# IS_NAVI = 0x1
class ReturnInfo(ctypes.Structure):
_fields_ = [('num_waves', ctypes.c_uint64),
('wavedata', POINTER(Wave)),
('num_events', ctypes.c_uint64),
('perfevents', POINTER(PerfEvent)),
('occupancy', POINTER(ctypes.c_uint64)),
('num_occupancy', ctypes.c_uint64)]
('num_occupancy', ctypes.c_uint64),
('flags', ctypes.c_uint64)]
rocprofv2_att_lib = os.getenv('ROCPROFV2_ATT_LIB_PATH')
try: # For build dir
path_to_parser = os.path.abspath(rocprofv2_att_lib)
SO = CDLL(path_to_parser)
except: # For installed dir
path_to_parser = os.path.abspath('/usr/lib/hsa-amd-aqlprofile/librocprofv2_att.so')
SO = CDLL(path_to_parser)
path_to_parser = os.path.abspath(rocprofv2_att_lib)
SO = CDLL(path_to_parser)
SO.AnalyseBinary.restype = ReturnInfo
SO.AnalyseBinary.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_bool]
@@ -182,18 +180,11 @@ def getWaves(filename, target_cu, verbose):
events = [deepcopy(info.perfevents[k]) for k in range(info.num_events)]
occupancy = [int(info.occupancy[k]) for k in range(int(info.num_occupancy))]
'''occupancy = np.asarray([f for f in occupancy if (f&0xFF) == 3])
print(occupancy.size, occupancy.dtype)
token_time = occupancy >> 16
value = (occupancy >> 8) & 0xFF
plt.plot(token_time, value); plt.show()
quit()'''
for wave in waves:
wave.timeline = deepcopy(wave.timeline_string.decode("utf-8"))
wave.instructions = deepcopy(wave.instructions_string.decode("utf-8"))
return waves, events, occupancy
return waves, events, occupancy, 'navi' if (info.flags & 0x1) else 'vega'
def persist(trace_file, SIMD):
@@ -328,9 +319,6 @@ def draw_wave_metrics(selections, normalize):
delta_step = 8
quad_delta_time = max(delta_step,int(0.5+np.min([get_delta_time(events) for events in EVENTS])))
maxtime = np.max([np.max([e.time for e in events]) for events in EVENTS])/quad_delta_time+1
event_timeline = np.zeros((16, maxtime), dtype=np.int32)
print('Delta:', quad_delta_time)
print('Max_cycles:', maxtime)
if maxtime*delta_step >= COUNTERS_MAX_CAPTURES:
delta_step = 1
@@ -468,7 +456,8 @@ if __name__ == "__main__":
for line in lines:
if 'PERFCOUNTER=' in line:
EVENT_NAMES += [clean(line).split('SQ_')[1].lower()]
if len(EVENT_NAMES) == 0:
EVENT_NAMES = ['SPI', 'Vdata', 'Sdata', 'LDS']
if args.target_cu is None:
args.target_cu = 1
@@ -515,13 +504,14 @@ if __name__ == "__main__":
DBFILES = []
global TIMELINES
global EVENTS
TIMELINES = [np.zeros(int(1E4),dtype=np.int32) for k in range(5)]
TIMELINES = [np.zeros(int(1E4),dtype=np.int16) for k in range(5)]
EVENTS = []
OCCUPANCY = []
analysed_filenames = []
SIMD_list = []
for name in filenames:
SIMD, perfevents, occupancy = getWaves(name, args.target_cu, False)
SIMD, perfevents, occupancy, gfxv = getWaves(name, args.target_cu, False)
if len(SIMD) == 0:
print("Error parsing ", name)
continue
@@ -529,17 +519,36 @@ if __name__ == "__main__":
EVENTS.append(perfevents)
DBFILES.append( persist(name, SIMD) )
OCCUPANCY.append( occupancy )
for wave in SIMD:
SIMD_list.append( SIMD )
min_event_time = 2**62
for df in DBFILES:
if len(df['begin_time']) > 0:
min_event_time = min(min_event_time, np.min(df['begin_time']))
for perf in EVENTS:
for p in perf:
min_event_time = min(min_event_time, p.time)
for occ in OCCUPANCY:
min_event_time = min(min_event_time, np.min(np.array(occ)>>16))
print("Min time:", min_event_time)
for perf in EVENTS:
for p in perf:
p.time -= min_event_time
OCCUPANCY = [[max(min(int((u>>16)-min_event_time)<<16,2**42),0) | (u&0xFFFFF) for u in occ] for occ in OCCUPANCY]
for df in DBFILES:
for T in range(len(df['timeline'])):
timeline = df['timeline'][T]
time_acc = 0
tuples1 = wave.timeline.split('(')
tuples1 = timeline.split('(')
tuples2 = [t.split(')')[0].split(',') for t in tuples1 if t != '']
tuples3 = [(int(t[0]),int(t[1])) for t in tuples2]
tuples3 = [(0,df['begin_time'][T]-min_event_time)]+[(int(t[0]),int(t[1])) for t in tuples2]
for state in tuples3:
if state[1] > 1E7:
if state[1] > 50E6:
print('Warning: Time limit reached for ',state[0], state[1])
break
if time_acc+state[1] > TIMELINES[state[0]].size:
TIMELINES[state[0]] = np.hstack([
TIMELINES[state[0]],
@@ -549,7 +558,7 @@ if __name__ == "__main__":
time_acc += state[1]
if args.genasm and len(args.genasm) > 0:
flight_count = view_trace(args, code, jumps, DBFILES, analysed_filenames, True, None, OCCUPANCY, args.dumpfiles)
flight_count = view_trace(args, code, jumps, DBFILES, analysed_filenames, True, None, OCCUPANCY, args.dumpfiles, min_event_time, gfxv)
with open(args.assembly_code, 'r') as file:
lines = file.readlines()
@@ -561,4 +570,4 @@ if __name__ == "__main__":
for k in keys:
file.write(assembly_code[k]+'\n')
else:
view_trace(args, code, jumps, DBFILES, analysed_filenames, False, GeneratePIC, OCCUPANCY, args.dumpfiles)
view_trace(args, code, jumps, DBFILES, analysed_filenames, False, GeneratePIC, OCCUPANCY, args.dumpfiles, min_event_time, gfxv)
@@ -3,6 +3,7 @@ import sys
if sys.version_info[0] < 3:
raise Exception("Must be using Python 3")
import os
import sys
import time
@@ -23,6 +24,7 @@ from copy import deepcopy
from http import HTTPStatus
from io import BytesIO
class Readable:
def __init__(self, jsonstring):
self.jsonstr = json.dumps(jsonstring)
@@ -42,10 +44,12 @@ class Readable:
def __len__(self):
return len(self.jsonstr)
MAX_STITCHED_TOKENS = 3000000
MAX_FAILED_STITCHES = 256
STACK_SIZE_LIMIT = 64
UNKNOWN = 0
SMEM = 1
SALU = 2
VMEM = 3
@@ -63,6 +67,7 @@ LANEIO = 14
DONT_KNOW = 100
WaveInstCategory = {
UNKNOWN: "UNKNOWN",
SMEM: "SMEM",
SALU: "SALU",
VMEM: "VMEM",
@@ -85,6 +90,7 @@ WaveInstCategory = {
JSON_GLOBAL_DICTIONARY = {}
class RegisterWatchList:
def __init__(self, labels):
self.registers = {'v'+str(k): [[] for m in range(64)] for k in range(64)}
@@ -154,15 +160,17 @@ class RegisterWatchList:
def try_match_swapped(insts, code, i, line):
return insts[i+1][1] == code[line][1] and insts[i][1] == code[line+1][1]
def Match(inst_value, code_value):
if code_value == inst_value:
return True
if code_value in [GETPC, SWAPPC, SETPC] and inst_value==SALU:
if code_value in [GETPC, SWAPPC, SETPC] and inst_value in [SALU, JUMP]:
return True
if code_value == BRANCH and inst_value in [JUMP, NEXT]: # TODO: Maybe lets not reorder branches?
return True
return False
def get_match_lookahead(insts, code, i, line):
if try_match_swapped(insts, code, i, line):
return [i+1, i]
@@ -183,15 +191,20 @@ def get_match_lookahead(insts, code, i, line):
new_inst_order += [j for j in list(range(i, max(new_inst_order)+1)) if j not in new_inst_order]
return new_inst_order
def stitch(insts, raw_code, jumps):
def stitch(insts, raw_code, jumps, gfxv):
bGFX9 = gfxv == 'vega'
result, i, line, loopCount, N = [], 0, 0, defaultdict(int), len(insts)
SMEM_INST = []
VMEM_INST = []
SMEM_INST = [] # scalar memory
VLMEM_INST = [] # vector memory load
VSMEM_INST = [] # vector memory store
FLAT_INST = []
NUM_SMEM = 0
NUM_VMEM = 0
NUM_VLMEM = 0
NUM_VSMEM = 0
NUM_FLAT = 0
skipped_immed = 0
mem_unroll = []
flight_count = []
@@ -217,7 +230,8 @@ def stitch(insts, raw_code, jumps):
jumps = {jump_map[j]+1: j for j in jumps}
smem_ordering = 0
vmem_ordering = 0
vlmem_ordering = 0
vsmem_ordering = 0
max_line = 0
watchlist = RegisterWatchList(labels=labels)
@@ -242,39 +256,48 @@ def stitch(insts, raw_code, jumps):
next = line+1
if as_line[1] == GETPC: # TODO: @ can put you ahead of label!
watchlist.getpc(as_line[0], code[line+1][0])
matched = inst[1] == SALU
matched = inst[1] in [SALU, JUMP]
elif as_line[1] == LANEIO:
watchlist.updatelane(as_line[0])
matched = inst[1] == VALU
elif as_line[1] == SETPC:
next = watchlist.setpc(as_line[0])
matched = inst[1] == SALU
matched = inst[1] in [SALU, JUMP]
elif as_line[1] == SWAPPC:
next = watchlist.swappc(as_line[0], line)
#print('Next:', next, code[next])
matched = inst[1] == SALU
matched = inst[1] in [SALU, JUMP]
elif inst[1] == as_line[1]:
if line in jumps:
loopCount[jumps[line]-1] += 1 # label is the previous line
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VMEM
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM
if inst[1] == SMEM or inst[1] == LDS:
smem_ordering = 1 if inst[1] == SMEM else smem_ordering
SMEM_INST.append([reverse_map[line], num_inflight])
NUM_SMEM += 1
elif inst[1] == VMEM or (inst[1] == FLAT and 'global_' in as_line[0]):
VMEM_INST.append([reverse_map[line], num_inflight])
NUM_VMEM += 1
if 'buffer_' in as_line[0]:
#watchlist.LDS_buffer_op(as_line[0])
vmem_ordering = 1
inc_ordering = False
if 'buffer_' in as_line[0] or 'flat_' in as_line[0]:
inc_ordering = True
if bGFX9 or 'load' in as_line[0]:
VLMEM_INST.append([reverse_map[line], num_inflight])
NUM_VLMEM += 1
if inc_ordering:
vlmem_ordering = 1
else:
VSMEM_INST.append([reverse_map[line], num_inflight])
NUM_VSMEM += 1
if inc_ordering:
vsmem_ordering = 1
elif inst[1] == FLAT:
smem_ordering = 1
vmem_ordering = 1
vlmem_ordering = 1
vsmem_ordering = 1
FLAT_INST.append([reverse_map[line], num_inflight])
NUM_FLAT += 1
elif inst[1] == IMMED and 'waitcnt' in as_line[0]:
elif inst[1] == IMMED and 's_waitcnt ' in as_line[0]:
if 'lgkmcnt' in as_line[0]:
wait_N = int(as_line[0].split('lgkmcnt(')[1].split(')')[0])
flight_count.append([as_line[-1], num_inflight, wait_N])
@@ -290,23 +313,44 @@ def stitch(insts, raw_code, jumps):
else:
NUM_SMEM = min(max(wait_N-NUM_FLAT, 0), NUM_SMEM)
NUM_FLAT = min(max(wait_N-NUM_SMEM, 0), NUM_FLAT)
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VMEM
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM
if 'vmcnt' in as_line[0]:
wait_N = int(as_line[0].split('vmcnt(')[1].split(')')[0])
flight_count.append([as_line[-1], num_inflight, wait_N])
if wait_N == 0:
vmem_ordering = 0
if vmem_ordering == 0:
offset = len(VMEM_INST)-wait_N
mem_unroll.append( [reverse_map[line], VMEM_INST[:offset]+FLAT_INST] )
VMEM_INST = VMEM_INST[offset:]
NUM_VMEM = len(VMEM_INST)
vlmem_ordering = 0
if vlmem_ordering == 0:
offset = len(VLMEM_INST)-wait_N
mem_unroll.append( [reverse_map[line], VLMEM_INST[:offset]+FLAT_INST] )
VLMEM_INST = VLMEM_INST[offset:]
NUM_VLMEM = len(VLMEM_INST)
FLAT_INST = []
NUM_FLAT = 0
else:
NUM_VMEM = min(max(wait_N-NUM_FLAT, 0), NUM_VMEM)
NUM_FLAT = min(max(wait_N-NUM_VMEM, 0), NUM_FLAT)
NUM_VLMEM = min(max(wait_N-NUM_FLAT, 0), NUM_VLMEM)
NUM_FLAT = min(max(wait_N-NUM_VLMEM, 0), NUM_FLAT)
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM
if 'vscnt' in as_line[0] or (bGFX9 and 'vmcnt' in as_line[0]):
try:
wait_N = int(as_line[0].split('vscnt(')[1].split(')')[0])
except:
wait_N = int(as_line[0].split('vmcnt(')[1].split(')')[0])
flight_count.append([as_line[-1], num_inflight, wait_N])
if wait_N == 0:
vsmem_ordering = 0
if vsmem_ordering == 0:
offset = len(VSMEM_INST)-wait_N
mem_unroll.append( [reverse_map[line], VSMEM_INST[:offset]+FLAT_INST] )
VSMEM_INST = VSMEM_INST[offset:]
NUM_VSMEM = len(VSMEM_INST)
FLAT_INST = []
NUM_FLAT = 0
else:
NUM_VSMEM = min(max(wait_N-NUM_FLAT, 0), NUM_VSMEM)
NUM_FLAT = min(max(wait_N-NUM_VSMEM, 0), NUM_FLAT)
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM
elif inst[1] == JUMP and as_line[1] == BRANCH:
next = jump_map[as_line[2]]
@@ -319,20 +363,32 @@ def stitch(insts, raw_code, jumps):
matched = False
next = line + 1
if i+1 < N and line+1 < len(code):
#print('Swap:', try_match_swapped(insts, code, i, line))
if try_match_swapped(insts, code, i, line):
temp = insts[i]
insts[i] = insts[i+1]
insts[i+1] = temp
next = line
elif 's_waitcnt' in as_line[0] or '_load_' in as_line[0]:
print(as_line)
break
elif 's_waitcnt ' in as_line[0] or '_load_' in as_line[0]:
if skipped_immed > 0 and 's_waitcnt ' in as_line[0]:
matched = True
skipped_immed -= 1
else:
print('Parsing terminated at:', as_line)
break
#print(matched, WaveInstCategory[inst[1]], WaveInstCategory[as_line[1]], as_line, inst)
#print([WaveInstCategory[insts[i+k][1]] for k in range(20) if i+k < len(insts)])
if matched:
new_res = inst + (reverse_map[line],) # (line,)
result.append(new_res)
result.append(inst + (reverse_map[line],))
i += 1
num_failed_stitches = 0
elif inst[1] == IMMED and line != next:
skipped_immed += 1
result.append(inst + (reverse_map[line],))
next = line
i += 1
else:
num_failed_stitches += 1
line = next
@@ -340,7 +396,7 @@ def stitch(insts, raw_code, jumps):
N = max(N, 1)
if len(result) != N:
print('Warning - Stitching rate: '+str(len(result) * 100 / N)+'% matched')
print('Leftovers:', [WaveInstCategory[insts[i+k][1]] for k in range(5) if i+k < len(insts)])
print('Leftovers:', [WaveInstCategory[insts[i+k][1]] for k in range(20) if i+k < len(insts)])
try:
print(line, code[line])
except:
@@ -348,7 +404,7 @@ def stitch(insts, raw_code, jumps):
else:
while line < len(code):
if 's_endpgm' in code[line]:
mem_unroll.append( [reverse_map[line], SMEM_INST+VMEM_INST+FLAT_INST] )
mem_unroll.append( [reverse_map[line], SMEM_INST+VLMEM_INST+VSMEM_INST+FLAT_INST] )
break
line += 1
@@ -372,6 +428,7 @@ IPAddr = get_ip()
PORT, WebSocketPort = 8000, 18000
SP = '\u00A0'
def extract_tuple(content, num):
vals = content.split(',')
assert (len(vals) == num)
@@ -443,7 +500,7 @@ def extract_waves(waves):
return result
def extract_data(df, se_number, code, jumps):
def extract_data(df, se_number, code, jumps, gfxv):
if len(df['id']) == 0 or len(df['instructions']) == 0 or len(df['timeline']) == 0:
return None
@@ -467,7 +524,7 @@ def extract_data(df, se_number, code, jumps):
for x in df['timeline'][wave_id].split('),'):
timeline.append(extract_tuple(x, 2))
stitched, loopCount, mem_unroll, count, maxline = stitch(insts, code, jumps)
stitched, loopCount, mem_unroll, count, maxline = stitch(insts, code, jumps, gfxv)
srate = len(stitched)**2 / max(len(insts), 1)
if srate <= maxgrade[df['simd'][wave_id]][df['wave_slot'][wave_id]]:
continue
@@ -527,7 +584,7 @@ class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
counters_json, imagebytes, _, _ = PICTURE_CALLBACK(selections[1:], selections[0])
JSON_GLOBAL_DICTIONARY['counters.json'] = counters_json
JSON_GLOBAL_DICTIONARY[self.path.split('/')[-1]] = imagebytes
if '.json' in self.path or 'timeline.png' in self.path or 'wstates' in self.path:
try:
response_file = JSON_GLOBAL_DICTIONARY[self.path.split('/')[-1]]
@@ -535,7 +592,6 @@ class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
except:
print('Invalid json request:', self.path)
self.send_error(HTTPStatus.NOT_FOUND, "File not found")
print(JSON_GLOBAL_DICTIONARY.keys())
return
self.send_response(HTTPStatus.OK)
self.send_header("Content-Length", str(len(response_file)))
@@ -627,7 +683,8 @@ def call_picture_callback(return_dict):
for n, e in enumerate(counter_events):
return_dict['se'+str(n)+'_perfcounter.json'] = Readable({"data": [v.toTuple() for v in e]})
def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callback, OCCUPANCY, bDumpOnly):
def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callback, OCCUPANCY, bDumpOnly, se_time_begin, gfxv):
global PICTURE_CALLBACK
PICTURE_CALLBACK = pic_callback
manager = Manager()
@@ -647,7 +704,7 @@ def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callba
if len(dbname['id']) == 0:
continue
count, wv_filenames = extract_data(dbname, se_number, code, jumps)
count, wv_filenames = extract_data(dbname, se_number, code, jumps, gfxv)
if count is not None:
flight_count.append(count)
@@ -659,7 +716,7 @@ def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callba
for key in simd_wave_filenames.keys():
wv_array = [[
int(s.split('_sm')[1].split('_wv')[0]),
int(s.split('_wv')[1][0]),
int(s.split('_wv')[1].split('.')[0]),
s
] for s in simd_wave_filenames[key]]
@@ -675,7 +732,9 @@ def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callba
simd_wave_filenames[key] = wv_dict
JSON_GLOBAL_DICTIONARY['filenames.json'] = Readable({"filenames": simd_wave_filenames})
JSON_GLOBAL_DICTIONARY['filenames.json'] = Readable({"filenames": simd_wave_filenames,
"global_begin_time": int(se_time_begin),
"gfxv": gfxv})
if pic_thread is not None:
pic_thread.join()