4dd21807c0
Change-Id: If893f84b0ff108cfa0ccdcf717ee8592aa621032
590 Zeilen
20 KiB
Python
590 Zeilen
20 KiB
Python
#!/usr/bin/env python3
|
|
import sys
|
|
|
|
if sys.version_info[0] < 3:
|
|
raise Exception("Must be using Python 3")
|
|
|
|
from collections import defaultdict
|
|
from copy import deepcopy
|
|
|
|
MAX_STITCHED_TOKENS = 200000000
|
|
MAX_FAILED_STITCHES = 256
|
|
|
|
SKIP = 0
|
|
SMEM = 1
|
|
SALU = 2
|
|
VMEM = 3
|
|
FLAT = 4
|
|
LDS = 5
|
|
VALU = 6
|
|
JUMP = 7
|
|
NEXT = 8
|
|
IMMED = 9
|
|
BRANCH = 10
|
|
GETPC = 11
|
|
SETPC = 12
|
|
SWAPPC = 13
|
|
LANEIO = 14
|
|
PCINFO = 15
|
|
WAVE_ENDED = 16
|
|
DONT_KNOW = 100
|
|
|
|
WaveInstCategory = {
|
|
SKIP: "SKIP",
|
|
SMEM: "SMEM",
|
|
SALU: "SALU",
|
|
VMEM: "VMEM",
|
|
FLAT: "FLAT",
|
|
LDS: "LDS",
|
|
VALU: "VALU",
|
|
JUMP: "JUMP",
|
|
NEXT: "NEXT",
|
|
IMMED: "IMMED",
|
|
JUMP: "JUMP",
|
|
NEXT: "NEXT",
|
|
IMMED: "IMMED",
|
|
BRANCH: "BRANCH",
|
|
GETPC: "GETPC",
|
|
SETPC: "SETPC",
|
|
SWAPPC: "SWAPPC",
|
|
LANEIO: "LANEIO",
|
|
PCINFO: "PCINFO",
|
|
WAVE_ENDED: "WAVE_ENDED",
|
|
DONT_KNOW: "DONT_KNOW",
|
|
}
|
|
|
|
# Keeps track of register states for hipcc-generated assembly
|
|
class RegisterWatchList:
|
|
def __init__(self, labels, code, jump_map, insts):
|
|
self.registers = {"v" + str(k): [[] for m in range(64)] for k in range(64)}
|
|
for k in range(128):
|
|
self.registers["s" + str(k)] = []
|
|
self.labels = labels
|
|
|
|
self.code = code
|
|
self.jump_map = jump_map
|
|
self.insts = insts
|
|
|
|
def jump(self, as_line):
|
|
return self.jump_map[as_line[2]]
|
|
|
|
def getcode(self, line):
|
|
return self.code[line], 1
|
|
|
|
def getincrement(self, line):
|
|
return 1
|
|
|
|
def try_translate(self, tok):
|
|
if tok[0] in ["s"]:
|
|
return self.registers[self.range(tok)[0]]
|
|
elif "@" in tok:
|
|
return self.labels[tok.split("@")[0]] + 1
|
|
|
|
def range(self, r):
|
|
reg = r.split(":")
|
|
if len(reg) == 1:
|
|
return reg
|
|
else:
|
|
r0 = reg[0].split("[")
|
|
return [r0[0] + str(k) for k in range(int(r0[1]), int(reg[1][:-1]) + 1)]
|
|
|
|
def tokenize(self, line):
|
|
return [
|
|
u for u in [t.split(",")[0].strip() for t in line.split(" ")] if len(u) > 0
|
|
]
|
|
|
|
def getpc(self, line, next_line):
|
|
try:
|
|
dst = line.split(" ")[1].strip()
|
|
label_dests = []
|
|
try:
|
|
label_dests = next_line[0].split(", ")
|
|
except:
|
|
pass
|
|
try:
|
|
label_dests.append(next_line[0].split(", ")[-1].split("@")[0])
|
|
except:
|
|
pass
|
|
|
|
for label_dst in label_dests:
|
|
try:
|
|
cur_label = self.labels[label_dst]
|
|
for reg in self.range(dst):
|
|
self.registers[reg] = deepcopy(cur_label)
|
|
except:
|
|
pass
|
|
except:
|
|
pass
|
|
|
|
def swappc(self, line, line_num, inst_num):
|
|
try:
|
|
tokens = self.tokenize(line)
|
|
dst = tokens[1]
|
|
src = tokens[2]
|
|
|
|
popped = deepcopy(self.registers[self.range(src)[0]])
|
|
self.registers[self.range(dst)[0]] = line_num + 1
|
|
return popped
|
|
except:
|
|
return -1
|
|
|
|
def setpc(self, line, inst_num):
|
|
try:
|
|
src = line.split(' ')[1].strip()
|
|
return deepcopy(self.registers[self.range(src)[0]])
|
|
except:
|
|
return -1
|
|
|
|
def scratch(self, line):
|
|
try:
|
|
tokens = self.tokenize(line)
|
|
if "_load" in tokens[0]:
|
|
dst = tokens[1]
|
|
src = tokens[3] + tokens[4]
|
|
else:
|
|
src = tokens[2]
|
|
dst = tokens[3] + tokens[4]
|
|
self.registers[dst] = deepcopy(self.registers[src])
|
|
except:
|
|
pass
|
|
|
|
def move(self, line):
|
|
try:
|
|
tokens = self.tokenize(line)
|
|
if tokens[2][0] in ["s", "d"] and tokens[1][0] in ["s", "d"]:
|
|
self.registers[self.range(tokens[1])[0]] = deepcopy(
|
|
self.registers[self.range(tokens[2])[0]]
|
|
)
|
|
except:
|
|
pass
|
|
|
|
def updatelane(self, line):
|
|
tokens = self.tokenize(line)
|
|
try:
|
|
if "v_readlane" in tokens[0]:
|
|
self.registers[tokens[1]] = deepcopy(self.registers[tokens[2]][int(tokens[3])])
|
|
elif "v_writelane" in tokens[0]:
|
|
self.registers[tokens[1]][int(tokens[3])] = deepcopy(self.registers[tokens[2]])
|
|
except:
|
|
pass
|
|
|
|
# Matches tokens in reverse order
|
|
def try_match_swapped(self, i, line, increment):
|
|
return self.insts[i + 1].type == self.code[line][1] and self.insts[i].type == self.code[line + 1][1]
|
|
|
|
# Translates PC values to instructions, for auto captured ISA
|
|
class PCTranslator:
|
|
def __init__(self, insts, code, raw_code, reverse_map, codeservice):
|
|
self.codeservice = codeservice
|
|
|
|
self.insts = insts
|
|
self.addrmap = {c[-3] : (c, self.codeservice.GetInstruction(c[-3])[3]) for c in code if c[-3] > 0}
|
|
|
|
self.code = code
|
|
self.raw_code = raw_code
|
|
self.reverse_map = reverse_map
|
|
self.jump_map = {c[-3] : self.getjump_loc(c) for c in code if c[1] == BRANCH}
|
|
|
|
def jump(self, as_line):
|
|
return self.jump_map[as_line[-3]]
|
|
|
|
def addsymbol(self, addr):
|
|
if addr in self.addrmap:
|
|
return
|
|
|
|
symbol = self.codeservice.getSymbolName(addr)
|
|
if symbol is None:
|
|
symbol = "Unkown symbol at 0x" + hex(addr)
|
|
|
|
last_line = self.raw_code[-1]
|
|
newline = ['; ' + symbol, DONT_KNOW, last_line[2], '', last_line[4], last_line[5], 0, 0, 0]
|
|
self.raw_code.append(newline)
|
|
|
|
def getcode(self, addr):
|
|
try:
|
|
return self.addrmap[addr]
|
|
except Exception as ex:
|
|
new_inst = self.codeservice.GetInstruction(addr)
|
|
if new_inst and new_inst[3]: # Check returned size > 0
|
|
last_line = self.raw_code[-1]
|
|
newline = [new_inst[1], new_inst[0], len(self.raw_code), new_inst[2], last_line[4]+1, last_line[5]+1, addr, 0, 0]
|
|
if new_inst[0] == BRANCH:
|
|
self.jump_map[addr] = self.getjump_loc(newline)
|
|
self.addrmap[addr] = (newline, new_inst[3])
|
|
|
|
next = len(self.code)
|
|
self.reverse_map[addr] = len(self.raw_code)
|
|
self.raw_code.append(newline)
|
|
self.code.append(newline)
|
|
return newline, new_inst[3]
|
|
else:
|
|
raise ex
|
|
|
|
def jump(self, asm_line):
|
|
try:
|
|
return self.jump_map[asm_line[-3]]
|
|
except:
|
|
loc = self.getjump_loc(asm_line)
|
|
self.jump_map[asm_line[-3]] = loc
|
|
return loc
|
|
|
|
def getjump_loc(self, asm_line):
|
|
try:
|
|
dest = int(asm_line[0].split(' ')[-1])
|
|
if dest >= 32768: dest -= 65536
|
|
return asm_line[-3] + 4*dest+4
|
|
except:
|
|
return -1
|
|
|
|
def getincrement(self, addr):
|
|
return self.getcode(addr)[1]
|
|
|
|
def try_translate(self, tok):
|
|
pass
|
|
def range(self, r):
|
|
pass
|
|
def tokenize(self, line):
|
|
pass
|
|
def getpc(self, line, next_line):
|
|
pass
|
|
def swappc(self, line, line_num, inst_index):
|
|
try:
|
|
return self.getcode(self.insts[inst_index+1].cycles)[0][-3]
|
|
except:
|
|
print('SWAPPC warning: Could not find addr', hex(self.insts[inst_index+1].cycles), 'for', inst_index, line)
|
|
return -1
|
|
def setpc(self, line, inst_index):
|
|
try:
|
|
return self.getcode(self.insts[inst_index+1].cycles)[0][-3]
|
|
except:
|
|
print('SETPC warning: Could not find addr', hex(self.insts[inst_index+1].cycles), 'for', inst_index, line)
|
|
return -1
|
|
def scratch(self, line):
|
|
pass
|
|
def move(self, line):
|
|
pass
|
|
def updatelane(self, line):
|
|
pass
|
|
|
|
# Matches tokens in reverse order
|
|
def try_match_swapped(self, i, addr, increment):
|
|
try:
|
|
return self.insts[i + 1].type == self.getcode(addr)[0][1] and \
|
|
self.insts[i].type == self.getcode(addr + increment)[0][1]
|
|
except Exception as e:
|
|
return False
|
|
|
|
|
|
def stitch(insts, raw_code, jumps, gfxv, bIsAuto, codeservice):
|
|
bGFX9 = gfxv == 'vega'
|
|
|
|
result, i, loopCount = [], 0, defaultdict(int)
|
|
|
|
SMEM_INST = [] # scalar memory
|
|
VLMEM_INST = [] # vector memory load
|
|
VSMEM_INST = [] # vector memory store
|
|
FLAT_INST = []
|
|
NUM_SMEM = 0
|
|
NUM_VLMEM = 0
|
|
NUM_VSMEM = 0
|
|
NUM_FLAT = 0
|
|
skipped_immed = 0
|
|
|
|
mem_unroll = []
|
|
flight_count = []
|
|
|
|
labels = {}
|
|
jump_map = [0]
|
|
|
|
# Clean the code and remove comments
|
|
code = [raw_code[0]]
|
|
for c in raw_code[1:]:
|
|
c = list(c)
|
|
c[0] = c[0].split(";")[0].split("//")[0].strip()
|
|
jump_map.append(len(code))
|
|
|
|
if c[1] != DONT_KNOW:
|
|
code.append(c)
|
|
elif ":" in c[0]:
|
|
labels[c[0].split(":")[0]] = len(code)
|
|
|
|
reverse_map = {}
|
|
if bIsAuto:
|
|
for k, v in enumerate(jump_map):
|
|
try:
|
|
reverse_map[code[v][-3]] = k
|
|
except:
|
|
pass
|
|
else:
|
|
for k, v in enumerate(jump_map):
|
|
reverse_map[v] = k
|
|
|
|
jumps = {jump_map[j] + 1: j for j in jumps}
|
|
|
|
# Checks if we have guaranteed ordering in memory operations
|
|
smem_ordering = 0
|
|
vlmem_ordering = 0
|
|
vsmem_ordering = 0
|
|
|
|
num_failed_stitches = 0
|
|
loops = 0
|
|
maxline = 0
|
|
|
|
pcskip = []
|
|
if bIsAuto:
|
|
try:
|
|
firstinst = insts[0]
|
|
|
|
if firstinst.type != PCINFO:
|
|
print('Warning: Waves without PCINFO')
|
|
return None
|
|
elif firstinst.cycles == 0:
|
|
print('Info: Some waves started before the trace')
|
|
return None
|
|
|
|
watchlist = PCTranslator(insts, code, raw_code, reverse_map, codeservice)
|
|
watchlist.addsymbol(firstinst.cycles)
|
|
line = firstinst.cycles
|
|
lineincrement = watchlist.getincrement(line)
|
|
except KeyError as e:
|
|
print('Warning: Waves from addr', hex(e.args[0]), 'have no codeobj info.')
|
|
for i in range(len(insts)):
|
|
insts[i].asmline = 0
|
|
return [i for k, i in enumerate(insts) if i.type != PCINFO], [], [], [], 1, 0, [k for k, i in enumerate(insts) if i.type == PCINFO]
|
|
except Exception as e:
|
|
print('Unknown error', e)
|
|
return None
|
|
else:
|
|
line = 0
|
|
lineincrement = 1
|
|
watchlist = RegisterWatchList(labels=labels, code=code, jump_map=jump_map, insts=insts)
|
|
|
|
N = len(insts)
|
|
|
|
while i < N and line >= 0 and loops < MAX_STITCHED_TOKENS:
|
|
if insts[i].type == PCINFO:
|
|
pcskip.append(i)
|
|
i += 1
|
|
continue
|
|
loops += 1
|
|
|
|
inst = insts[i]
|
|
try:
|
|
as_line, lineincrement = watchlist.getcode(line)
|
|
except:
|
|
break
|
|
|
|
matched = True
|
|
next = line + lineincrement
|
|
|
|
if not bIsAuto:
|
|
if '_mov_' in as_line[0]:
|
|
watchlist.move(as_line[0])
|
|
elif 'scratch_' in as_line[0]:
|
|
watchlist.scratch(as_line[0])
|
|
|
|
if as_line[1] == DONT_KNOW or (as_line[1] == SKIP and not bGFX9):
|
|
matched = False
|
|
elif as_line[1] == GETPC:
|
|
try:
|
|
watchlist.getpc(as_line[0], watchlist.getcode(next)[0])
|
|
matched = inst.type in [SALU, JUMP]
|
|
except:
|
|
matched = False
|
|
elif as_line[1] == LANEIO:
|
|
watchlist.updatelane(as_line[0])
|
|
matched = inst.type == VALU
|
|
elif as_line[1] == SETPC:
|
|
next = watchlist.setpc(as_line[0], i)
|
|
matched = inst.type in [SALU, JUMP]
|
|
i += 1
|
|
pcskip.append(i)
|
|
while bIsAuto and next < 0 and i+1 < len(insts):
|
|
i += 1
|
|
if insts[i].type == PCINFO:
|
|
pcskip.append(i)
|
|
next = watchlist.setpc(as_line[0], i-1)
|
|
else:
|
|
inst.cycles += insts[i].cycles
|
|
if next < 0:
|
|
print('Jump to unknown location in line', as_line[0])
|
|
break
|
|
elif as_line[1] == SWAPPC:
|
|
matched = inst.type in [SALU, JUMP]
|
|
next = watchlist.swappc(as_line[0], line, i)
|
|
i += 1
|
|
pcskip.append(i)
|
|
while bIsAuto and next < 0 and i+1 < len(insts):
|
|
i += 1
|
|
if insts[i].type == PCINFO:
|
|
next = watchlist.swappc(as_line[0], line, i-1)
|
|
pcskip.append(i)
|
|
else:
|
|
inst.cycles += insts[i].cycles
|
|
if next < 0:
|
|
print('Jump to unknown location in line', as_line[0])
|
|
break
|
|
elif inst.type == as_line[1]:
|
|
if line in jumps:
|
|
loopCount[jumps[line] - 1] += 1
|
|
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM
|
|
|
|
if inst.type == SMEM or inst.type == LDS:
|
|
smem_ordering = 1 if inst.type == SMEM else smem_ordering
|
|
SMEM_INST.append([reverse_map[line], num_inflight])
|
|
NUM_SMEM += 1
|
|
elif inst.type == VMEM or (inst.type == FLAT and "global_" in as_line[0]):
|
|
inc_ordering = False
|
|
if "flat_" in as_line[0]:
|
|
inc_ordering = True
|
|
|
|
if not "_inv" in as_line[0] and not "_wb" in as_line[0]:
|
|
if not bGFX9 and "store" in as_line[0]:
|
|
VSMEM_INST.append([reverse_map[line], num_inflight])
|
|
NUM_VSMEM += 1
|
|
if inc_ordering:
|
|
vsmem_ordering = 1
|
|
else:
|
|
VLMEM_INST.append([reverse_map[line], num_inflight])
|
|
NUM_VLMEM += 1
|
|
if inc_ordering:
|
|
vlmem_ordering = 1
|
|
elif inst.type == FLAT:
|
|
smem_ordering = 1
|
|
vlmem_ordering = 1
|
|
vsmem_ordering = 1
|
|
FLAT_INST.append([reverse_map[line], num_inflight])
|
|
NUM_FLAT += 1
|
|
elif inst.type == IMMED and "s_waitcnt" in as_line[0]:
|
|
if "lgkmcnt" in as_line[0]:
|
|
try:
|
|
wait_N = int(as_line[0].split("lgkmcnt(")[1].split(")")[0])
|
|
except:
|
|
wait_N = 0
|
|
flight_count.append([as_line[5], num_inflight, wait_N])
|
|
if wait_N == 0:
|
|
smem_ordering = 0
|
|
if smem_ordering == 0:
|
|
offset = len(SMEM_INST) - wait_N
|
|
mem_unroll.append(
|
|
[reverse_map[line], SMEM_INST[:offset] + FLAT_INST]
|
|
)
|
|
SMEM_INST = SMEM_INST[offset:]
|
|
NUM_SMEM = len(SMEM_INST)
|
|
FLAT_INST = []
|
|
NUM_FLAT = 0
|
|
else:
|
|
NUM_SMEM = min(max(wait_N - NUM_FLAT, 0), NUM_SMEM)
|
|
NUM_FLAT = min(max(wait_N - NUM_SMEM, 0), NUM_FLAT)
|
|
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM
|
|
|
|
if "vmcnt" in as_line[0]:
|
|
try:
|
|
wait_N = int(as_line[0].split("vmcnt(")[1].split(")")[0])
|
|
except:
|
|
wait_N = 0
|
|
flight_count.append([as_line[5], num_inflight, wait_N])
|
|
if wait_N == 0:
|
|
vlmem_ordering = 0
|
|
if vlmem_ordering == 0:
|
|
offset = len(VLMEM_INST) - wait_N
|
|
mem_unroll.append(
|
|
[reverse_map[line], VLMEM_INST[:offset] + FLAT_INST]
|
|
)
|
|
VLMEM_INST = VLMEM_INST[offset:]
|
|
NUM_VLMEM = len(VLMEM_INST)
|
|
FLAT_INST = []
|
|
NUM_FLAT = 0
|
|
else:
|
|
NUM_VLMEM = min(max(wait_N - NUM_FLAT, 0), NUM_VLMEM)
|
|
NUM_FLAT = min(max(wait_N - NUM_VLMEM, 0), NUM_FLAT)
|
|
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM
|
|
|
|
if "vscnt" in as_line[0] or (bGFX9 and "vmcnt" in as_line[0]):
|
|
try:
|
|
wait_N = int(as_line[0].split('vscnt(')[1].split(')')[0])
|
|
except:
|
|
try:
|
|
wait_N = int(as_line[0].split('vmcnt(')[1].split(')')[0])
|
|
except:
|
|
wait_N = 0
|
|
flight_count.append([as_line[5], num_inflight, wait_N])
|
|
if wait_N == 0:
|
|
vsmem_ordering = 0
|
|
if vsmem_ordering == 0:
|
|
offset = len(VSMEM_INST) - wait_N
|
|
mem_unroll.append(
|
|
[reverse_map[line], VSMEM_INST[:offset] + FLAT_INST]
|
|
)
|
|
VSMEM_INST = VSMEM_INST[offset:]
|
|
NUM_VSMEM = len(VSMEM_INST)
|
|
FLAT_INST = []
|
|
NUM_FLAT = 0
|
|
else:
|
|
NUM_VSMEM = min(max(wait_N - NUM_FLAT, 0), NUM_VSMEM)
|
|
NUM_FLAT = min(max(wait_N - NUM_VSMEM, 0), NUM_FLAT)
|
|
num_inflight = NUM_FLAT + NUM_SMEM + NUM_VLMEM + NUM_VSMEM
|
|
|
|
elif inst.type == JUMP and as_line[1] == BRANCH:
|
|
next = watchlist.jump(as_line)
|
|
if next is None or next == 0:
|
|
print("Jump to unknown location!", as_line)
|
|
break
|
|
elif inst.type == NEXT and as_line[1] == BRANCH:
|
|
pass
|
|
else:
|
|
matched = False
|
|
if watchlist.try_match_swapped(i, line, lineincrement):
|
|
temp = insts[i]
|
|
insts[i] = insts[i + 1]
|
|
insts[i + 1] = temp
|
|
next = line
|
|
elif "s_waitcnt" in as_line[0] or "_load_" in as_line[0]:
|
|
if skipped_immed > 0 and "s_waitcnt" in as_line[0]:
|
|
matched = True
|
|
skipped_immed -= 1
|
|
elif 'scratch_' not in as_line[0]:
|
|
print('WARNING: Parsing terminated at:', as_line)
|
|
break
|
|
|
|
if matched or as_line[1] != DONT_KNOW:
|
|
if matched:
|
|
inst.asmline = reverse_map[line]
|
|
result.append(inst)
|
|
i += 1
|
|
num_failed_stitches = 0
|
|
elif not bGFX9 and inst.type == IMMED and line != next:
|
|
skipped_immed += 1
|
|
inst.asmline = reverse_map[line]
|
|
result.append(inst)
|
|
if 's_barrier' in as_line[0]:
|
|
next = line + lineincrement
|
|
i += 1
|
|
else:
|
|
num_failed_stitches += 1
|
|
|
|
maxline = max(reverse_map[line], maxline)
|
|
line = next
|
|
|
|
N = max(N, 1)
|
|
if i != N and insts[i].type == WAVE_ENDED:
|
|
print('Warning - Wave ended.')
|
|
elif i < N:
|
|
print('Warning - Stitching rate: '+str(i * 100 / N)+'% matched', i, ' of ', N)
|
|
print('Leftovers:', [WaveInstCategory[insts[i+k].type] for k in range(20) if i+k < len(insts)])
|
|
try:
|
|
print(line, code[line])
|
|
except:
|
|
pass
|
|
else:
|
|
while line < len(code):
|
|
if "s_endpgm" in code[line]:
|
|
mem_unroll.append(
|
|
[reverse_map[line], SMEM_INST + VLMEM_INST + VSMEM_INST + FLAT_INST]
|
|
)
|
|
break
|
|
line += 1
|
|
print('Success: Parsed', i, 'tokens')
|
|
|
|
return result, loopCount, mem_unroll, flight_count, maxline, len(result), pcskip
|