diff --git a/projects/rocm-smi-lib/python_smi_tools/rocm_smi.py b/projects/rocm-smi-lib/python_smi_tools/rocm_smi.py index de2e09ec04..db5a23dbbd 100755 --- a/projects/rocm-smi-lib/python_smi_tools/rocm_smi.py +++ b/projects/rocm-smi-lib/python_smi_tools/rocm_smi.py @@ -17,7 +17,7 @@ import logging import os import sys import subprocess -import _thread +import threading import time import multiprocessing import trace @@ -87,6 +87,9 @@ validClockNames = clk_type_names[1:-2] validClockNames.append('pcie') validClockNames.sort() +# Thread stop condition +stop_threads = False + def driverInitialized(): """ Returns true if amdgpu is found in the list of initialized modules """ @@ -867,7 +870,7 @@ def printEventList(device, delay, eventList): if not rsmi_ret_ok(ret, device, 'set_event_notification_mask'): printErrLog(device, 'Unable to set event notification mask.') return - while 1: # Exit condition from user keyboard input of 'q' or 'ctrl + c' + while not stop_threads: # Exit condition from user keyboard input of 'q' or 'ctrl + c' num_elements = c_uint32(1) data = rsmi_evt_notification_data_t(1) rocmsmi.rsmi_event_notification_get(delay, byref(num_elements), byref(data)) @@ -2999,6 +3002,7 @@ def showEvents(deviceList, eventTypes): printLogSpacer(' Show Events ') printLog(None, 'press \'q\' or \'ctrl + c\' to quit', None) eventTypeList = [] + thread_list = [] for event in eventTypes: # Cleaning list from wrong values if event.replace(',', '').upper() in notification_type_names: eventTypeList.append(event.replace(',', '').upper()) @@ -3010,7 +3014,9 @@ def showEvents(deviceList, eventTypes): # Create a separate thread for each GPU for device in deviceList: try: - _thread.start_new_thread(printEventList, (device, 1000, eventTypeList)) + thread = threading.Thread(target=printEventList, args=(device, 1000, eventTypeList)) + thread.start() + thread_list.append(thread) time.sleep(0.25) except Exception as e: printErrLog(device, 'Unable to start new thread. %s' % (e)) @@ -3019,6 +3025,8 @@ def showEvents(deviceList, eventTypes): getch = _Getch() user_input = getch() # Catch user input for q or Ctrl + c + global stop_threads + stop_threads = True if user_input == 'q' or user_input == '\x03': for device in deviceList: ret = rocmsmi.rsmi_event_notification_stop(device) @@ -3026,6 +3034,8 @@ def showEvents(deviceList, eventTypes): printErrLog(device, 'Unable to end event notifications.') print('\r') break + for thread in thread_list: + thread.join() def printTempGraph(deviceList, delay, temp_type): @@ -3038,7 +3048,7 @@ def printTempGraph(deviceList, delay, temp_type): for i in range(devices): printEmptyLine() originalTerminalWidth = os.get_terminal_size()[0] - while 1: # Exit condition from user keyboard input of 'q' or 'ctrl + c' + while not stop_threads: # Exit condition from user keyboard input of 'q' or 'ctrl + c' terminalWidth = os.get_terminal_size()[0] printStrings = list() for device in deviceList: @@ -3118,19 +3128,26 @@ def showTempGraph(deviceList): deviceList.sort() temp_type = getTemperatureLabel(deviceList) printLogSpacer(' Temperature Graph ' + temp_type.capitalize() + ' ') + thread_list = [] # Start a thread for constantly printing try: # Create a thread (call print function, devices, delay in ms) - _thread.start_new_thread(printTempGraph, (deviceList, 150, temp_type)) + thread = threading.Thread(target=printTempGraph, args=(deviceList, 150, temp_type)) + thread.start() + thread_list.append(thread) except Exception as e: printErrLog(device, 'Unable to start new thread. %s' % (e)) # Catch user input for program termination while 1: # Exit condition from user keyboard input of 'q' or 'ctrl + c' getch = _Getch() user_input = getch() + global stop_threads + stop_threads = True; # Catch user input for q or Ctrl + c if user_input == 'q' or user_input == '\x03': break + for thread in thread_list: + thread.join() # Reset color to default before exit print('\033[A\x1b[0m\r') printLogSpacer()