diff --git a/projects/rdc/python_binding/rdc_prometheus.py b/projects/rdc/python_binding/rdc_prometheus.py index ec9e733e30..77975b2606 100644 --- a/projects/rdc/python_binding/rdc_prometheus.py +++ b/projects/rdc/python_binding/rdc_prometheus.py @@ -15,7 +15,7 @@ default_field_ids = [ class PrometheusReader(RdcReader): def __init__(self, rdc_ip_port, field_ids, update_freq, max_keep_age, max_keep_samples, - gpu_indexes, rdc_unauth, enable_plugin_monitoring): + gpu_indexes, rdc_unauth, enable_pci_id, enable_plugin_monitoring): group_name = "rdc_prometheus_plugin_group" field_group_name = "rdc_prometheus_plugin_fieldgroup" if rdc_unauth: @@ -32,6 +32,30 @@ class PrometheusReader(RdcReader): REGISTRY.unregister(PROCESS_COLLECTOR) REGISTRY.unregister(PLATFORM_COLLECTOR) + # Use the PCI id as gpu_index if enabled + self.enable_pci_id = False + if enable_pci_id == True: + try: + import sys, os + # Relaive path of rocm_smi to map gpu index to PCI id + # change smi_lib_path if the rocm_smi is installed in different folder + smi_lib_relative_path = "../../bin" + smi_lib_path = os.path.join(sys.path[0], smi_lib_relative_path) + if os.path.exists(smi_lib_path+"/rocm_smi.py"): + sys.path.append(smi_lib_path) + from rocm_smi import getBus, initializeRsmi + initializeRsmi() + # Map between gpu indexes and PCIe bus addresses + self.index_to_bus_addr = {} + for item in self.gpu_indexes: + self.index_to_bus_addr[item] = getBus(item) + self.enable_pci_id = True + else: + print("cannot find smi_lib to map the PCI id") + except Exception as error: + print("Fail to get the PCI id", error) + + # Create the guages self.guages = {} for fid in self.field_ids: @@ -39,8 +63,11 @@ class PrometheusReader(RdcReader): self.guages[fid] = Gauge(field_name, field_name, labelnames=['gpu_index']) def handle_field(self, gpu_index, value): + gpu_label = gpu_index + if self.enable_pci_id: + gpu_label = self.index_to_bus_addr[gpu_index] if value.field_id.value in self.guages: - self.guages[value.field_id.value].labels(gpu_index).set(value.value.l_int) + self.guages[value.field_id.value].labels(gpu_label).set(value.value.l_int) def get_field_ids(args): field_ids = [] @@ -80,6 +107,7 @@ if __name__ == '__main__': parser.add_argument('--rdc_fields', default=None, nargs='+', help='The list of fields name needs to be watched, for example, " --rdc_fields RDC_FI_GPU_TEMP RDC_FI_POWER_USAGE " (default: predefined fields in the plugin)') parser.add_argument('--rdc_fields_file', default=None, help='The list of fields name can also be read from a file with each field name in a separated line (default: None)') parser.add_argument('--rdc_gpu_indexes', default=None, nargs='+', help='The list of GPUs to be watched (default: All GPUs)') + parser.add_argument('--enable_pci_id', default=False, action='store_true', help = 'Use the PCI Device Identifier to identify GPU (default: false)') parser.add_argument('--enable_plugin_monitoring', default=False, action='store_true', help = 'Set this option to collect process metrics of the plugin itself (default: false)') args = parser.parse_args() @@ -94,7 +122,7 @@ if __name__ == '__main__': reader = PrometheusReader(rdc_ip_port, field_ids, args.rdc_update_freq*1000000, args.rdc_max_keep_age, args.rdc_max_keep_samples, - args.rdc_gpu_indexes, args.rdc_unauth, args.enable_plugin_monitoring) + args.rdc_gpu_indexes, args.rdc_unauth, args.enable_pci_id, args.enable_plugin_monitoring) start_http_server(args.listen_port) print("The RDC Prometheus plugin listen at port %d" % (args.listen_port)) time.sleep(3)