From c9b1ad72a54bf0805d4d82c95ed645ce83b692e4 Mon Sep 17 00:00:00 2001 From: ywang103-amd Date: Mon, 8 Sep 2025 01:30:30 -0400 Subject: [PATCH] scientific notion of memchart(CLI, TUI and GUI) (#764) --- projects/rocprofiler-compute/CHANGELOG.md | 2 + .../src/utils/gui_components/memchart.py | 216 +++++++++++------- .../src/utils/mem_chart.py | 71 ++++-- .../rocprofiler-compute/src/utils/utils.py | 77 ++++++- .../rocprofiler-compute/tests/test_utils.py | 69 ++++++ 5 files changed, 334 insertions(+), 101 deletions(-) diff --git a/projects/rocprofiler-compute/CHANGELOG.md b/projects/rocprofiler-compute/CHANGELOG.md index c36b4b297a..e8f3170946 100644 --- a/projects/rocprofiler-compute/CHANGELOG.md +++ b/projects/rocprofiler-compute/CHANGELOG.md @@ -25,6 +25,8 @@ Full documentation for ROCm Compute Profiler is available at [https://rocm.docs. ### Changed +* On memory chart, long string of numbers are displayed as scientific notation. It also solves the issue of overflow of displaying long number + * Add notice for change in default output format to `rocpd` in a future release * This is displayed when `--format-rocprof-output rocpd` is not used in profile mode diff --git a/projects/rocprofiler-compute/src/utils/gui_components/memchart.py b/projects/rocprofiler-compute/src/utils/gui_components/memchart.py index b2bb4e614f..89ddff50c1 100644 --- a/projects/rocprofiler-compute/src/utils/gui_components/memchart.py +++ b/projects/rocprofiler-compute/src/utils/gui_components/memchart.py @@ -27,6 +27,7 @@ from dash import html from dash_svg import G, Path, Rect, Svg, Text from utils.logger import console_error +from utils.utils import format_scientific_notation_if_needed def insert_chart_data(mem_data, base_data): @@ -60,7 +61,9 @@ def insert_chart_data(mem_data, base_data): fill="#FFFF33", fontSize="20px", fontWeight="bold", - children=memchart_values["Wavefront Occupancy"], + children=format_value_for_display( + memchart_values.get("Wavefront Occupancy") + ), ), Text( x="49", @@ -69,7 +72,7 @@ def insert_chart_data(mem_data, base_data): fill="#FFFF33", fontSize="20px", fontWeight="bold", - children=memchart_values["Wave Life"], + children=format_value_for_display(memchart_values.get("Wave Life")), ), # ---------------------------------------- # Instr Dispatch Block @@ -79,7 +82,7 @@ def insert_chart_data(mem_data, base_data): id="salu", fill="rgb(0, 0, 0)", fontSize="12px", - children=format_value_for_display(memchart_values["SALU"]), + children=format_value_for_display(memchart_values.get("SALU")), ), Text( x="386", @@ -87,7 +90,7 @@ def insert_chart_data(mem_data, base_data): id="smem", fill="rgb(0, 0, 0)", fontSize="12px", - children=format_value_for_display(memchart_values["SMEM"]), + children=format_value_for_display(memchart_values.get("SMEM")), ), Text( x="386", @@ -95,7 +98,7 @@ def insert_chart_data(mem_data, base_data): id="valu", fill="rgb(0, 0, 0)", fontSize="12px", - children=format_value_for_display(memchart_values["VALU"]), + children=format_value_for_display(memchart_values.get("VALU")), ), Text( x="386", @@ -103,7 +106,7 @@ def insert_chart_data(mem_data, base_data): id="mfma", fill="rgb(0, 0, 0)", fontSize="12px", - children=format_value_for_display(memchart_values["MFMA"]), + children=format_value_for_display(memchart_values.get("MFMA")), ), Text( x="386", @@ -111,7 +114,7 @@ def insert_chart_data(mem_data, base_data): id="vmem", fill="rgb(0, 0, 0)", fontSize="12px", - children=format_value_for_display(memchart_values["VMEM"]), + children=format_value_for_display(memchart_values.get("VMEM")), ), Text( x="386", @@ -119,7 +122,7 @@ def insert_chart_data(mem_data, base_data): id="lds", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["LDS"], + children=format_value_for_display(memchart_values.get("LDS")), ), Text( x="386", @@ -127,7 +130,7 @@ def insert_chart_data(mem_data, base_data): id="gws", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["GWS"], + children=format_value_for_display(memchart_values.get("GWS")), ), Text( x="386", @@ -135,7 +138,7 @@ def insert_chart_data(mem_data, base_data): id="br", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["BR"], + children=format_value_for_display(memchart_values.get("BR")), ), # ---------------------------------------- # Exec Block @@ -146,7 +149,7 @@ def insert_chart_data(mem_data, base_data): fill="#FFFF33", fontSize="20px", fontWeight="bold", - children=memchart_values["Active CUs"], + children=format_value_for_display(memchart_values.get("Active CUs")), ), # x=454 Text( x="580", @@ -154,7 +157,7 @@ def insert_chart_data(mem_data, base_data): id="vgpr", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["VGPR"], + children=format_value_for_display(memchart_values.get("VGPR")), ), Text( x="581", @@ -162,7 +165,7 @@ def insert_chart_data(mem_data, base_data): id="sgpr", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["SGPR"], + children=format_value_for_display(memchart_values.get("SGPR")), ), Text( x="580", @@ -170,7 +173,9 @@ def insert_chart_data(mem_data, base_data): id="lds_alloc", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["LDS Allocation"], + children=format_value_for_display( + memchart_values.get("LDS Allocation") + ), ), Text( x="580", @@ -178,7 +183,9 @@ def insert_chart_data(mem_data, base_data): id="scratch_alloc", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["Scratch Allocation"], + children=format_value_for_display( + memchart_values.get("Scratch Allocation") + ), ), Text( x="580", @@ -186,7 +193,7 @@ def insert_chart_data(mem_data, base_data): id="wavefronts", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["Wavefronts"], + children=format_value_for_display(memchart_values.get("Wavefronts")), ), Text( x="580", @@ -194,7 +201,7 @@ def insert_chart_data(mem_data, base_data): id="workgroups", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["Workgroups"], + children=format_value_for_display(memchart_values.get("Workgroups")), ), # ---------------------------------------- # LDS Block @@ -204,7 +211,7 @@ def insert_chart_data(mem_data, base_data): id="lds_req", fill="#FFFFFF", fontSize="12px", - children=memchart_values["LDS Req"], + children=format_value_for_display(memchart_values.get("LDS Req")), ), Text( x="839", @@ -212,7 +219,7 @@ def insert_chart_data(mem_data, base_data): id="lds_util", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["LDS Util"], + children=format_value_for_display(memchart_values.get("LDS Util")), ), Text( x="839", @@ -220,7 +227,7 @@ def insert_chart_data(mem_data, base_data): id="lds_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["LDS Latency"], + children=format_value_for_display(memchart_values.get("LDS Latency")), ), # ---------------------------------------- # Vector L1 Cache Block @@ -230,7 +237,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_rd", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["VL1 Rd"]), + children=format_value_for_display(memchart_values.get("VL1 Rd")), ), Text( x="708", @@ -238,7 +245,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_wr", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["VL1 Wr"]), + children=format_value_for_display(memchart_values.get("VL1 Wr")), ), Text( x="716", @@ -246,7 +253,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_atom", fill="#FFFFFF", fontSize="12px", - children=memchart_values["VL1 Atomic"], + children=format_value_for_display(memchart_values.get("VL1 Atomic")), ), Text( x="840", @@ -254,7 +261,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_hit", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["VL1 Hit"], + children=format_value_for_display(memchart_values.get("VL1 Hit")), ), Text( x="840", @@ -262,7 +269,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["VL1 Lat"], + children=format_value_for_display(memchart_values.get("VL1 Lat")), ), Text( x="840", @@ -270,7 +277,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_coales", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["VL1 Coalesce"], + children=format_value_for_display(memchart_values.get("VL1 Coalesce")), ), Text( x="838", @@ -278,7 +285,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_stall", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["VL1 Stall"], + children=format_value_for_display(memchart_values.get("VL1 Stall")), ), Text( x="1000", @@ -286,7 +293,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_l2_rd", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["VL1_L2 Rd"]), + children=format_value_for_display(memchart_values.get("VL1_L2 Rd")), ), Text( x="1000", @@ -294,7 +301,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_l2_wr", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["VL1_L2 Wr"]), + children=format_value_for_display(memchart_values.get("VL1_L2 Wr")), ), Text( x="1008", @@ -302,7 +309,7 @@ def insert_chart_data(mem_data, base_data): id="vl1_l2_atom", fill="#FFFFFF", fontSize="12px", - children=memchart_values["VL1_L2 Atomic"], + children=format_value_for_display(memchart_values.get("VL1_L2 Atomic")), ), # ---------------------------------------- # Scalar L1D Cache Block @@ -312,7 +319,7 @@ def insert_chart_data(mem_data, base_data): id="sl1_rd", fill="#FFFFFF", fontSize="12px", - children=memchart_values["sL1D Rd"], + children=format_value_for_display(memchart_values.get("sL1D Rd")), ), Text( x="838", @@ -320,7 +327,7 @@ def insert_chart_data(mem_data, base_data): id="sl1_hit", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["sL1D Hit"], + children=format_value_for_display(memchart_values.get("sL1D Hit")), ), Text( x="838", @@ -328,7 +335,7 @@ def insert_chart_data(mem_data, base_data): id="sl1_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["sL1D Lat"], + children=format_value_for_display(memchart_values.get("sL1D Lat")), ), Text( x="1000", @@ -336,7 +343,7 @@ def insert_chart_data(mem_data, base_data): id="sl1_l2_rd", fill="#FFFFFF", fontSize="12px", - children=memchart_values["sL1D_L2 Rd"], + children=format_value_for_display(memchart_values.get("sL1D_L2 Rd")), ), Text( x="1000", @@ -344,7 +351,7 @@ def insert_chart_data(mem_data, base_data): id="sl1_l2_wr", fill="#FFFFFF", fontSize="12px", - children=memchart_values["sL1D_L2 Wr"], + children=format_value_for_display(memchart_values.get("sL1D_L2 Wr")), ), Text( x="1008", @@ -352,7 +359,9 @@ def insert_chart_data(mem_data, base_data): id="sl1_l2_atom", fill="#FFFFFF", fontSize="12px", - children=memchart_values["sL1D_L2 Atomic"], + children=format_value_for_display( + memchart_values.get("sL1D_L2 Atomic") + ), ), # ---------------------------------------- # Instr L1 Cache Block @@ -362,7 +371,7 @@ def insert_chart_data(mem_data, base_data): id="il1_fetch", fill="#FFFFFF", fontSize="12px", - children=memchart_values["IL1 Fetch"], + children=format_value_for_display(memchart_values.get("IL1 Fetch")), ), Text( x="837", @@ -370,7 +379,7 @@ def insert_chart_data(mem_data, base_data): id="il1_hit", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["IL1 Hit"], + children=format_value_for_display(memchart_values.get("IL1 Hit")), ), Text( x="837", @@ -378,7 +387,7 @@ def insert_chart_data(mem_data, base_data): id="il1_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["IL1 Lat"], + children=format_value_for_display(memchart_values.get("IL1 Lat")), ), Text( x="1015", @@ -386,7 +395,7 @@ def insert_chart_data(mem_data, base_data): id="il1_l2_req", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["IL1_L2 Rd"]), + children=format_value_for_display(memchart_values.get("IL1_L2 Rd")), ), # ---------------------------------------- # L2 Cache Block(inside) @@ -396,7 +405,7 @@ def insert_chart_data(mem_data, base_data): id="l2_rd", fill="rgb(0, 0, 0)", fontSize="12px", - children=format_value_for_display(memchart_values["L2 Rd"]), + children=format_value_for_display(memchart_values.get("L2 Rd")), ), Text( x="1145", @@ -404,7 +413,7 @@ def insert_chart_data(mem_data, base_data): id="l2_wr", fill="rgb(0, 0, 0)", fontSize="12px", - children=format_value_for_display(memchart_values["L2 Wr"]), + children=format_value_for_display(memchart_values.get("L2 Wr")), ), Text( x="1145", @@ -412,7 +421,7 @@ def insert_chart_data(mem_data, base_data): id="l2_atom", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["L2 Atomic"], + children=format_value_for_display(memchart_values.get("L2 Atomic")), ), Text( x="1145", @@ -420,7 +429,7 @@ def insert_chart_data(mem_data, base_data): id="l2_hit", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["L2 Hit"], + children=format_value_for_display(memchart_values.get("L2 Hit")), ), Text( x="1145", @@ -428,7 +437,7 @@ def insert_chart_data(mem_data, base_data): id="l2_rd_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["L2 Rd Lat"], + children=format_value_for_display(memchart_values.get("L2 Rd Lat")), ), Text( x="1145", @@ -436,7 +445,7 @@ def insert_chart_data(mem_data, base_data): id="l2_wr_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["L2 Wr Lat"], + children=format_value_for_display(memchart_values.get("L2 Wr Lat")), ), # ---------------------------------------- # Fabric Block @@ -446,7 +455,7 @@ def insert_chart_data(mem_data, base_data): id="l2_fabric_rd", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["Fabric_L2 Rd"]), + children=format_value_for_display(memchart_values.get("Fabric_L2 Rd")), ), Text( x="1317", @@ -454,7 +463,7 @@ def insert_chart_data(mem_data, base_data): id="l2_fabric_wr", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["Fabric_L2 Wr"]), + children=format_value_for_display(memchart_values.get("Fabric_L2 Wr")), ), Text( x="1319", @@ -462,7 +471,9 @@ def insert_chart_data(mem_data, base_data): id="l2_fabric_atom", fill="#FFFFFF", fontSize="12px", - children=memchart_values["Fabric_L2 Atomic"], + children=format_value_for_display( + memchart_values.get("Fabric_L2 Atomic") + ), ), Text( x="1435", @@ -470,7 +481,7 @@ def insert_chart_data(mem_data, base_data): id="fabric_rd_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["Fabric Rd Lat"], + children=format_value_for_display(memchart_values.get("Fabric Rd Lat")), ), Text( x="1435", @@ -478,7 +489,7 @@ def insert_chart_data(mem_data, base_data): id="fabric_wr_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["Fabric Wr Lat"], + children=format_value_for_display(memchart_values.get("Fabric Wr Lat")), ), Text( x="1435", @@ -486,7 +497,9 @@ def insert_chart_data(mem_data, base_data): id="fabric_atom_lat", fill="rgb(0, 0, 0)", fontSize="12px", - children=memchart_values["Fabric Atomic Lat"], + children=format_value_for_display( + memchart_values.get("Fabric Atomic Lat") + ), ), Text( x="1578", @@ -494,7 +507,7 @@ def insert_chart_data(mem_data, base_data): id="hbm_rd", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["HBM Rd"]), + children=format_value_for_display(memchart_values.get("HBM Rd")), ), Text( x="1577", @@ -502,7 +515,7 @@ def insert_chart_data(mem_data, base_data): id="hbm_wr", fill="#FFFFFF", fontSize="12px", - children=format_value_for_display(memchart_values["HBM Wr"]), + children=format_value_for_display(memchart_values.get("HBM Wr")), ), ], ) @@ -2028,42 +2041,91 @@ def get_memchart(mem_data, base_data): def format_value_for_display(value, max_length=6): """ - Format values to prevent overflow in SVG text elements. - """ - ##### - # TODO: this is quick fix to prevent value overflow. - # The long term solution should be dynamically adjust - # SVG dimensions and positions to maintain visual - # integrity while preventing overflow. - ##### + Format a value (int, float, or str) into a concise string suitable for display. + + The function attempts to convert strings to numeric types if possible. + It then decides between normal decimal notation and scientific notation + based on length constraints and value magnitude. + + If the formatted string is too long, it truncates it gracefully, + preserving scientific notation exponent parts where applicable. + + Parameters: + - value: The input value to format. Can be int, float, or string. + Strings representing numbers are converted to numeric types if possible. + - max_length: Maximum allowed length of the output string. + Longer strings are truncated with an ellipsis ('…'). + + Returns: + - A string representation of the input value, formatted either + in fixed-point or scientific notation, and truncated if too long. + Returns "N/A" if the value is invalid (e.g., None or NaN). + """ + + if value is None: + return "N/A" - # 1. If non-numerical if isinstance(value, str): try: if "." in value: + # when dot is in the string, we know it's a float number and convert with "float" value = float(value) else: + # without dot, we assume it's an integer and convert with "int" value = int(value) except ValueError: - pass # Keep as string - # 2. If numerical + # when conversion fails, the string is neither legit float or int, then assume it's invalid and display "N/A" + return "N/A" + if isinstance(value, (int, float)): - value = abs(value) - if value >= 1000000000: - value = f"{value / 1000000000:.1f}B" - elif value >= 1000000: - value = f"{value / 1000000:.1f}M" - elif value >= 1000: - value = f"{value / 1000:.1f}K" - elif value == int(value): - value = str(int(value)) + is_negative = value < 0 + abs_val = abs(value) + + if isinstance(abs_val, float): + if value != value: + return "N/A" + + if abs_val.is_integer(): + normal = str(int(abs_val)) + else: + normal = f"{abs_val:.1f}" else: - value = f"{value:.1f}" + normal = str(abs_val) + + sci = format_scientific_notation_if_needed( + abs_val, + align=">", + width_align=8, + precision=1, + fmt_type_align="e", + max_length=max_length, + ).strip() + + # Choose shorter notation or if normal too long + if len(sci) < len(normal) or len(normal) > max_length: + value = sci + else: + value = normal + + if is_negative: + value = "-" + value + else: value = str(value) - # 3. Truncate if needed + # Custom truncation logic: if len(value) > max_length: - value = value[: max_length - 1] + "…" + if "e" in value.lower(): + e_index = value.lower().index("e") + mantissa = value[:e_index] + exponent = value[e_index:] + max_mantissa_len = max_length - len(exponent) + if max_mantissa_len < 1: + value = exponent[: max_length - 1] + "…" + else: + truncated_mantissa = mantissa[:max_mantissa_len] + value = truncated_mantissa + exponent + else: + value = value[: max_length - 1] + "…" return value diff --git a/projects/rocprofiler-compute/src/utils/mem_chart.py b/projects/rocprofiler-compute/src/utils/mem_chart.py index 211f1ad220..1060c0e40f 100644 --- a/projects/rocprofiler-compute/src/utils/mem_chart.py +++ b/projects/rocprofiler-compute/src/utils/mem_chart.py @@ -22,12 +22,15 @@ # SOFTWARE. ###############################################################################el +import re from dataclasses import dataclass, field from decimal import Decimal from typing import Dict from plotille import Canvas +from .utils import format_scientific_notation_if_needed + def make_format_spec(num, align=">"): """ @@ -83,25 +86,30 @@ def format_text( ): """ Format a text string for canvas to display according to - input key value pair and make proper aligment - For invalid value, it displays N/A - All strings to be displayed on Canvas need to use this method + input key-value pair and make proper alignment. + Uses scientific notation formatting when needed. + For invalid value, it displays N/A. """ + + # Step 1: Build format spec using make_format_spec value_format = make_format_spec(value_step_prec_rightalign, value_align) - if is_value_valid(value): - value_str = "{val:{format}}".format(val=value, format=value_format) + # Step 2: Extract width and precision as integer + match = re.match(r"([<>=^])(\d+)(?:\.(\d+))?([a-zA-Z])?", value_format) + if match: + align_char = match.group(1) + width_align = int(match.group(2)) + precision_digits = match.group(3) + fmt_type_align = match.group(4) or "f" + precision = int(precision_digits) if precision_digits else 0 else: - import re - - match = re.search(r"[<>=^](\d+)", value_format) - width = int(match.group(1)) if match else 6 - - # Use same alignment as in value_format (first char) - align = value_format[0] - - value_str = f"{'N/A':{align}{width}}" + # Fallback to default values + align_char = value_align + width_align = 6 + precision = 2 + fmt_type_align = "f" + # Step 3: Format the key using make_format_spec key_format = ( make_format_spec(key_step_prec_leftalign, key_align) if key is not None @@ -109,19 +117,36 @@ def format_text( ) key_str = ( "{key:{key_format}}".format(key=key, key_format=key_format) - if key and isinstance(key, (int, float)) + if key is not None and isinstance(key, (int, float)) else str(key) - if key + if key is not None else None ) - unit_string = post_description_with_space if not "N/A" in value_str else "" + # Step 4: Format the value or fallback to N/A + if is_value_valid(value): + formatted_value = format_scientific_notation_if_needed( + value, + align=align_char, + width_align=width_align, + precision=precision, + fmt_type_align=fmt_type_align, + max_length=width_align, + sci_lower_bound=1e-3, + sci_upper_bound=1e3, + ) + value_str = formatted_value + else: + value_str = f"{'N/A':{align_char}{width_align}}" + + # Step 5: Unit and Final Output + unit_string = post_description_with_space if "N/A" not in value_str else "" + + if key_str is not None: + result_str_no_unit = f"{key_str}{mark_between}{value_str}" + else: + result_str_no_unit = value_str - result_str_no_unit = ( - "{key}{mark}{value}".format(key=key_str, value=value_str, mark=mark_between) - if key is not None - else "{value}".format(value=value_str) - ) result_str = result_str_no_unit + unit_string return result_str @@ -603,7 +628,7 @@ class ScalarL1DCache(RectFrame): key="Hit", value=self.hit, key_step_prec_leftalign=6, - value_step_prec_rightalign=6, + value_step_prec_rightalign=6.0, post_description_with_space=" %", ), ) diff --git a/projects/rocprofiler-compute/src/utils/utils.py b/projects/rocprofiler-compute/src/utils/utils.py index 176d89b079..a17b3b9bb3 100644 --- a/projects/rocprofiler-compute/src/utils/utils.py +++ b/projects/rocprofiler-compute/src/utils/utils.py @@ -38,7 +38,7 @@ import tempfile import time import uuid from pathlib import Path as path -from typing import Optional +from typing import Optional, Union import pandas as pd import yaml @@ -1645,3 +1645,78 @@ def parse_sets_yaml(arch): def get_uuid(length=8): return uuid.uuid4().hex[:length] + + +def format_scientific_notation_if_needed( + value: Union[int, float], + align: str = ">", + width_align: int = 6, + precision: int = 2, + fmt_type_align: str = "f", + max_length: int = 6, + sci_lower_bound: float = 1e-2, + sci_upper_bound: float = 1e6, +) -> str: + """ + Format a numeric value as normal or scientific notation string. + + Uses scientific notation if: + - abs(value) < sci_lower_bound (but not zero) + - abs(value) >= sci_upper_bound + - formatted normal string length exceeds max_length + + Parameters: + - value: numeric value to format + - align: alignment character ('<', '>', '^', '=') + - width_align: total width of formatted output + - precision: number of digits after decimal point + - fmt_type_align: format type, e.g., 'f', 'e', 'g' + - max_length: max allowed length for normal format string (excluding padding) + - sci_lower_bound: lower bound for scientific notation usage + - sci_upper_bound: upper bound for scientific notation usage + + Returns: + - formatted string according to the criteria, respecting alignment + """ + + abs_val = abs(value) + use_sci = False + + # Build format specifiers + normal_format_spec = f"{align}{width_align}.{precision}{fmt_type_align}" + sci_format_spec = f"{align}{width_align}.{precision}e" + + normal_str = None # will hold formatted normal string (with padding) + sci_str = None # will hold formatted scientific string (with padding) + + if abs_val != 0: + if abs_val < sci_lower_bound or abs_val >= sci_upper_bound: + use_sci = True + else: + try: + normal_str = format(value, normal_format_spec) + normal_str_strip = normal_str.strip() + + sci_str = format(value, sci_format_spec) + sci_str_strip = sci_str.strip() + + # Decide based on length of stripped strings (ignore padding) + if ( + len(normal_str_strip) > len(sci_str_strip) + or len(normal_str_strip) > max_length + ): + use_sci = True + except Exception: + # Fallback to scientific if formatting fails + use_sci = True + + if use_sci: + if sci_str is None: + sci_str = format(value, sci_format_spec) + formatted = sci_str + else: + if normal_str is None: + normal_str = format(value, normal_format_spec) + formatted = normal_str + + return formatted diff --git a/projects/rocprofiler-compute/tests/test_utils.py b/projects/rocprofiler-compute/tests/test_utils.py index c53b17b062..5fbcc2065e 100644 --- a/projects/rocprofiler-compute/tests/test_utils.py +++ b/projects/rocprofiler-compute/tests/test_utils.py @@ -9290,3 +9290,72 @@ def test_set_parser(): assert "compute_thruput_util" in result assert result["compute_thruput_util"]["title"] == "Compute Throughput Utilization" + + +@pytest.mark.sci_notion +def test_scientific_notation_trigger_below_lower_bound(): + value = 0.0001 + result = utils.format_scientific_notation_if_needed(value) + assert pytest.approx(float(result.strip()), rel=1e-9) == value + + +@pytest.mark.sci_notion +def test_scientific_notation_trigger_at_lower_bound(): + value = 0.01 + result = utils.format_scientific_notation_if_needed(value) + assert pytest.approx(float(result.strip()), rel=1e-9) == value + + +def test_scientific_notation_trigger_above_upper_bound(): + value = 1234567890 + result = utils.format_scientific_notation_if_needed(value) + assert pytest.approx(float(result.strip()), rel=1e-9) == value + + +@pytest.mark.sci_notion +def test_scientific_notation_trigger_just_below_upper_bound(): + value = 999999 + result = utils.format_scientific_notation_if_needed(value, precision=6) + assert pytest.approx(float(result.strip()), rel=1e-6) == value + + +@pytest.mark.sci_notion +def test_scientific_notation_trigger_zero(): + value = 0 + result = utils.format_scientific_notation_if_needed(value) + assert float(result.strip()) == value # Exact match for zero + + +@pytest.mark.sci_notion +def test_scientific_notation_trigger_slightly_below_lower_bound(): + value = 0.009 + result = utils.format_scientific_notation_if_needed(value) + assert pytest.approx(float(result.strip()), rel=1e-9) == value + + +@pytest.mark.sci_notion +def test_scientific_notation_trigger_well_below_lower_bound(): + value = 1e-5 + result = utils.format_scientific_notation_if_needed(value) + assert pytest.approx(float(result.strip()), rel=1e-9) == value + + +@pytest.mark.sci_notion +def test_scientific_notation_trigger_well_above_upper_bound(): + value = 1e10 + result = utils.format_scientific_notation_if_needed(value) + assert pytest.approx(float(result.strip()), rel=1e-9) == value + + +@pytest.mark.sci_notion +def test_alignment_and_width(): + value = 1e10 + result = utils.format_scientific_notation_if_needed( + value, + align=">", + width_align=12, + precision=2, + fmt_type_align="f", + max_length=8, + ) + assert pytest.approx(float(result.strip()), rel=1e-9) == value