diff --git a/projects/rocprofiler/plugin/att/att.py b/projects/rocprofiler/plugin/att/att.py index 78123a3f5a..901dd83d41 100755 --- a/projects/rocprofiler/plugin/att/att.py +++ b/projects/rocprofiler/plugin/att/att.py @@ -117,22 +117,20 @@ class Wave(ctypes.Structure): ('timeline_string', ctypes.c_char_p), ('instructions_string', ctypes.c_char_p)] - +# Flags : +# IS_NAVI = 0x1 class ReturnInfo(ctypes.Structure): _fields_ = [('num_waves', ctypes.c_uint64), ('wavedata', POINTER(Wave)), ('num_events', ctypes.c_uint64), ('perfevents', POINTER(PerfEvent)), ('occupancy', POINTER(ctypes.c_uint64)), - ('num_occupancy', ctypes.c_uint64)] + ('num_occupancy', ctypes.c_uint64), + ('flags', ctypes.c_uint64)] rocprofv2_att_lib = os.getenv('ROCPROFV2_ATT_LIB_PATH') -try: # For build dir - path_to_parser = os.path.abspath(rocprofv2_att_lib) - SO = CDLL(path_to_parser) -except: # For installed dir - path_to_parser = os.path.abspath('/usr/lib/hsa-amd-aqlprofile/librocprofv2_att.so') - SO = CDLL(path_to_parser) +path_to_parser = os.path.abspath(rocprofv2_att_lib) +SO = CDLL(path_to_parser) SO.AnalyseBinary.restype = ReturnInfo SO.AnalyseBinary.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_bool] @@ -182,18 +180,11 @@ def getWaves(filename, target_cu, verbose): events = [deepcopy(info.perfevents[k]) for k in range(info.num_events)] occupancy = [int(info.occupancy[k]) for k in range(int(info.num_occupancy))] - '''occupancy = np.asarray([f for f in occupancy if (f&0xFF) == 3]) - print(occupancy.size, occupancy.dtype) - token_time = occupancy >> 16 - value = (occupancy >> 8) & 0xFF - plt.plot(token_time, value); plt.show() - quit()''' - for wave in waves: wave.timeline = deepcopy(wave.timeline_string.decode("utf-8")) wave.instructions = deepcopy(wave.instructions_string.decode("utf-8")) - return waves, events, occupancy + return waves, events, occupancy, 'navi' if (info.flags & 0x1) else 'vega' def persist(trace_file, SIMD): @@ -328,9 +319,6 @@ def draw_wave_metrics(selections, normalize): delta_step = 8 quad_delta_time = max(delta_step,int(0.5+np.min([get_delta_time(events) for events in EVENTS]))) maxtime = np.max([np.max([e.time for e in events]) for events in EVENTS])/quad_delta_time+1 - event_timeline = np.zeros((16, maxtime), dtype=np.int32) - print('Delta:', quad_delta_time) - print('Max_cycles:', maxtime) if maxtime*delta_step >= COUNTERS_MAX_CAPTURES: delta_step = 1 @@ -468,7 +456,8 @@ if __name__ == "__main__": for line in lines: if 'PERFCOUNTER=' in line: EVENT_NAMES += [clean(line).split('SQ_')[1].lower()] - + if len(EVENT_NAMES) == 0: + EVENT_NAMES = ['SPI', 'Vdata', 'Sdata', 'LDS'] if args.target_cu is None: args.target_cu = 1 @@ -515,13 +504,14 @@ if __name__ == "__main__": DBFILES = [] global TIMELINES global EVENTS - TIMELINES = [np.zeros(int(1E4),dtype=np.int32) for k in range(5)] + TIMELINES = [np.zeros(int(1E4),dtype=np.int16) for k in range(5)] EVENTS = [] OCCUPANCY = [] analysed_filenames = [] + SIMD_list = [] for name in filenames: - SIMD, perfevents, occupancy = getWaves(name, args.target_cu, False) + SIMD, perfevents, occupancy, gfxv = getWaves(name, args.target_cu, False) if len(SIMD) == 0: print("Error parsing ", name) continue @@ -529,17 +519,36 @@ if __name__ == "__main__": EVENTS.append(perfevents) DBFILES.append( persist(name, SIMD) ) OCCUPANCY.append( occupancy ) - for wave in SIMD: + SIMD_list.append( SIMD ) + + min_event_time = 2**62 + for df in DBFILES: + if len(df['begin_time']) > 0: + min_event_time = min(min_event_time, np.min(df['begin_time'])) + for perf in EVENTS: + for p in perf: + min_event_time = min(min_event_time, p.time) + for occ in OCCUPANCY: + min_event_time = min(min_event_time, np.min(np.array(occ)>>16)) + print("Min time:", min_event_time) + for perf in EVENTS: + for p in perf: + p.time -= min_event_time + + OCCUPANCY = [[max(min(int((u>>16)-min_event_time)<<16,2**42),0) | (u&0xFFFFF) for u in occ] for occ in OCCUPANCY] + + for df in DBFILES: + for T in range(len(df['timeline'])): + timeline = df['timeline'][T] time_acc = 0 - tuples1 = wave.timeline.split('(') + tuples1 = timeline.split('(') tuples2 = [t.split(')')[0].split(',') for t in tuples1 if t != ''] - tuples3 = [(int(t[0]),int(t[1])) for t in tuples2] + tuples3 = [(0,df['begin_time'][T]-min_event_time)]+[(int(t[0]),int(t[1])) for t in tuples2] for state in tuples3: - if state[1] > 1E7: + if state[1] > 50E6: print('Warning: Time limit reached for ',state[0], state[1]) break - if time_acc+state[1] > TIMELINES[state[0]].size: TIMELINES[state[0]] = np.hstack([ TIMELINES[state[0]], @@ -549,7 +558,7 @@ if __name__ == "__main__": time_acc += state[1] if args.genasm and len(args.genasm) > 0: - flight_count = view_trace(args, code, jumps, DBFILES, analysed_filenames, True, None, OCCUPANCY, args.dumpfiles) + flight_count = view_trace(args, code, jumps, DBFILES, analysed_filenames, True, None, OCCUPANCY, args.dumpfiles, min_event_time, gfxv) with open(args.assembly_code, 'r') as file: lines = file.readlines() @@ -561,4 +570,4 @@ if __name__ == "__main__": for k in keys: file.write(assembly_code[k]+'\n') else: - view_trace(args, code, jumps, DBFILES, analysed_filenames, False, GeneratePIC, OCCUPANCY, args.dumpfiles) + view_trace(args, code, jumps, DBFILES, analysed_filenames, False, GeneratePIC, OCCUPANCY, args.dumpfiles, min_event_time, gfxv) diff --git a/projects/rocprofiler/plugin/att/trace_view.py b/projects/rocprofiler/plugin/att/trace_view.py index 393de612bd..3482b062a5 100755 --- a/projects/rocprofiler/plugin/att/trace_view.py +++ b/projects/rocprofiler/plugin/att/trace_view.py @@ -3,6 +3,7 @@ import sys if sys.version_info[0] < 3: raise Exception("Must be using Python 3") + import os import sys import time @@ -23,6 +24,7 @@ from copy import deepcopy from http import HTTPStatus from io import BytesIO + class Readable: def __init__(self, jsonstring): self.jsonstr = json.dumps(jsonstring) @@ -42,10 +44,12 @@ class Readable: def __len__(self): return len(self.jsonstr) + MAX_STITCHED_TOKENS = 3000000 MAX_FAILED_STITCHES = 256 STACK_SIZE_LIMIT = 64 +UNKNOWN = 0 SMEM = 1 SALU = 2 VMEM = 3 @@ -63,6 +67,7 @@ LANEIO = 14 DONT_KNOW = 100 WaveInstCategory = { + UNKNOWN: "UNKNOWN", SMEM: "SMEM", SALU: "SALU", VMEM: "VMEM", @@ -85,6 +90,7 @@ WaveInstCategory = { JSON_GLOBAL_DICTIONARY = {} + class RegisterWatchList: def __init__(self, labels): self.registers = {'v'+str(k): [[] for m in range(64)] for k in range(64)} @@ -154,15 +160,17 @@ class RegisterWatchList: def try_match_swapped(insts, code, i, line): return insts[i+1][1] == code[line][1] and insts[i][1] == code[line+1][1] + def Match(inst_value, code_value): if code_value == inst_value: return True - if code_value in [GETPC, SWAPPC, SETPC] and inst_value==SALU: + if code_value in [GETPC, SWAPPC, SETPC] and inst_value in [SALU, JUMP]: return True if code_value == BRANCH and inst_value in [JUMP, NEXT]: # TODO: Maybe lets not reorder branches? return True return False + def get_match_lookahead(insts, code, i, line): if try_match_swapped(insts, code, i, line): return [i+1, i] @@ -183,15 +191,20 @@ def get_match_lookahead(insts, code, i, line): new_inst_order += [j for j in list(range(i, max(new_inst_order)+1)) if j not in new_inst_order] return new_inst_order -def stitch(insts, raw_code, jumps): + +def stitch(insts, raw_code, jumps, gfxv): + bGFX9 = gfxv == 'vega' result, i, line, loopCount, N = [], 0, 0, defaultdict(int), len(insts) - - SMEM_INST = [] - VMEM_INST = [] + + SMEM_INST = [] # scalar memory + VLMEM_INST = [] # vector memory load + VSMEM_INST = [] # vector memory store FLAT_INST = [] NUM_SMEM = 0 - NUM_VMEM = 0 + NUM_VLMEM = 0 + NUM_VSMEM = 0 NUM_FLAT = 0 + skipped_immed = 0 mem_unroll = [] flight_count = [] @@ -217,7 +230,8 @@ def stitch(insts, raw_code, jumps): jumps = {jump_map[j]+1: j for j in jumps} smem_ordering = 0 - vmem_ordering = 0 + vlmem_ordering = 0 + vsmem_ordering = 0 max_line = 0 watchlist = RegisterWatchList(labels=labels) @@ -242,39 +256,48 @@ def stitch(insts, raw_code, jumps): next = line+1 if as_line[1] == GETPC: # TODO: @ can put you ahead of label! watchlist.getpc(as_line[0], code[line+1][0]) - matched = inst[1] == SALU + matched = inst[1] in [SALU, JUMP] elif as_line[1] == LANEIO: watchlist.updatelane(as_line[0]) matched = inst[1] == VALU elif as_line[1] == SETPC: next = watchlist.setpc(as_line[0]) - matched = inst[1] == SALU + matched = inst[1] in [SALU, JUMP] elif as_line[1] == SWAPPC: next = watchlist.swappc(as_line[0], line) #print('Next:', next, code[next]) - matched = inst[1] == SALU + matched = inst[1] in [SALU, JUMP] elif inst[1] == as_line[1]: if line in jumps: loopCount[jumps[line]-1] += 1 # label is the previous line - num_inflight = NUM_FLAT + NUM_SMEM + NUM_VMEM + num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM if inst[1] == SMEM or inst[1] == LDS: smem_ordering = 1 if inst[1] == SMEM else smem_ordering SMEM_INST.append([reverse_map[line], num_inflight]) NUM_SMEM += 1 elif inst[1] == VMEM or (inst[1] == FLAT and 'global_' in as_line[0]): - VMEM_INST.append([reverse_map[line], num_inflight]) - NUM_VMEM += 1 - if 'buffer_' in as_line[0]: - #watchlist.LDS_buffer_op(as_line[0]) - vmem_ordering = 1 + inc_ordering = False + if 'buffer_' in as_line[0] or 'flat_' in as_line[0]: + inc_ordering = True + + if bGFX9 or 'load' in as_line[0]: + VLMEM_INST.append([reverse_map[line], num_inflight]) + NUM_VLMEM += 1 + if inc_ordering: + vlmem_ordering = 1 + else: + VSMEM_INST.append([reverse_map[line], num_inflight]) + NUM_VSMEM += 1 + if inc_ordering: + vsmem_ordering = 1 elif inst[1] == FLAT: smem_ordering = 1 - vmem_ordering = 1 + vlmem_ordering = 1 + vsmem_ordering = 1 FLAT_INST.append([reverse_map[line], num_inflight]) NUM_FLAT += 1 - elif inst[1] == IMMED and 'waitcnt' in as_line[0]: - + elif inst[1] == IMMED and 's_waitcnt ' in as_line[0]: if 'lgkmcnt' in as_line[0]: wait_N = int(as_line[0].split('lgkmcnt(')[1].split(')')[0]) flight_count.append([as_line[-1], num_inflight, wait_N]) @@ -290,23 +313,44 @@ def stitch(insts, raw_code, jumps): else: NUM_SMEM = min(max(wait_N-NUM_FLAT, 0), NUM_SMEM) NUM_FLAT = min(max(wait_N-NUM_SMEM, 0), NUM_FLAT) - num_inflight = NUM_FLAT + NUM_SMEM + NUM_VMEM + num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM if 'vmcnt' in as_line[0]: wait_N = int(as_line[0].split('vmcnt(')[1].split(')')[0]) flight_count.append([as_line[-1], num_inflight, wait_N]) if wait_N == 0: - vmem_ordering = 0 - if vmem_ordering == 0: - offset = len(VMEM_INST)-wait_N - mem_unroll.append( [reverse_map[line], VMEM_INST[:offset]+FLAT_INST] ) - VMEM_INST = VMEM_INST[offset:] - NUM_VMEM = len(VMEM_INST) + vlmem_ordering = 0 + if vlmem_ordering == 0: + offset = len(VLMEM_INST)-wait_N + mem_unroll.append( [reverse_map[line], VLMEM_INST[:offset]+FLAT_INST] ) + VLMEM_INST = VLMEM_INST[offset:] + NUM_VLMEM = len(VLMEM_INST) FLAT_INST = [] NUM_FLAT = 0 else: - NUM_VMEM = min(max(wait_N-NUM_FLAT, 0), NUM_VMEM) - NUM_FLAT = min(max(wait_N-NUM_VMEM, 0), NUM_FLAT) + NUM_VLMEM = min(max(wait_N-NUM_FLAT, 0), NUM_VLMEM) + NUM_FLAT = min(max(wait_N-NUM_VLMEM, 0), NUM_FLAT) + num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM + + if 'vscnt' in as_line[0] or (bGFX9 and 'vmcnt' in as_line[0]): + try: + wait_N = int(as_line[0].split('vscnt(')[1].split(')')[0]) + except: + wait_N = int(as_line[0].split('vmcnt(')[1].split(')')[0]) + flight_count.append([as_line[-1], num_inflight, wait_N]) + if wait_N == 0: + vsmem_ordering = 0 + if vsmem_ordering == 0: + offset = len(VSMEM_INST)-wait_N + mem_unroll.append( [reverse_map[line], VSMEM_INST[:offset]+FLAT_INST] ) + VSMEM_INST = VSMEM_INST[offset:] + NUM_VSMEM = len(VSMEM_INST) + FLAT_INST = [] + NUM_FLAT = 0 + else: + NUM_VSMEM = min(max(wait_N-NUM_FLAT, 0), NUM_VSMEM) + NUM_FLAT = min(max(wait_N-NUM_VSMEM, 0), NUM_FLAT) + num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM elif inst[1] == JUMP and as_line[1] == BRANCH: next = jump_map[as_line[2]] @@ -319,20 +363,32 @@ def stitch(insts, raw_code, jumps): matched = False next = line + 1 if i+1 < N and line+1 < len(code): + #print('Swap:', try_match_swapped(insts, code, i, line)) if try_match_swapped(insts, code, i, line): temp = insts[i] insts[i] = insts[i+1] insts[i+1] = temp next = line - elif 's_waitcnt' in as_line[0] or '_load_' in as_line[0]: - print(as_line) - break + elif 's_waitcnt ' in as_line[0] or '_load_' in as_line[0]: + if skipped_immed > 0 and 's_waitcnt ' in as_line[0]: + matched = True + skipped_immed -= 1 + else: + print('Parsing terminated at:', as_line) + break + + #print(matched, WaveInstCategory[inst[1]], WaveInstCategory[as_line[1]], as_line, inst) + #print([WaveInstCategory[insts[i+k][1]] for k in range(20) if i+k < len(insts)]) if matched: - new_res = inst + (reverse_map[line],) # (line,) - result.append(new_res) + result.append(inst + (reverse_map[line],)) i += 1 num_failed_stitches = 0 + elif inst[1] == IMMED and line != next: + skipped_immed += 1 + result.append(inst + (reverse_map[line],)) + next = line + i += 1 else: num_failed_stitches += 1 line = next @@ -340,7 +396,7 @@ def stitch(insts, raw_code, jumps): N = max(N, 1) if len(result) != N: print('Warning - Stitching rate: '+str(len(result) * 100 / N)+'% matched') - print('Leftovers:', [WaveInstCategory[insts[i+k][1]] for k in range(5) if i+k < len(insts)]) + print('Leftovers:', [WaveInstCategory[insts[i+k][1]] for k in range(20) if i+k < len(insts)]) try: print(line, code[line]) except: @@ -348,7 +404,7 @@ def stitch(insts, raw_code, jumps): else: while line < len(code): if 's_endpgm' in code[line]: - mem_unroll.append( [reverse_map[line], SMEM_INST+VMEM_INST+FLAT_INST] ) + mem_unroll.append( [reverse_map[line], SMEM_INST+VLMEM_INST+VSMEM_INST+FLAT_INST] ) break line += 1 @@ -372,6 +428,7 @@ IPAddr = get_ip() PORT, WebSocketPort = 8000, 18000 SP = '\u00A0' + def extract_tuple(content, num): vals = content.split(',') assert (len(vals) == num) @@ -443,7 +500,7 @@ def extract_waves(waves): return result -def extract_data(df, se_number, code, jumps): +def extract_data(df, se_number, code, jumps, gfxv): if len(df['id']) == 0 or len(df['instructions']) == 0 or len(df['timeline']) == 0: return None @@ -467,7 +524,7 @@ def extract_data(df, se_number, code, jumps): for x in df['timeline'][wave_id].split('),'): timeline.append(extract_tuple(x, 2)) - stitched, loopCount, mem_unroll, count, maxline = stitch(insts, code, jumps) + stitched, loopCount, mem_unroll, count, maxline = stitch(insts, code, jumps, gfxv) srate = len(stitched)**2 / max(len(insts), 1) if srate <= maxgrade[df['simd'][wave_id]][df['wave_slot'][wave_id]]: continue @@ -527,7 +584,7 @@ class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): counters_json, imagebytes, _, _ = PICTURE_CALLBACK(selections[1:], selections[0]) JSON_GLOBAL_DICTIONARY['counters.json'] = counters_json JSON_GLOBAL_DICTIONARY[self.path.split('/')[-1]] = imagebytes - + if '.json' in self.path or 'timeline.png' in self.path or 'wstates' in self.path: try: response_file = JSON_GLOBAL_DICTIONARY[self.path.split('/')[-1]] @@ -535,7 +592,6 @@ class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): except: print('Invalid json request:', self.path) self.send_error(HTTPStatus.NOT_FOUND, "File not found") - print(JSON_GLOBAL_DICTIONARY.keys()) return self.send_response(HTTPStatus.OK) self.send_header("Content-Length", str(len(response_file))) @@ -627,7 +683,8 @@ def call_picture_callback(return_dict): for n, e in enumerate(counter_events): return_dict['se'+str(n)+'_perfcounter.json'] = Readable({"data": [v.toTuple() for v in e]}) -def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callback, OCCUPANCY, bDumpOnly): + +def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callback, OCCUPANCY, bDumpOnly, se_time_begin, gfxv): global PICTURE_CALLBACK PICTURE_CALLBACK = pic_callback manager = Manager() @@ -647,7 +704,7 @@ def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callba if len(dbname['id']) == 0: continue - count, wv_filenames = extract_data(dbname, se_number, code, jumps) + count, wv_filenames = extract_data(dbname, se_number, code, jumps, gfxv) if count is not None: flight_count.append(count) @@ -659,7 +716,7 @@ def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callba for key in simd_wave_filenames.keys(): wv_array = [[ int(s.split('_sm')[1].split('_wv')[0]), - int(s.split('_wv')[1][0]), + int(s.split('_wv')[1].split('.')[0]), s ] for s in simd_wave_filenames[key]] @@ -675,7 +732,9 @@ def view_trace(args, code, jumps, dbnames, att_filenames, bReturnLoc, pic_callba simd_wave_filenames[key] = wv_dict - JSON_GLOBAL_DICTIONARY['filenames.json'] = Readable({"filenames": simd_wave_filenames}) + JSON_GLOBAL_DICTIONARY['filenames.json'] = Readable({"filenames": simd_wave_filenames, + "global_begin_time": int(se_time_begin), + "gfxv": gfxv}) if pic_thread is not None: pic_thread.join()